| import collections | |
| from torch.serialization import default_restore_location | |
| from transformers import BertTokenizer, BertModel | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from tqdm import tqdm | |
| import numpy as np | |
| import pickle | |
| import argparse | |
| import csv | |
| nq_temp = {} | |
| CheckpointState = collections.namedtuple("CheckpointState", | |
| ['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch', | |
| 'encoder_params']) | |
| def load_states_from_checkpoint(model_file: str) -> CheckpointState: | |
| state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu')) | |
| return CheckpointState(**state_dict) | |
| class DocPool(Dataset): | |
| def __init__(self, path): | |
| doc = [] | |
| with open(path, "r", encoding="utf8") as f: | |
| lines = csv.reader(f, delimiter='\t') | |
| for _id, _text in lines: | |
| doc.append(_text) | |
| self.doc = doc | |
| def __len__(self): | |
| return len(self.doc) | |
| def __getitem__(self, index): | |
| doc = self.doc[index] | |
| return index, doc | |
| def my_collate(batch): | |
| batch = list(zip(*batch)) | |
| res = {'id': batch[0], 'doc': batch[1]} | |
| del batch | |
| return res | |
| def extract_feature(args): | |
| torch.manual_seed(2024) | |
| torch.cuda.manual_seed(2024) | |
| np.random.seed(2024) | |
| if(args.doc_or_query == 'doc'): | |
| _path = './nq_doc.tsv' | |
| _out_path = './doc_embedding.pickle' | |
| _prefix = 'ctx_model.' | |
| else: | |
| _path = './nq_query.tsv' | |
| _out_path = './query_embedding.pickle' | |
| _prefix = 'question_model.' | |
| with torch.no_grad(): | |
| doc_dataset = DocPool(_path) | |
| print(len(doc_dataset)) | |
| doc_dataloader = DataLoader(dataset=doc_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=my_collate) | |
| tokenizer = BertTokenizer.from_pretrained(args.model_name) | |
| model = BertModel.from_pretrained(args.model_name, return_dict=True) | |
| saved_state = load_states_from_checkpoint(args.model_file) | |
| prefix_len = len(_prefix) | |
| ctx_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if | |
| key.startswith(_prefix)} | |
| model.load_state_dict(ctx_state, strict=False) | |
| model = torch.nn.DataParallel(model) | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| model = model.to(device) | |
| model.eval() | |
| ids = [] | |
| idx = 0 | |
| doc_feature = np.zeros((len(doc_dataset), 768)) | |
| for batch_data in tqdm(doc_dataloader): | |
| doc_id = batch_data['id'] | |
| doc_body = batch_data['doc'] | |
| inputs = tokenizer(doc_body, padding=True, truncation=True, return_tensors="pt", | |
| add_special_tokens=True).to(device) | |
| outputs = model(**inputs) | |
| pooler_output = outputs.last_hidden_state[:, 0] | |
| ids.extend(doc_id) | |
| doc_feature[idx: idx + pooler_output.shape[0]] = pooler_output.cpu().numpy() | |
| idx += pooler_output.shape[0] | |
| feature_dic = {} | |
| for i, id_i in enumerate(ids): | |
| feature_dic[id_i] = doc_feature[i] | |
| with open(_out_path, 'wb') as f: | |
| pickle.dump(feature_dic, f) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--batch_size', type=int, default=512, | |
| help='minibatch size') | |
| parser.add_argument('--model_name', type=str, default='Luyu/co-condenser-wiki', | |
| help='model name') | |
| parser.add_argument('--model_file', type=str, default='./dpr_biencoder.38.602', | |
| help='model name') | |
| parser.add_argument('--doc_or_query', type=str, default='query', | |
| help='transfer documents or queries') | |
| args = parser.parse_args() | |
| extract_feature(args) |