Spaces:
Build error
Build error
| from ML_SLRC import * | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| from torch.utils.data import DataLoader | |
| from torch.optim import Adam | |
| import gc | |
| from torchmetrics import functional as fn | |
| import random | |
| from tqdm import tqdm | |
| from sklearn.metrics import confusion_matrix | |
| from sklearn.metrics import roc_curve, auc | |
| import ipywidgets as widgets | |
| from IPython.display import display, clear_output | |
| import matplotlib.pyplot as plt | |
| import warnings | |
| import torch | |
| import time | |
| from sklearn.manifold import TSNE | |
| from copy import deepcopy | |
| import seaborn as sns | |
| import matplotlib.pylab as plt | |
| import json | |
| from pathlib import Path | |
| import re | |
| from collections import defaultdict | |
| # SEED = 2222 | |
| # gen_seed = torch.Generator().manual_seed(SEED) | |
| # Random seed function | |
| def random_seed(value): | |
| torch.backends.cudnn.deterministic=True | |
| torch.manual_seed(value) | |
| torch.cuda.manual_seed(value) | |
| np.random.seed(value) | |
| random.seed(value) | |
| # Tasks for meta-learner | |
| def create_batch_of_tasks(taskset, is_shuffle = True, batch_size = 4): | |
| idxs = list(range(0,len(taskset))) | |
| if is_shuffle: | |
| random.shuffle(idxs) | |
| for i in range(0,len(idxs), batch_size): | |
| yield [taskset[idxs[i]] for i in range(i, min(i + batch_size,len(taskset)))] | |
| # Prepare data to process by Domain-learner | |
| def prepare_data(data, batch_size, tokenizer,max_seq_length, | |
| input = 'text', output = 'label', | |
| train_size_per_class = 5, global_datasets = False, | |
| treat_text_fun =None): | |
| data = data.reset_index().drop("index", axis=1) | |
| if global_datasets: | |
| global data_train, data_test | |
| # Sample task for training | |
| data_train = data.groupby('label').sample(train_size_per_class, replace=False) | |
| idex = data.index.isin(data_train.index) | |
| # The Test set to label by the model | |
| data_test = data | |
| # Transform in dataset to model | |
| ## Train | |
| dataset_train = SLR_DataSet( | |
| data = data_train.sample(frac=1), | |
| input = input, | |
| output = output, | |
| tokenizer=tokenizer, | |
| max_seq_length =max_seq_length, | |
| treat_text =treat_text_fun) | |
| ## Test | |
| dataset_test = SLR_DataSet( | |
| data = data_test, | |
| input = input, | |
| output = output, | |
| tokenizer=tokenizer, | |
| max_seq_length =max_seq_length, | |
| treat_text =treat_text_fun) | |
| # Dataloaders | |
| ## Train | |
| data_train_loader = DataLoader(dataset_train, | |
| shuffle=True, | |
| batch_size=batch_size['train'] | |
| ) | |
| ## Test | |
| if len(dataset_test) % batch_size['test'] == 1 : | |
| data_test_loader = DataLoader(dataset_test, | |
| batch_size=batch_size['test'], | |
| drop_last=True) | |
| else: | |
| data_test_loader = DataLoader(dataset_test, | |
| batch_size=batch_size['test'], | |
| drop_last=False) | |
| return data_train_loader, data_test_loader, data_train, data_test | |
| # Meta trainer | |
| def meta_train(data, model, device, Info, | |
| print_epoch =True, | |
| Test_resource =None, | |
| treat_text_fun =None): | |
| # Meta-learner model | |
| learner = Learner(model = model, device = device, **Info) | |
| # Testing tasks | |
| if isinstance(Test_resource, pd.DataFrame): | |
| test = MetaTask(Test_resource, num_task = 0, k_support=10, k_query=10, | |
| training=False,treat_text =treat_text_fun, **Info) | |
| torch.clear_autocast_cache() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # Meta epoch (Outer epoch) | |
| for epoch in tqdm(range(Info['meta_epoch']), desc= "Meta epoch ", ncols=80): | |
| # Train tasks | |
| train = MetaTask(data, | |
| num_task = Info['num_task_train'], | |
| k_support=Info['k_qry'], | |
| k_query=Info['k_spt'], | |
| treat_text =treat_text_fun, **Info) | |
| # Batch of train tasks | |
| db = create_batch_of_tasks(train, is_shuffle = True, batch_size = Info["outer_batch_size"]) | |
| if print_epoch: | |
| # Outer loop bach training | |
| for step, task_batch in enumerate(db): | |
| print("\n-----------------Training Mode","Meta_epoch:", epoch ,"-----------------\n") | |
| # meta-feedfoward (outer-feedfoward) | |
| acc = learner(task_batch, valid_train= print_epoch) | |
| print('Step:', step, '\ttraining Acc:', acc) | |
| if isinstance(Test_resource, pd.DataFrame): | |
| # Validating Model | |
| if ((epoch+1) % 4) + step == 0: | |
| random_seed(123) | |
| print("\n-----------------Testing Mode-----------------\n") | |
| # Batch of test tasks | |
| db_test = create_batch_of_tasks(test, is_shuffle = False, batch_size = 1) | |
| acc_all_test = [] | |
| # Looping testing tasks | |
| for test_batch in db_test: | |
| acc = learner(test_batch, training = False) | |
| acc_all_test.append(acc) | |
| print('Test acc:', np.mean(acc_all_test)) | |
| del acc_all_test, db_test | |
| # Restarting training randomly | |
| random_seed(int(time.time() % 10)) | |
| else: | |
| for step, task_batch in enumerate(db): | |
| # meta-feedfoward (outer-feedfoward) | |
| acc = learner(task_batch, print_epoch, valid_train= print_epoch) | |
| torch.clear_autocast_cache() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def train_loop(data_train_loader, data_test_loader, model, device, epoch = 4, lr = 1, print_info = True, name = 'name', weight_decay = 1): | |
| # Start the model's parameters | |
| model_meta = deepcopy(model) | |
| optimizer = Adam(model_meta.parameters(), lr=lr, weight_decay = weight_decay) | |
| model_meta.to(device) | |
| model_meta.train() | |
| # Task epoch (Inner epoch) | |
| for i in range(0, epoch): | |
| all_loss = [] | |
| # Inner training batch (support set) | |
| for inner_step, batch in enumerate(data_train_loader): | |
| batch = tuple(t.to(device) for t in batch) | |
| input_ids, attention_mask,q_token_type_ids, label_id = batch | |
| # Inner Feedfoward | |
| loss, _, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze()) | |
| # compute grads | |
| loss.backward() | |
| # update parameters | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| all_loss.append(loss.item()) | |
| if (i % 2 == 0) & print_info: | |
| print("Loss: ", np.mean(all_loss)) | |
| # Test evaluation | |
| model_meta.eval() | |
| all_loss = [] | |
| all_acc = [] | |
| features = [] | |
| labels = [] | |
| predi_logit = [] | |
| with torch.no_grad(): | |
| # Test's Batch loop | |
| for inner_step, batch in enumerate(tqdm(data_test_loader, | |
| desc="Test validation | " + name, | |
| ncols=80)) : | |
| batch = tuple(t.to(device) for t in batch) | |
| input_ids, attention_mask,q_token_type_ids, label_id = batch | |
| # Predictions | |
| _, feature, _ = model_meta(input_ids, attention_mask,q_token_type_ids, labels = label_id.squeeze()) | |
| # prediction = prediction.detach().cpu().squeeze() | |
| # label_id = label_id.detach().cpu() | |
| logit = feature[1].detach().cpu() | |
| # feature_lat = feature[0].detach().cpu() | |
| # labels.append(label_id.numpy().squeeze()) | |
| # features.append(feature_lat.numpy()) | |
| predi_logit.append(logit.numpy()) | |
| # Accuracy over the test's bach | |
| # acc = fn.accuracy(prediction, label_id).item() | |
| # all_acc.append(acc) | |
| del input_ids, attention_mask, label_id, batch | |
| if print_info: | |
| print("acc:", np.mean(all_acc)) | |
| model_meta.to('cpu') | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| del model_meta, optimizer | |
| logits = np.concatenate(np.array(predi_logit,dtype=object)) | |
| logits = torch.tensor(logits.astype(np.float32)).detach().clone() | |
| # return features, labels, predi_logit | |
| return logits.detach().clone() | |
| # Process predictions and map the feature_map in tsne | |
| def map_feature_tsne(features, labels, predi_logit): | |
| features = np.concatenate(np.array(features,dtype=object)) | |
| features = torch.tensor(features.astype(np.float32)).detach().clone() | |
| labels = np.concatenate(np.array(labels,dtype=object)) | |
| labels = torch.tensor(labels.astype(int)).detach().clone() | |
| logits = np.concatenate(np.array(predi_logit,dtype=object)) | |
| logits = torch.tensor(logits.astype(np.float32)).detach().clone() | |
| # Dimention reduction | |
| X_embedded = TSNE(n_components=2, learning_rate='auto', | |
| init='random').fit_transform(features.detach().clone()) | |
| return logits.detach().clone(), X_embedded, labels.detach().clone(), features.detach().clone() | |
| def wss_calc(logit, labels, trsh = 0.5): | |
| # Prediction label given the threshold | |
| predict_trash = torch.sigmoid(logit).squeeze() >= trsh | |
| # Compute confusion matrix values | |
| CM = confusion_matrix(labels, predict_trash.to(int) ) | |
| tn, fp, fne, tp = CM.ravel() | |
| P = (tp + fne) | |
| N = (tn + fp) | |
| recall = tp/(tp+fne) | |
| # WSS | |
| wss = (tn + fne)/len(labels) -(1- recall) | |
| # AWSS | |
| awss = (tn/N - fne/P) | |
| return { | |
| "wss": round(wss,4), | |
| "awss": round(awss,4), | |
| "R": round(recall,4), | |
| "CM": CM | |
| } | |
| # Compute the metrics | |
| def plot(logits, X_embedded, labels, threshold, show = True, | |
| namefig = "plot", make_plot = True, print_stats = True, save = True): | |
| col = pd.MultiIndex.from_tuples([ | |
| ("Predict", "0"), | |
| ("Predict", "1") | |
| ]) | |
| index = pd.MultiIndex.from_tuples([ | |
| ("Real", "0"), | |
| ("Real", "1") | |
| ]) | |
| predict = torch.sigmoid(logits).detach().clone() | |
| # Roc curve | |
| fpr, tpr, thresholds = roc_curve(labels, predict.squeeze()) | |
| # Given by a Recall of 95% (threshold avaliation) | |
| ## WSS | |
| ### Index to recall | |
| idx_wss95 = sum(tpr < 0.95) | |
| ### threshold | |
| thresholds95 = thresholds[idx_wss95] | |
| ### Compute the metrics | |
| wss95_info = wss_calc(logits,labels, thresholds95 ) | |
| acc_wss95 = fn.accuracy(predict, labels, threshold=thresholds95) | |
| f1_wss95 = fn.f1_score(predict, labels, threshold=thresholds95) | |
| # Given by a threshold (recall avaliation) | |
| ### Compute the metrics | |
| wss_info = wss_calc(logits,labels, threshold ) | |
| acc_wssR = fn.accuracy(predict, labels, threshold=threshold) | |
| f1_wssR = fn.f1_score(predict, labels, threshold=threshold) | |
| metrics= { | |
| # WSS | |
| "WSS@95": wss95_info['wss'], | |
| "AWSS@95": wss95_info['awss'], | |
| "WSS@R": wss_info['wss'], | |
| "AWSS@R": wss_info['awss'], | |
| # Recall | |
| "Recall_WSS@95": wss95_info['R'], | |
| "Recall_WSS@R": wss_info['R'], | |
| # acc | |
| "acc@95": acc_wss95.item(), | |
| "acc@R": acc_wssR.item(), | |
| # f1 | |
| "f1@95": f1_wss95.item(), | |
| "f1@R": f1_wssR.item(), | |
| # threshold 95 | |
| "threshold@95": thresholds95 | |
| } | |
| # Print stats | |
| if print_stats: | |
| wss95= f"WSS@95:{wss95_info['wss']}, R: {wss95_info['R']}" | |
| wss95_adj= f"ASSWSS@95:{wss95_info['awss']}" | |
| print(wss95) | |
| print(wss95_adj) | |
| print('Acc.:', round(acc_wss95.item(), 4)) | |
| print('F1-score:', round(f1_wss95.item(), 4)) | |
| print(f"threshold to wss95: {round(thresholds95, 4)}") | |
| cm = pd.DataFrame(wss95_info['CM'], | |
| index=index, | |
| columns=col) | |
| print("\nConfusion matrix:") | |
| print(cm) | |
| print("\n---Metrics with threshold:", threshold, "----\n") | |
| wss= f"WSS@R:{wss_info['wss']}, R: {wss_info['R']}" | |
| print(wss) | |
| wss_adj= f"AWSS@R:{wss_info['awss']}" | |
| print(wss_adj) | |
| print('Acc.:', round(acc_wssR.item(), 4)) | |
| print('F1-score:', round(f1_wssR.item(), 4)) | |
| cm = pd.DataFrame(wss_info['CM'], | |
| index=index, | |
| columns=col) | |
| print("\nConfusion matrix:") | |
| print(cm) | |
| # Plots | |
| if make_plot: | |
| fig, axes = plt.subplots(1, 4, figsize=(25,10)) | |
| alpha = torch.squeeze(predict).numpy() | |
| # TSNE | |
| p1 = sns.scatterplot(x=X_embedded[:, 0], | |
| y=X_embedded[:, 1], | |
| hue=labels, | |
| alpha=alpha, ax = axes[0]).set_title('Predictions-TSNE', size=20) | |
| # WSS@95 | |
| t_wss = predict >= thresholds95 | |
| t_wss = t_wss.squeeze().numpy() | |
| p2 = sns.scatterplot(x=X_embedded[t_wss, 0], | |
| y=X_embedded[t_wss, 1], | |
| hue=labels[t_wss], | |
| alpha=alpha[t_wss], ax = axes[1]).set_title('WSS@95', size=20) | |
| # WSS@R | |
| t = predict >= threshold | |
| t = t.squeeze().numpy() | |
| p3 = sns.scatterplot(x=X_embedded[t, 0], | |
| y=X_embedded[t, 1], | |
| hue=labels[t], | |
| alpha=alpha[t], ax = axes[2]).set_title(f'Predictions-threshold {threshold}', size=20) | |
| # ROC-Curve | |
| roc_auc = auc(fpr, tpr) | |
| lw = 2 | |
| axes[3].plot( | |
| fpr, | |
| tpr, | |
| color="darkorange", | |
| lw=lw, | |
| label="ROC curve (area = %0.2f)" % roc_auc) | |
| axes[3].plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--") | |
| axes[3].axhline(y=0.95, color='r', linestyle='-') | |
| # axes[3].set(xlabel="False Positive Rate", ylabel="True Positive Rate") | |
| axes[3].legend(loc="lower right") | |
| axes[3].set_title(label= "ROC", size = 20) | |
| axes[3].set_ylabel("True Positive Rate", fontsize = 15) | |
| axes[3].set_xlabel("False Positive Rate", fontsize = 15) | |
| if show: | |
| plt.show() | |
| if save: | |
| fig.savefig(namefig, dpi=fig.dpi) | |
| return metrics | |
| def auc_plot(logits,labels, color = "darkorange", label = "test"): | |
| predict = torch.sigmoid(logits).detach().clone() | |
| fpr, tpr, thresholds = roc_curve(labels, predict.squeeze()) | |
| roc_auc = auc(fpr, tpr) | |
| lw = 2 | |
| label = label + str(round(roc_auc,2)) | |
| # print(label) | |
| plt.plot( | |
| fpr, | |
| tpr, | |
| color=color, | |
| lw=lw, | |
| label= label | |
| ) | |
| plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--") | |
| plt.axhline(y=0.95, color='r', linestyle='-') | |
| # Interface to evaluation | |
| class diagnosis(): | |
| def __init__(self, names, Valid_resource, batch_size_test, | |
| model,Info, device,treat_text_fun=None,start = 0): | |
| self.names=names | |
| self.Valid_resource=Valid_resource | |
| self.batch_size_test=batch_size_test | |
| self.model=model | |
| self.start=start | |
| self.Info = Info | |
| self.device = device | |
| self.treat_text_fun = treat_text_fun | |
| # BOX INPUT | |
| self.value_trash = widgets.FloatText( | |
| value=0.95, | |
| description='threshold', | |
| disabled=False | |
| ) | |
| self.valueb = widgets.IntText( | |
| value=10, | |
| description='size', | |
| disabled=False | |
| ) | |
| # Buttons | |
| self.train_b = widgets.Button(description="Train") | |
| self.next_b = widgets.Button(description="Next") | |
| self.eval_b = widgets.Button(description="Evaluation") | |
| self.hbox = widgets.HBox([self.train_b, self.valueb]) | |
| # Click buttons functions | |
| self.next_b.on_click(self.Next_button) | |
| self.train_b.on_click(self.Train_button) | |
| self.eval_b.on_click(self.Evaluation_button) | |
| # Next button | |
| def Next_button(self,p): | |
| clear_output() | |
| self.i=self.i+1 | |
| # Select the domain data | |
| self.domain = self.names[self.i] | |
| self.data = self.Valid_resource[self.Valid_resource['domain'] == self.domain] | |
| print("Name:", self.domain) | |
| print(self.data['label'].value_counts()) | |
| display(self.hbox) | |
| display(self.next_b) | |
| # Train button | |
| def Train_button(self, y): | |
| clear_output() | |
| print(self.domain) | |
| # Prepare data for training (domain-learner) | |
| self.data_train_loader, self.data_test_loader, self.data_train, self.data_test = prepare_data(self.data, | |
| train_size_per_class = self.valueb.value, | |
| batch_size = {'train': self.Info['inner_batch_size'], | |
| 'test': self.batch_size_test}, | |
| max_seq_length = self.Info['max_seq_length'], | |
| tokenizer = self.Info['tokenizer'], | |
| input = "text", | |
| output = "label", | |
| treat_text_fun=self.treat_text_fun) | |
| # Train the model and predict in the test set | |
| self.logits, self.X_embedded, self.labels, self.features = train_loop(self.data_train_loader, self.data_test_loader, | |
| self.model, self.device, | |
| epoch = self.Info['inner_update_step'], | |
| lr=self.Info['inner_update_lr'], | |
| print_info=True, | |
| name = self.domain) | |
| tresh_box = widgets.HBox([self.eval_b, self.value_trash]) | |
| display(self.hbox) | |
| display(tresh_box) | |
| display(self.next_b) | |
| # Evaluation button | |
| def Evaluation_button(self, te): | |
| clear_output() | |
| tresh_box = widgets.HBox([self.eval_b, self.value_trash]) | |
| print(self.domain) | |
| # print("\n") | |
| print("-------Train data-------") | |
| print(data_train['label'].value_counts()) | |
| print("-------Test data-------") | |
| print(data_test['label'].value_counts()) | |
| # print("\n") | |
| display(self.next_b) | |
| display(tresh_box) | |
| display(self.hbox) | |
| # Compute metrics | |
| metrics = plot(self.logits, self.X_embedded, self.labels, | |
| threshold=self.Info['threshold'], show = True, | |
| namefig= 'test', | |
| make_plot = True, | |
| print_stats = True, | |
| save=False) | |
| def __call__(self): | |
| self.i= self.start-1 | |
| clear_output() | |
| display(self.next_b) | |
| # Simulation attemps of domain learner | |
| def pipeline_simulation(Valid_resource, names_to_valid, path_save, | |
| model, Info, device, initializer_model, | |
| treat_text_fun=None): | |
| n_attempt = 5 | |
| batch_test = 100 | |
| # Create a directory to save informations | |
| for name in names_to_valid: | |
| name = re.sub("\.csv", "",name) | |
| Path(path_save + name + "/img").mkdir(parents=True, exist_ok=True) | |
| # Dict to sabe roc curves | |
| roc_stats = defaultdict(lambda: defaultdict( | |
| lambda: defaultdict( | |
| list | |
| ) | |
| ) | |
| ) | |
| all_metrics = [] | |
| # Loop over a list of domains | |
| for name in names_to_valid: | |
| # Select a domain dataset | |
| data = Valid_resource[Valid_resource['domain'] == name].reset_index().drop("index", axis=1) | |
| # Attempts simulation | |
| for attempt in range(n_attempt): | |
| print("---"*4,"attempt", attempt, "---"*4) | |
| # Prepare data to pass to the model | |
| data_train_loader, data_test_loader, _ , _ = prepare_data(data, | |
| train_size_per_class = Info['k_spt'], | |
| batch_size = {'train': Info['inner_batch_size'], | |
| 'test': batch_test}, | |
| max_seq_length = Info['max_seq_length'], | |
| tokenizer = Info['tokenizer'], | |
| input = "text", | |
| output = "label", | |
| treat_text_fun=treat_text_fun) | |
| # Train the model and evaluate on the test set of the domain | |
| logits, X_embedded, labels, features = train_loop(data_train_loader, data_test_loader, | |
| model, device, | |
| epoch = Info['inner_update_step'], | |
| lr=Info['inner_update_lr'], | |
| print_info=False, | |
| name = name) | |
| name_domain = re.sub("\.csv", "",name) | |
| # Compute the metrics | |
| metrics = plot(logits, X_embedded, labels, | |
| threshold=Info['threshold'], show = False, | |
| namefig= path_save + name_domain + "/img/" + str(attempt) + 'plots', | |
| make_plot = True, print_stats = False, save = True) | |
| # Compute the roc-curve | |
| fpr, tpr, _ = roc_curve(labels, torch.sigmoid(logits).squeeze()) | |
| # Save the correspoud information of the domain | |
| metrics['name'] = name_domain | |
| metrics['layer_size'] = Info['bert_layers'] | |
| metrics['attempt'] = attempt | |
| roc_stats[name_domain][str(Info['bert_layers'])]['fpr'].append(fpr.tolist()) | |
| roc_stats[name_domain][str(Info['bert_layers'])]['tpr'].append(tpr.tolist()) | |
| all_metrics.append(metrics) | |
| # Save the metrics and the roc curve of the attemp | |
| pd.DataFrame(all_metrics).to_csv(path_save+ "metrics.csv") | |
| roc_path = path_save + "roc_stats.json" | |
| with open(roc_path, 'w') as fp: | |
| json.dump(roc_stats, fp) | |
| del fpr, tpr, logits, X_embedded, labels | |
| del features, metrics, _ | |
| # Save the information used to evaluate the validation resource | |
| save_info = Info.copy() | |
| save_info['model'] = initializer_model.tokenizer.name_or_path | |
| save_info.pop("tokenizer") | |
| save_info.pop("bert_layers") | |
| info_path = path_save+"info.json" | |
| with open(info_path, 'w') as fp: | |
| json.dump(save_info, fp) | |
| # Loading dataset statistics | |
| def load_data_statistics(paths, names): | |
| size = [] | |
| pos = [] | |
| neg = [] | |
| for p in paths: | |
| data = pd.read_csv(p) | |
| data = data.dropna() | |
| # Dataset size | |
| size.append(len(data)) | |
| # Number of positive labels | |
| pos.append(data['labels'].value_counts()[1]) | |
| # Number of negative labels | |
| neg.append(data['labels'].value_counts()[0]) | |
| del data | |
| info_load = pd.DataFrame({ | |
| "size":size, | |
| "pos":pos, | |
| "neg":neg, | |
| "names":names, | |
| "paths": paths }) | |
| return info_load | |
| # Loading the datasets | |
| def load_data(train_info_load): | |
| col = ['abstract','title', 'labels', 'domain'] | |
| data_train = pd.DataFrame(columns=col) | |
| for p in train_info_load['paths']: | |
| data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']] | |
| data_temp = pd.read_csv(p).loc[:, ['labels', 'title', 'abstract']] | |
| data_temp['domain'] = os.path.basename(p) | |
| data_train = pd.concat([data_train, data_temp]) | |
| data_train['text'] = data_train['title'] + data_train['abstract'].replace(np.nan, '') | |
| return( data_train \ | |
| .replace({"labels":{0:"negative", 1:'positive'}})\ | |
| .rename({"labels":"label"} , axis=1)\ | |
| .loc[ :,("text","domain","label")] | |
| ) | |