coCondensor-nq / dump.py
Komorebi660
upload
0e016e5
import collections
from torch.serialization import default_restore_location
from transformers import BertTokenizer, BertModel
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
import pickle
import argparse
import csv
nq_temp = {}
CheckpointState = collections.namedtuple("CheckpointState",
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch',
'encoder_params'])
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
return CheckpointState(**state_dict)
class DocPool(Dataset):
def __init__(self, path):
doc = []
with open(path, "r", encoding="utf8") as f:
lines = csv.reader(f, delimiter='\t')
for _id, _text in lines:
doc.append(_text)
self.doc = doc
def __len__(self):
return len(self.doc)
def __getitem__(self, index):
doc = self.doc[index]
return index, doc
def my_collate(batch):
batch = list(zip(*batch))
res = {'id': batch[0], 'doc': batch[1]}
del batch
return res
def extract_feature(args):
torch.manual_seed(2024)
torch.cuda.manual_seed(2024)
np.random.seed(2024)
if(args.doc_or_query == 'doc'):
_path = './nq_doc.tsv'
_out_path = './doc_embedding.pickle'
_prefix = 'ctx_model.'
else:
_path = './nq_query.tsv'
_out_path = './query_embedding.pickle'
_prefix = 'question_model.'
with torch.no_grad():
doc_dataset = DocPool(_path)
print(len(doc_dataset))
doc_dataloader = DataLoader(dataset=doc_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=my_collate)
tokenizer = BertTokenizer.from_pretrained(args.model_name)
model = BertModel.from_pretrained(args.model_name, return_dict=True)
saved_state = load_states_from_checkpoint(args.model_file)
prefix_len = len(_prefix)
ctx_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if
key.startswith(_prefix)}
model.load_state_dict(ctx_state, strict=False)
model = torch.nn.DataParallel(model)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
model.eval()
ids = []
idx = 0
doc_feature = np.zeros((len(doc_dataset), 768))
for batch_data in tqdm(doc_dataloader):
doc_id = batch_data['id']
doc_body = batch_data['doc']
inputs = tokenizer(doc_body, padding=True, truncation=True, return_tensors="pt",
add_special_tokens=True).to(device)
outputs = model(**inputs)
pooler_output = outputs.last_hidden_state[:, 0]
ids.extend(doc_id)
doc_feature[idx: idx + pooler_output.shape[0]] = pooler_output.cpu().numpy()
idx += pooler_output.shape[0]
feature_dic = {}
for i, id_i in enumerate(ids):
feature_dic[id_i] = doc_feature[i]
with open(_out_path, 'wb') as f:
pickle.dump(feature_dic, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=512,
help='minibatch size')
parser.add_argument('--model_name', type=str, default='Luyu/co-condenser-wiki',
help='model name')
parser.add_argument('--model_file', type=str, default='./dpr_biencoder.38.602',
help='model name')
parser.add_argument('--doc_or_query', type=str, default='query',
help='transfer documents or queries')
args = parser.parse_args()
extract_feature(args)