param2004's picture
Update src/utils.py
bfade87 verified
import streamlit as st
import pickle
import pandas as pd
import numpy as np
import torch
from sentence_transformers import SentenceTransformer, models
from huggingface_hub import hf_hub_download
from src.search import init_faiss
# Hugging Face repo IDs
DATASET_REPO = "param2004/Medilingua-dataset"
MODEL_REPO = "param2004/Medilingua-model"
@st.cache_resource
def load_model():
"""Load SapBERT dynamically from Hugging Face Hub"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
st.info(f"πŸ”¬ Loading SapBERT from Hugging Face Hub on {device.upper()}...")
try:
# Download model files from Hub
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename="models/SapBERT-from-PubMedBERT-fulltext/pytorch_model.bin"
)
# Build SentenceTransformer manually
word_embedding_model = models.Transformer(model_path)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)
st.success("βœ… SapBERT loaded successfully from Hugging Face Hub.")
except Exception as e:
st.error(f"❌ Failed to load SapBERT from Hub: {e}")
st.warning("⚠️ Falling back to 'all-MiniLM-L6-v2' model.")
model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
return model
@st.cache_resource
def load_data():
"""Load embeddings and dataset dynamically from Hugging Face Hub"""
try:
# Download embeddings & CSV from Hub
question_emb_path = hf_hub_download(DATASET_REPO, filename="dataset/question_embeddings.pkl")
doctor_emb_path = hf_hub_download(DATASET_REPO, filename="dataset/doctor_embeddings.pkl")
dataset_csv_path = hf_hub_download(DATASET_REPO, filename="dataset/dataset.csv")
# Load embeddings
with open(question_emb_path, 'rb') as f:
question_data = pickle.load(f)
question_embeddings = question_data.get('embeddings').astype('float32')
with open(doctor_emb_path, 'rb') as f:
doctor_data = pickle.load(f)
doctor_embeddings = doctor_data.get('embeddings').astype('float32')
# Load dataset CSV
df = pd.read_csv(dataset_csv_path)
df.dropna(subset=['Description', 'Patient', 'Doctor'], inplace=True)
df.drop_duplicates(inplace=True)
# Ensure all arrays align
num_samples = min(len(df), len(question_embeddings), len(doctor_embeddings))
df = df.iloc[:num_samples]
question_embeddings = question_embeddings[:num_samples]
doctor_embeddings = doctor_embeddings[:num_samples]
st.success(f"βœ… Loaded {num_samples} rows with SapBERT embeddings ({question_embeddings.shape[1]} dims)")
# Initialize FAISS
init_faiss(question_embeddings)
return {
"question_embeddings": question_embeddings,
"doctor_embeddings": doctor_embeddings,
"description_column": df["Description"].tolist(),
"patient_column": df["Patient"].tolist(),
"original_answers": df["Doctor"].tolist(),
}
except Exception as e:
st.error(f"❌ Error loading dataset or embeddings: {e}")
return None