import collections from torch.serialization import default_restore_location from transformers import BertTokenizer, BertModel import torch from torch.utils.data import Dataset, DataLoader from tqdm import tqdm import numpy as np import pickle import argparse import csv nq_temp = {} CheckpointState = collections.namedtuple("CheckpointState", ['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch', 'encoder_params']) def load_states_from_checkpoint(model_file: str) -> CheckpointState: state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu')) return CheckpointState(**state_dict) class DocPool(Dataset): def __init__(self, path): doc = [] with open(path, "r", encoding="utf8") as f: lines = csv.reader(f, delimiter='\t') for _id, _text in lines: doc.append(_text) self.doc = doc def __len__(self): return len(self.doc) def __getitem__(self, index): doc = self.doc[index] return index, doc def my_collate(batch): batch = list(zip(*batch)) res = {'id': batch[0], 'doc': batch[1]} del batch return res def extract_feature(args): torch.manual_seed(2024) torch.cuda.manual_seed(2024) np.random.seed(2024) if(args.doc_or_query == 'doc'): _path = './nq_doc.tsv' _out_path = './doc_embedding.pickle' _prefix = 'ctx_model.' else: _path = './nq_query.tsv' _out_path = './query_embedding.pickle' _prefix = 'question_model.' with torch.no_grad(): doc_dataset = DocPool(_path) print(len(doc_dataset)) doc_dataloader = DataLoader(dataset=doc_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=my_collate) tokenizer = BertTokenizer.from_pretrained(args.model_name) model = BertModel.from_pretrained(args.model_name, return_dict=True) saved_state = load_states_from_checkpoint(args.model_file) prefix_len = len(_prefix) ctx_state = {key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith(_prefix)} model.load_state_dict(ctx_state, strict=False) model = torch.nn.DataParallel(model) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) model.eval() ids = [] idx = 0 doc_feature = np.zeros((len(doc_dataset), 768)) for batch_data in tqdm(doc_dataloader): doc_id = batch_data['id'] doc_body = batch_data['doc'] inputs = tokenizer(doc_body, padding=True, truncation=True, return_tensors="pt", add_special_tokens=True).to(device) outputs = model(**inputs) pooler_output = outputs.last_hidden_state[:, 0] ids.extend(doc_id) doc_feature[idx: idx + pooler_output.shape[0]] = pooler_output.cpu().numpy() idx += pooler_output.shape[0] feature_dic = {} for i, id_i in enumerate(ids): feature_dic[id_i] = doc_feature[i] with open(_out_path, 'wb') as f: pickle.dump(feature_dic, f) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch_size', type=int, default=512, help='minibatch size') parser.add_argument('--model_name', type=str, default='Luyu/co-condenser-wiki', help='model name') parser.add_argument('--model_file', type=str, default='./dpr_biencoder.38.602', help='model name') parser.add_argument('--doc_or_query', type=str, default='query', help='transfer documents or queries') args = parser.parse_args() extract_feature(args)