Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pickle | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from sentence_transformers import SentenceTransformer, models | |
| from huggingface_hub import hf_hub_download | |
| from src.search import init_faiss | |
| # Hugging Face repo IDs | |
| DATASET_REPO = "param2004/Medilingua-dataset" | |
| MODEL_REPO = "param2004/Medilingua-model" | |
| def load_model(): | |
| """Load SapBERT dynamically from Hugging Face Hub""" | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| st.info(f"π¬ Loading SapBERT from Hugging Face Hub on {device.upper()}...") | |
| try: | |
| # Download model files from Hub | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="models/SapBERT-from-PubMedBERT-fulltext/pytorch_model.bin" | |
| ) | |
| # Build SentenceTransformer manually | |
| word_embedding_model = models.Transformer(model_path) | |
| pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) | |
| model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device) | |
| st.success("β SapBERT loaded successfully from Hugging Face Hub.") | |
| except Exception as e: | |
| st.error(f"β Failed to load SapBERT from Hub: {e}") | |
| st.warning("β οΈ Falling back to 'all-MiniLM-L6-v2' model.") | |
| model = SentenceTransformer('all-MiniLM-L6-v2', device=device) | |
| return model | |
| def load_data(): | |
| """Load embeddings and dataset dynamically from Hugging Face Hub""" | |
| try: | |
| # Download embeddings & CSV from Hub | |
| question_emb_path = hf_hub_download(DATASET_REPO, filename="dataset/question_embeddings.pkl") | |
| doctor_emb_path = hf_hub_download(DATASET_REPO, filename="dataset/doctor_embeddings.pkl") | |
| dataset_csv_path = hf_hub_download(DATASET_REPO, filename="dataset/dataset.csv") | |
| # Load embeddings | |
| with open(question_emb_path, 'rb') as f: | |
| question_data = pickle.load(f) | |
| question_embeddings = question_data.get('embeddings').astype('float32') | |
| with open(doctor_emb_path, 'rb') as f: | |
| doctor_data = pickle.load(f) | |
| doctor_embeddings = doctor_data.get('embeddings').astype('float32') | |
| # Load dataset CSV | |
| df = pd.read_csv(dataset_csv_path) | |
| df.dropna(subset=['Description', 'Patient', 'Doctor'], inplace=True) | |
| df.drop_duplicates(inplace=True) | |
| # Ensure all arrays align | |
| num_samples = min(len(df), len(question_embeddings), len(doctor_embeddings)) | |
| df = df.iloc[:num_samples] | |
| question_embeddings = question_embeddings[:num_samples] | |
| doctor_embeddings = doctor_embeddings[:num_samples] | |
| st.success(f"β Loaded {num_samples} rows with SapBERT embeddings ({question_embeddings.shape[1]} dims)") | |
| # Initialize FAISS | |
| init_faiss(question_embeddings) | |
| return { | |
| "question_embeddings": question_embeddings, | |
| "doctor_embeddings": doctor_embeddings, | |
| "description_column": df["Description"].tolist(), | |
| "patient_column": df["Patient"].tolist(), | |
| "original_answers": df["Doctor"].tolist(), | |
| } | |
| except Exception as e: | |
| st.error(f"β Error loading dataset or embeddings: {e}") | |
| return None | |