import streamlit as st import pickle import pandas as pd import numpy as np import torch from sentence_transformers import SentenceTransformer, models from huggingface_hub import hf_hub_download from src.search import init_faiss # Hugging Face repo IDs DATASET_REPO = "param2004/Medilingua-dataset" MODEL_REPO = "param2004/Medilingua-model" @st.cache_resource def load_model(): """Load SapBERT dynamically from Hugging Face Hub""" device = 'cuda' if torch.cuda.is_available() else 'cpu' st.info(f"🔬 Loading SapBERT from Hugging Face Hub on {device.upper()}...") try: # Download model files from Hub model_path = hf_hub_download( repo_id=MODEL_REPO, filename="models/SapBERT-from-PubMedBERT-fulltext/pytorch_model.bin" ) # Build SentenceTransformer manually word_embedding_model = models.Transformer(model_path) pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device) st.success("✅ SapBERT loaded successfully from Hugging Face Hub.") except Exception as e: st.error(f"❌ Failed to load SapBERT from Hub: {e}") st.warning("⚠️ Falling back to 'all-MiniLM-L6-v2' model.") model = SentenceTransformer('all-MiniLM-L6-v2', device=device) return model @st.cache_resource def load_data(): """Load embeddings and dataset dynamically from Hugging Face Hub""" try: # Download embeddings & CSV from Hub question_emb_path = hf_hub_download(DATASET_REPO, filename="dataset/question_embeddings.pkl") doctor_emb_path = hf_hub_download(DATASET_REPO, filename="dataset/doctor_embeddings.pkl") dataset_csv_path = hf_hub_download(DATASET_REPO, filename="dataset/dataset.csv") # Load embeddings with open(question_emb_path, 'rb') as f: question_data = pickle.load(f) question_embeddings = question_data.get('embeddings').astype('float32') with open(doctor_emb_path, 'rb') as f: doctor_data = pickle.load(f) doctor_embeddings = doctor_data.get('embeddings').astype('float32') # Load dataset CSV df = pd.read_csv(dataset_csv_path) df.dropna(subset=['Description', 'Patient', 'Doctor'], inplace=True) df.drop_duplicates(inplace=True) # Ensure all arrays align num_samples = min(len(df), len(question_embeddings), len(doctor_embeddings)) df = df.iloc[:num_samples] question_embeddings = question_embeddings[:num_samples] doctor_embeddings = doctor_embeddings[:num_samples] st.success(f"✅ Loaded {num_samples} rows with SapBERT embeddings ({question_embeddings.shape[1]} dims)") # Initialize FAISS init_faiss(question_embeddings) return { "question_embeddings": question_embeddings, "doctor_embeddings": doctor_embeddings, "description_column": df["Description"].tolist(), "patient_column": df["Patient"].tolist(), "original_answers": df["Doctor"].tolist(), } except Exception as e: st.error(f"❌ Error loading dataset or embeddings: {e}") return None