param2004 commited on
Commit
690bcb6
Β·
verified Β·
1 Parent(s): 1cc7aa6

Upload 17 files

Browse files
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from src.utils import load_model, load_data
3
+ from src.query_utils import QueryEnhancer
4
+ from src.search import hybrid_search, set_description_texts, set_patient_texts, strong_recall_indices
5
+ from src.ui import apply_custom_css, render_header, render_sidebar, render_chat_history, bot_typing_animation
6
+ from src.pdf_utils import build_chat_pdf
7
+ from src.llm import summarize_with_gemini
8
+ import numpy as np
9
+
10
+ def cosine_similarity(vec1, vec2):
11
+ return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
12
+
13
+ def filter_similar_answers(indices, doctor_embeddings, threshold=0.88):
14
+ """
15
+ Filters out semantically similar answers based on cosine similarity.
16
+
17
+ Args:
18
+ indices (list[int]): candidate indices of top answers (assumed sorted by relevance)
19
+ doctor_embeddings (np.array): embeddings of all doctor's answers (num_samples x dim)
20
+ threshold (float): similarity above which answers are considered duplicates
21
+
22
+ Returns:
23
+ filtered_indices (list[int]): indices of diverse answers
24
+ """
25
+ if len(indices) == 0:
26
+ return []
27
+
28
+ filtered = [indices[0]] # Always keep the first (most relevant) one
29
+ for idx in indices[1:]:
30
+ emb = doctor_embeddings[idx]
31
+ keep = True
32
+ for f_idx in filtered:
33
+ existing_emb = doctor_embeddings[f_idx]
34
+ sim = np.dot(emb, existing_emb) # Cosine sim (since normalized)
35
+ if sim >= threshold:
36
+ keep = False
37
+ break
38
+ if keep:
39
+ filtered.append(idx)
40
+ return filtered
41
+
42
+ # --- 1. Setup UI ---
43
+ apply_custom_css()
44
+ render_header()
45
+ render_sidebar()
46
+
47
+ # --- 2. Load Model & Data ---
48
+ model = load_model()
49
+ data = load_data()
50
+ google_api_key = st.secrets.get("GOOGLE_API_KEY")
51
+
52
+ if not data:
53
+ st.error("❌ Could not load dataset or embeddings. Please check your paths.")
54
+ st.stop()
55
+
56
+ query_enhancer = QueryEnhancer(model)
57
+
58
+ # --- 3. Set texts for recall ---
59
+ set_description_texts(data['description_column'])
60
+ set_patient_texts(data['patient_column'])
61
+
62
+ # --- 4. Initialize Chat History ---
63
+ if "messages" not in st.session_state:
64
+ st.session_state.messages = []
65
+
66
+ render_chat_history(st.session_state.messages)
67
+
68
+ # --- 5. PDF export sidebar ---
69
+ with st.sidebar:
70
+ st.markdown("---")
71
+ st.subheader("Export")
72
+ if st.session_state.get("messages"):
73
+ pdf_buffer = build_chat_pdf(
74
+ st.session_state.messages, title="MediLingua Chat Transcript"
75
+ )
76
+ if pdf_buffer:
77
+ st.download_button(
78
+ label="πŸ“„ Download chat as PDF",
79
+ data=pdf_buffer,
80
+ file_name="medilingua_chat.pdf",
81
+ mime="application/pdf",
82
+ )
83
+ else:
84
+ st.caption(
85
+ "Install `reportlab` to enable PDF export: pip install reportlab"
86
+ )
87
+
88
+ # --- 6. Handle User Input ---
89
+ if prompt := st.chat_input("What is your medical question?"):
90
+ # Show user input immediately
91
+ st.session_state.messages.append({"role": "user", "content": prompt})
92
+ with st.chat_message("user"):
93
+ st.markdown(prompt)
94
+
95
+ # Placeholder for bot response
96
+ with st.chat_message("assistant"):
97
+ message_placeholder = st.empty()
98
+
99
+ # Step 1: Enhance query
100
+ enhanced_query = query_enhancer.enhance_query(prompt)
101
+
102
+ # Step 2: Strong recall + hybrid search
103
+ with st.spinner("πŸ” Searching for top relevant answers..."):
104
+ # 1️⃣ Get top-k candidates
105
+ indices = strong_recall_indices(prompt, top_k=10)
106
+ if not indices or len(indices) < 3:
107
+ indices = hybrid_search(
108
+ model,
109
+ data['question_embeddings'],
110
+ user_query_raw=prompt,
111
+ user_query_enhanced=enhanced_query,
112
+ top_k=10,
113
+ weight_semantic=0.7,
114
+ faiss_top_candidates=256,
115
+ use_exact_match=True,
116
+ use_fuzzy_match=True
117
+ )
118
+
119
+ # 2️⃣ Filter out semantically similar doctor answers to ensure diversity
120
+ indices = filter_similar_answers(indices, data['doctor_embeddings'], threshold=0.88)
121
+
122
+ # Step 3: Summarization / Gemini
123
+ if indices is not None and len(indices) > 0:
124
+ # Gather ALL filtered diverse answers for summarization (used as reference)
125
+ top_answers = [data['original_answers'][i] for i in indices]
126
+ combined_text = " ".join(top_answers)
127
+
128
+ summary = summarize_with_gemini(google_api_key, combined_text, prompt)
129
+
130
+ # Show only the AI's summarized answer (no doctor's notes displayed)
131
+ bot_typing_animation(message_placeholder, summary)
132
+
133
+ response = summary
134
+ else:
135
+ response = "βš•οΈ I couldn’t find any contextually similar answer in the dataset."
136
+ bot_typing_animation(message_placeholder, response)
137
+
138
+ # Save conversation
139
+ st.session_state.messages.append({"role": "assistant", "content": response})
src/__pycache__/llm.cpython-313.pyc ADDED
Binary file (3.77 kB). View file
 
src/__pycache__/pdf_utils.cpython-313.pyc ADDED
Binary file (3.46 kB). View file
 
src/__pycache__/query_utils.cpython-313.pyc ADDED
Binary file (4.41 kB). View file
 
src/__pycache__/search.cpython-313.pyc ADDED
Binary file (11.9 kB). View file
 
src/__pycache__/ui.cpython-313.pyc ADDED
Binary file (5.37 kB). View file
 
src/__pycache__/utils.cpython-313.pyc ADDED
Binary file (4.97 kB). View file
 
src/llm.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def summarize_with_gemini(api_key, doctor_answer, user_question, max_retries=2):
2
+ import requests, json, time, streamlit as st
3
+
4
+ if not api_key:
5
+ st.warning("⚠️ Google API Key missing. Showing full answer instead.")
6
+ return doctor_answer if isinstance(doctor_answer, str) else "\n\n".join(doctor_answer)
7
+
8
+ combined_answer = "\n\n---\n\n".join(doctor_answer) if isinstance(doctor_answer, list) else doctor_answer
9
+
10
+ candidate_models = ["gemini-2.0-flash", "gemini-2.5-flash", "gemini-2.0-pro"]
11
+
12
+ for model_name in candidate_models:
13
+ prompt = f"""You are a professional AI medical assistant.
14
+
15
+ Summarize the doctor's responses clearly, accurately, and concisely for the patient.
16
+ Focus only on medically relevant information that directly answers the user's question.
17
+
18
+ User's Question:
19
+ "{user_question}"
20
+
21
+ Doctor's Answer(s):
22
+ "{combined_answer}"
23
+
24
+ Instructions:
25
+ - Provide a medically correct, patient-friendly summary in simple, clear language.
26
+ - List multiple points as bullets if possible.
27
+ - If the user's question lacks personal details (e.g., gender, age, weight), generate a generalized, gender-neutral summary.
28
+ - Avoid gender-specific recommendations (e.g., consulting a gynecologist) unless the query explicitly mentions gender or related details.
29
+ - Dont forget to mention potential next steps, treatments, or lifestyle changes if doctor's answer has it mentioned.
30
+ - Always include a recommendation to consult a relevant doctor type (e.g., general practitioner, orthopedist) at the end of the summary, unless the doctor's answers already specify a consultation with a specific doctor type.
31
+ - For example, for back pain, recommend consulting an orthopedist or general practitioner unless the query or doctor's answers suggest a more specific specialist.
32
+ - If the doctor's response does not address the question, respond:
33
+ "There is no information related to your question in the doctor's answer, so I generated the best possible answer based on the information provided."""
34
+
35
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent?key={api_key}"
36
+ payload = {"contents": [{"parts": [{"text": prompt}]}]}
37
+ headers = {"Content-Type": "application/json"}
38
+
39
+ for attempt in range(max_retries):
40
+ try:
41
+ resp = requests.post(url, headers=headers, data=json.dumps(payload), timeout=60)
42
+ resp.raise_for_status()
43
+ result = resp.json()
44
+ if "candidates" in result and result["candidates"]:
45
+ return result["candidates"][0]["content"]["parts"][0]["text"].strip()
46
+ except requests.exceptions.HTTPError as e:
47
+ if resp.status_code == 404: break
48
+ time.sleep(1)
49
+ except Exception: time.sleep(1)
50
+
51
+ st.warning("βš•οΈ Could not generate summary. Showing original answer.")
52
+ return combined_answer
src/pdf_utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import List, Dict
3
+
4
+ import streamlit as st
5
+
6
+
7
+ def _import_reportlab():
8
+ try:
9
+ from reportlab.lib.pagesizes import A4
10
+ from reportlab.lib.styles import getSampleStyleSheet
11
+ from reportlab.lib.units import mm
12
+ from reportlab.pdfgen import canvas
13
+ from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer, Table, TableStyle
14
+ from reportlab.lib import colors
15
+ return {
16
+ "A4": A4,
17
+ "getSampleStyleSheet": getSampleStyleSheet,
18
+ "mm": mm,
19
+ "canvas": canvas,
20
+ "Paragraph": Paragraph,
21
+ "SimpleDocTemplate": SimpleDocTemplate,
22
+ "Spacer": Spacer,
23
+ "Table": Table,
24
+ "TableStyle": TableStyle,
25
+ "colors": colors,
26
+ }
27
+ except Exception:
28
+ return None
29
+
30
+
31
+ def build_chat_pdf(messages: List[Dict[str, str]], title: str = "MediLingua Chat Transcript") -> io.BytesIO:
32
+ """
33
+ Create a PDF from chat messages and return as an in-memory BytesIO buffer.
34
+ Each message is a dict with keys: role ('user'|'assistant'), content (str).
35
+ """
36
+ libs = _import_reportlab()
37
+ if libs is None:
38
+ st.error(
39
+ "PDF generation library not found. Install with: `pip install reportlab` and rerun."
40
+ )
41
+ return None
42
+
43
+ buffer = io.BytesIO()
44
+
45
+ doc = libs["SimpleDocTemplate"](buffer, pagesize=libs["A4"], rightMargin=28, leftMargin=28, topMargin=36, bottomMargin=28)
46
+ styles = libs["getSampleStyleSheet"]()
47
+
48
+ elements = []
49
+
50
+ # Title
51
+ title_style = styles["Title"]
52
+ elements.append(libs["Paragraph"](title, title_style))
53
+ elements.append(libs["Spacer"](1, 12))
54
+
55
+ # Build a table-like layout for messages
56
+ data = []
57
+ table_style_cmds = [
58
+ ("VALIGN", (0, 0), (-1, -1), "TOP"),
59
+ ("INNERGRID", (0, 0), (-1, -1), 0.25, libs["colors"].lightgrey),
60
+ ("BOX", (0, 0), (-1, -1), 0.25, libs["colors"].lightgrey),
61
+ ("LEFTPADDING", (0, 0), (-1, -1), 6),
62
+ ("RIGHTPADDING", (0, 0), (-1, -1), 6),
63
+ ("TOPPADDING", (0, 0), (-1, -1), 6),
64
+ ("BOTTOMPADDING", (0, 0), (-1, -1), 6),
65
+ ]
66
+
67
+ role_style = styles["Heading5"]
68
+ body_style = styles["BodyText"]
69
+
70
+ for msg in messages:
71
+ role = msg.get("role", "").capitalize()
72
+ content = msg.get("content", "")
73
+
74
+ # Left column: role, Right column: content paragraph
75
+ role_par = libs["Paragraph"](f"<b>{role}</b>", role_style)
76
+ content_par = libs["Paragraph"](content.replace("\n", "<br/>"), body_style)
77
+ data.append([role_par, content_par])
78
+
79
+ table = libs["Table"](data, colWidths=[30 * libs["mm"], None])
80
+ table.setStyle(libs["TableStyle"](table_style_cmds))
81
+ elements.append(table)
82
+
83
+ doc.build(elements)
84
+ buffer.seek(0)
85
+ return buffer
86
+
87
+
src/query_utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import nltk
3
+ from nltk.tokenize import RegexpTokenizer
4
+ from nltk.corpus import stopwords
5
+ from nltk.stem import WordNetLemmatizer
6
+ from keybert import KeyBERT
7
+
8
+ # --- Download NLTK resources if needed ---
9
+ try:
10
+ stopwords.words('english')
11
+ except LookupError:
12
+ nltk.download('stopwords', quiet=True)
13
+ nltk.download('punkt', quiet=True)
14
+ nltk.download('averaged_perceptron_tagger', quiet=True)
15
+ nltk.download('wordnet', quiet=True)
16
+
17
+ # --- Initialize tools ---
18
+ tokenizer = RegexpTokenizer(r'\w+')
19
+ lemmatizer = WordNetLemmatizer()
20
+ custom_stopwords = set(stopwords.words('english')) - {'no', 'not', 'without', 'due', 'to', 'with', 'on', 'in'}
21
+
22
+ # --- Medical synonym expansion ---
23
+ medical_synonyms = {
24
+ "flu": ["influenza"],
25
+ "cold": ["common cold", "rhinitis"],
26
+ "heart attack": ["myocardial infarction"],
27
+ "diabetes": ["high blood sugar", "hyperglycemia"],
28
+ "bp": ["blood pressure", "hypertension"],
29
+ "hypertension": ["high blood pressure"],
30
+ "asthma": ["respiratory disease"],
31
+ "cough": ["dry cough", "wet cough"],
32
+ "fever": ["temperature", "high fever"]
33
+ }
34
+
35
+ def expand_medical_terms(text: str) -> str:
36
+ """Expands known medical terms with their synonyms for better recall."""
37
+ for key, syns in medical_synonyms.items():
38
+ for syn in syns:
39
+ text = re.sub(rf"\b{key}\b", f"{key} {syn}", text, flags=re.IGNORECASE)
40
+ return text
41
+
42
+ def preprocess_text(text: str) -> str:
43
+ """Minimal preprocessing: lowercase, remove punctuation, collapse spaces."""
44
+ text = str(text).lower()
45
+ text = re.sub(r'[^\w\s]', ' ', text)
46
+ text = re.sub(r'\s+', ' ', text).strip()
47
+ return text
48
+
49
+ class QueryEnhancer:
50
+ """
51
+ Wrapper class to handle query enhancement with local SapBERT + KeyBERT.
52
+ """
53
+ def __init__(self, sentence_transformer_model):
54
+ """
55
+ sentence_transformer_model: the already-loaded local SapBERT SentenceTransformer
56
+ """
57
+ self.kw_model = KeyBERT(sentence_transformer_model)
58
+
59
+ def extract_keywords(self, text: str, top_n: int = 5) -> list:
60
+ """Extracts top keywords using KeyBERT."""
61
+ if not self.kw_model:
62
+ return []
63
+ try:
64
+ keywords = self.kw_model.extract_keywords(
65
+ text,
66
+ keyphrase_ngram_range=(1, 2),
67
+ stop_words='english',
68
+ top_n=top_n
69
+ )
70
+ return [kw[0] for kw in keywords]
71
+ except Exception:
72
+ return []
73
+
74
+ def enhance_query(self, user_query: str) -> str:
75
+ """
76
+ Full query enhancement pipeline:
77
+ - Preprocess text
78
+ - Expand medical synonyms
79
+ - Extract keywords
80
+ - Return combined enhanced query string
81
+ """
82
+ preprocessed = preprocess_text(user_query)
83
+ expanded = expand_medical_terms(preprocessed)
84
+ keywords = self.extract_keywords(user_query)
85
+ enhanced_query = f"{expanded} {' '.join(keywords)}".strip()
86
+ return enhanced_query
src/search.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.feature_extraction.text import TfidfVectorizer
3
+ import faiss
4
+ import re
5
+
6
+ # --- Global caches ---
7
+ tfidf_vectorizer = None
8
+ tfidf_matrix = None
9
+ corpus_texts = None
10
+ faiss_index = None
11
+ embeddings_array = None # FAISS requires float32
12
+ description_texts = None # For exact/fuzzy match
13
+ patient_texts = None # For exact/fuzzy match
14
+ description_norm_texts = None # Normalized (punctuation stripped)
15
+ patient_norm_texts = None # Normalized (punctuation stripped)
16
+
17
+
18
+ def encode_question(model, user_question):
19
+ """Encodes the user's question using the embedding model."""
20
+ if model is None or not user_question.strip():
21
+ return None
22
+ return model.encode([user_question], show_progress_bar=False)[0].astype('float32')
23
+
24
+
25
+ def init_tfidf(data_texts):
26
+ """
27
+ Initialize TF-IDF matrix for hybrid search.
28
+ """
29
+ global tfidf_vectorizer, tfidf_matrix, corpus_texts
30
+ corpus_texts = data_texts
31
+ tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=10000)
32
+ tfidf_matrix = tfidf_vectorizer.fit_transform(corpus_texts)
33
+
34
+
35
+ def init_faiss(embeddings):
36
+ """
37
+ Initialize FAISS index for fast semantic search.
38
+ embeddings: np.array (num_samples x 768) normalized
39
+ """
40
+ global faiss_index, embeddings_array
41
+ embeddings_array = embeddings.astype('float32')
42
+
43
+ # Do not renormalize
44
+ dimension = embeddings_array.shape[1]
45
+ faiss_index = faiss.IndexFlatIP(dimension) # Inner product for cosine similarity
46
+ faiss_index.add(embeddings_array)
47
+
48
+
49
+
50
+ def set_description_texts(texts):
51
+ """
52
+ Provide the Description column for exact/fuzzy match search.
53
+ """
54
+ global description_texts, description_norm_texts
55
+ description_texts = [str(t).lower() for t in texts]
56
+ description_norm_texts = [preprocess_text_for_embeddings(t) for t in texts]
57
+
58
+
59
+ def set_patient_texts(texts):
60
+ """
61
+ Provide the Patient column for exact/fuzzy match search.
62
+ """
63
+ global patient_texts, patient_norm_texts
64
+ patient_texts = [str(t).lower() for t in texts]
65
+ patient_norm_texts = [preprocess_text_for_embeddings(t) for t in texts]
66
+
67
+
68
+ def strong_recall_indices(user_query_raw: str, top_k: int = 10):
69
+ """
70
+ Scan the entire dataset (Description + Patient) for:
71
+ 1) Exact equality on normalized text
72
+ 2) Exact substring presence
73
+ 3) High-threshold fuzzy match (if rapidfuzz available)
74
+
75
+ Returns a list of indices (unique, in priority order) up to top_k.
76
+ """
77
+ global description_texts, patient_texts, description_norm_texts, patient_norm_texts
78
+
79
+ if not user_query_raw:
80
+ return []
81
+
82
+ q_lower = str(user_query_raw).lower()
83
+ q_norm = preprocess_text_for_embeddings(user_query_raw)
84
+
85
+ N_desc = len(description_texts) if description_texts is not None else 0
86
+ N_pat = len(patient_texts) if patient_texts is not None else 0
87
+ N = max(N_desc, N_pat)
88
+ if N == 0:
89
+ return []
90
+
91
+ exact_equal = []
92
+ exact_sub = []
93
+ fuzzy_hits = []
94
+
95
+ # 1) Exact equality on normalized text
96
+ if description_norm_texts is not None:
97
+ exact_equal += [i for i in range(len(description_norm_texts)) if description_norm_texts[i] == q_norm]
98
+ if patient_norm_texts is not None:
99
+ exact_equal += [i for i in range(len(patient_norm_texts)) if patient_norm_texts[i] == q_norm]
100
+
101
+ # Deduplicate preserving order
102
+ seen = set()
103
+ ordered = []
104
+ for i in exact_equal:
105
+ if i not in seen:
106
+ seen.add(i)
107
+ ordered.append(i)
108
+ if len(ordered) >= top_k:
109
+ return ordered[:top_k]
110
+
111
+ # 2) Exact substring presence (lowercased)
112
+ if description_texts is not None:
113
+ exact_sub += [i for i in range(len(description_texts)) if q_lower in description_texts[i]]
114
+ if patient_texts is not None:
115
+ exact_sub += [i for i in range(len(patient_texts)) if q_lower in patient_texts[i]]
116
+
117
+ for i in exact_sub:
118
+ if i not in seen:
119
+ seen.add(i)
120
+ ordered.append(i)
121
+ if len(ordered) >= top_k:
122
+ return ordered[:top_k]
123
+
124
+ # 3) High-threshold fuzzy matches
125
+ try:
126
+ from rapidfuzz import fuzz
127
+ # Use partial_ratio and token_set_ratio; take max as score
128
+ scored = []
129
+ for i in range(N):
130
+ s_desc = description_texts[i] if (description_texts is not None and i < len(description_texts)) else ""
131
+ s_pat = patient_texts[i] if (patient_texts is not None and i < len(patient_texts)) else ""
132
+ score_desc = max(fuzz.partial_ratio(q_lower, s_desc), fuzz.token_set_ratio(q_lower, s_desc)) if s_desc else 0
133
+ score_pat = max(fuzz.partial_ratio(q_lower, s_pat), fuzz.token_set_ratio(q_lower, s_pat)) if s_pat else 0
134
+ score = max(score_desc, score_pat)
135
+ if score >= 90:
136
+ scored.append((i, score))
137
+ # sort by score desc
138
+ scored.sort(key=lambda x: x[1], reverse=True)
139
+ for i, _ in scored:
140
+ if i not in seen:
141
+ seen.add(i)
142
+ ordered.append(i)
143
+ if len(ordered) >= top_k:
144
+ break
145
+ except Exception:
146
+ pass
147
+
148
+ return ordered[:top_k]
149
+
150
+
151
+ def hybrid_search(
152
+ model,
153
+ embeddings,
154
+ user_query_raw,
155
+ user_query_enhanced,
156
+ top_k=5,
157
+ weight_semantic=0.7,
158
+ faiss_top_candidates=256,
159
+ use_exact_match=True,
160
+ use_fuzzy_match=True
161
+ ):
162
+ """
163
+ Hybrid search combining:
164
+ 1. FAISS semantic similarity
165
+ 2. TF-IDF boosting
166
+ 3. Optional exact substring match in Description
167
+
168
+ Returns: list of top indices in dataset
169
+ """
170
+ global tfidf_vectorizer, tfidf_matrix, corpus_texts, faiss_index, embeddings_array, description_texts
171
+
172
+ if model is None or embeddings is None or len(embeddings) == 0:
173
+ return []
174
+
175
+ # Encode enhanced query for semantic/TF-IDF stages
176
+ question_embedding = encode_question(model, user_query_enhanced)
177
+ if question_embedding is None:
178
+ return []
179
+
180
+ # --- 1. FAISS semantic search ---
181
+ if faiss_index is not None:
182
+ D, I = faiss_index.search(np.array([question_embedding]), k=min(faiss_top_candidates, embeddings.shape[0]))
183
+ top_candidates = I[0]
184
+ semantic_sim_top = D[0]
185
+ else:
186
+ semantic_sim_full = np.dot(embeddings, question_embedding)
187
+ top_candidates = np.argpartition(semantic_sim_full, -faiss_top_candidates)[-faiss_top_candidates:]
188
+ top_candidates = top_candidates[np.argsort(semantic_sim_full[top_candidates])[::-1]]
189
+ semantic_sim_top = semantic_sim_full[top_candidates]
190
+
191
+ # --- 2. TF-IDF similarity ---
192
+ if tfidf_vectorizer is not None and tfidf_matrix is not None:
193
+ tfidf_vec = tfidf_vectorizer.transform([user_query_enhanced])
194
+ tfidf_sim_top = (tfidf_matrix[top_candidates] @ tfidf_vec.T).toarray().ravel()
195
+ else:
196
+ tfidf_sim_top = np.zeros(len(top_candidates))
197
+
198
+ # --- 3. Optional exact + fuzzy match across Description & Patient ---
199
+ combined_sim_top = weight_semantic * semantic_sim_top + (1 - weight_semantic) * tfidf_sim_top
200
+
201
+ if use_exact_match or use_fuzzy_match:
202
+ query_lower = user_query_raw.lower()
203
+
204
+ # Exact substring presence boosts
205
+ exact_desc = np.zeros(len(top_candidates))
206
+ exact_pat = np.zeros(len(top_candidates))
207
+ if description_texts is not None:
208
+ exact_desc = np.array([1.0 if query_lower in description_texts[i] else 0.0 for i in top_candidates])
209
+ if patient_texts is not None:
210
+ exact_pat = np.array([1.0 if query_lower in patient_texts[i] else 0.0 for i in top_candidates])
211
+
212
+ # Fuzzy partial ratio via rapidfuzz (graceful fallback)
213
+ fuzzy_desc = np.zeros(len(top_candidates))
214
+ fuzzy_pat = np.zeros(len(top_candidates))
215
+ if use_fuzzy_match:
216
+ try:
217
+ from rapidfuzz import fuzz
218
+ if description_texts is not None:
219
+ fuzzy_desc = np.array([
220
+ fuzz.partial_ratio(query_lower, description_texts[i]) / 100.0 for i in top_candidates
221
+ ])
222
+ if patient_texts is not None:
223
+ fuzzy_pat = np.array([
224
+ fuzz.partial_ratio(query_lower, patient_texts[i]) / 100.0 for i in top_candidates
225
+ ])
226
+ except Exception:
227
+ pass
228
+
229
+ # Token overlap (Jaccard) as an additional weak signal
230
+ def jaccard(a: str, b: str) -> float:
231
+ sa = set(a.split())
232
+ sb = set(b.split())
233
+ if not sa or not sb:
234
+ return 0.0
235
+ inter = len(sa & sb)
236
+ union = len(sa | sb)
237
+ return inter / union if union else 0.0
238
+
239
+ token_desc = np.zeros(len(top_candidates))
240
+ token_pat = np.zeros(len(top_candidates))
241
+ if description_texts is not None:
242
+ token_desc = np.array([jaccard(query_lower, description_texts[i]) for i in top_candidates])
243
+ if patient_texts is not None:
244
+ token_pat = np.array([jaccard(query_lower, patient_texts[i]) for i in top_candidates])
245
+
246
+ # Combine boosters with gentle weights; exact match is strongest
247
+ booster = 0.20 * exact_desc + 0.20 * exact_pat + 0.10 * fuzzy_desc + 0.10 * fuzzy_pat + 0.05 * token_desc + 0.05 * token_pat
248
+ combined_sim_top = combined_sim_top + booster
249
+
250
+ # --- 4. Select final top-k indices ---
251
+ sorted_top_indices = top_candidates[np.argsort(combined_sim_top)[::-1][:top_k]]
252
+
253
+ return sorted_top_indices
254
+
255
+
256
+ # --- Minimal preprocessing for embeddings ---
257
+ def preprocess_text_for_embeddings(text: str) -> str:
258
+ """Lowercase + remove punctuation for embeddings."""
259
+ text = str(text).lower()
260
+ text = re.sub(r'[^\w\s]', ' ', text)
261
+ text = re.sub(r'\s+', ' ', text).strip()
262
+ return text
263
+
264
+
265
+ # --- Minimal preprocessing for keywords ---
266
+ def preprocess_text_for_keywords(text: str) -> str:
267
+ """Lowercase + remove punctuation for keywords."""
268
+ text = str(text).lower()
269
+ text = re.sub(r'[^\w\s]', ' ', text)
270
+ text = re.sub(r'\s+', ' ', text).strip()
271
+ return text
src/ui.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+
4
+ def apply_custom_css():
5
+ """Applies custom CSS for proper left-right chat alignment with enhanced colors."""
6
+ css = """
7
+ <style>
8
+ .stApp {
9
+ background-color: #0f172a;
10
+ color: #e2e8f0;
11
+ font-family: 'Inter', sans-serif;
12
+ }
13
+ .block-container {
14
+ max-width: 900px;
15
+ margin: auto;
16
+ padding-top: 1rem;
17
+ }
18
+ /* Chat message layout override */
19
+ [data-testid="stChatMessage"] {
20
+ display: flex !important;
21
+ align-items: flex-start !important;
22
+ margin-bottom: 0.75rem;
23
+ }
24
+ [data-testid="stChatMessage"] > div[data-testid="stMarkdownContainer"] {
25
+ padding: 0.8rem 1rem;
26
+ border-radius: 16px;
27
+ max-width: 70%;
28
+ line-height: 1.5;
29
+ font-size: 0.95rem;
30
+ word-wrap: break-word;
31
+ box-shadow: 0 4px 12px rgba(0,0,0,0.25);
32
+ transition: all 0.2s ease-in-out;
33
+ animation: fadeIn 0.3s ease;
34
+ }
35
+ @keyframes fadeIn {
36
+ from { opacity: 0; transform: translateY(4px); }
37
+ to { opacity: 1; transform: translateY(0); }
38
+ }
39
+ /* Assistant (left) */
40
+ [data-testid="stChatMessage"]:has(.stChatMessageContent[data-testid="assistant"]) {
41
+ justify-content: flex-start !important;
42
+ }
43
+ [data-testid="stChatMessage"]:has(.stChatMessageContent[data-testid="assistant"])
44
+ > div[data-testid="stMarkdownContainer"] {
45
+ background-color: #1e293b;
46
+ color: #f1f5f9;
47
+ border: 1px solid #334155;
48
+ text-align: left;
49
+ }
50
+ /* User (right) */
51
+ [data-testid="stChatMessage"]:has(.stChatMessageContent[data-testid="user"]) {
52
+ justify-content: flex-end !important;
53
+ }
54
+ [data-testid="stChatMessage"]:has(.stChatMessageContent[data-testid="user"])
55
+ > div[data-testid="stMarkdownContainer"] {
56
+ background-color: #2563eb;
57
+ color: white;
58
+ border: 1px solid #1d4ed8;
59
+ text-align: right;
60
+ }
61
+ /* Expander (doctor notes) */
62
+ .streamlit-expanderHeader {
63
+ background: #111827;
64
+ color: #cbd5e1;
65
+ border: 1px solid #374151;
66
+ border-radius: 10px;
67
+ }
68
+ .streamlit-expanderContent {
69
+ background: #0b1220;
70
+ border-left: 2px solid #334155;
71
+ }
72
+ /* Scrollbar style */
73
+ ::-webkit-scrollbar { width: 8px; }
74
+ ::-webkit-scrollbar-thumb { background-color: #334155; border-radius: 10px; }
75
+ /* Header/title */
76
+ h1 {
77
+ text-align: center;
78
+ color: #60a5fa;
79
+ font-weight: 600;
80
+ }
81
+ p[style*='text-align: center;'] {
82
+ color: #94a3b8;
83
+ }
84
+ </style>
85
+ """
86
+ st.markdown(css, unsafe_allow_html=True)
87
+
88
+
89
+ def render_header():
90
+ st.title("πŸ€– MediLingua: Your Medical Assistant")
91
+ st.markdown(
92
+ "<p style='text-align: center; font-size: 1.1rem;'>Ask medical questions and get summarized answers from real doctor responses.</p>",
93
+ unsafe_allow_html=True
94
+ )
95
+
96
+
97
+ def render_sidebar():
98
+ with st.sidebar:
99
+ st.header("βš™οΈ Configuration")
100
+ if st.secrets.get("GOOGLE_API_KEY"):
101
+ st.success("βœ… Google API Key configured.")
102
+ else:
103
+ st.error("❌ Missing API Key in `.streamlit/secrets.toml`.")
104
+ st.markdown("---")
105
+ st.markdown("πŸ’‘ Built with **Streamlit** & **Gemini**.")
106
+
107
+
108
+ def render_chat_history(messages):
109
+ """Render previous messages."""
110
+ for message in messages:
111
+ with st.chat_message(message["role"]):
112
+ st.markdown(message["content"])
113
+
114
+
115
+ def bot_typing_animation(message_placeholder, final_text, delay=0.02):
116
+ """
117
+ Simulate bot typing animation in chat.
118
+ """
119
+ message_placeholder.markdown("") # Empty initially
120
+ displayed = ""
121
+ for char in final_text:
122
+ displayed += char
123
+ message_placeholder.markdown(displayed)
124
+ time.sleep(delay)
src/utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ import pandas as pd
4
+ from sentence_transformers import SentenceTransformer, models
5
+ import torch
6
+ import numpy as np
7
+ from src.search import init_faiss
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # Repo IDs
11
+ DATASET_REPO = "param2004/Medilingua-dataset"
12
+ MODEL_REPO = "param2004/Medilingua-model"
13
+
14
+ @st.cache_resource
15
+ def load_model():
16
+ """Load SapBERT dynamically from Hugging Face Hub"""
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+
19
+ st.info(f"πŸ”¬ Loading SapBERT from Hugging Face Hub on {device.upper()}...")
20
+
21
+ # Download model files dynamically
22
+ try:
23
+ model_path = hf_hub_download(
24
+ repo_id=MODEL_REPO,
25
+ filename="models/SapBERT-from-PubMedBERT-fulltext/pytorch_model.bin"
26
+ )
27
+
28
+ # Load SentenceTransformer as before
29
+ word_embedding_model = models.Transformer(model_path)
30
+ pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
31
+ model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)
32
+
33
+ st.success("βœ… SapBERT loaded successfully from Hub.")
34
+ except Exception as e:
35
+ st.error(f"❌ Failed to load SapBERT from Hub. Details: {e}")
36
+ st.warning("⚠️ Falling back to 'all-MiniLM-L6-v2' model.")
37
+ model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
38
+
39
+ return model
40
+
41
+
42
+ @st.cache_resource
43
+ def load_data():
44
+ """Load embeddings and dataset dynamically from Hugging Face Hub"""
45
+ try:
46
+ # Download embeddings
47
+ question_emb_path = hf_hub_download(
48
+ repo_id=DATASET_REPO,
49
+ filename="dataset/question_embeddings.pkl"
50
+ )
51
+ doctor_emb_path = hf_hub_download(
52
+ repo_id=DATASET_REPO,
53
+ filename="dataset/doctor_embeddings.pkl"
54
+ )
55
+ dataset_csv_path = hf_hub_download(
56
+ repo_id=DATASET_REPO,
57
+ filename="dataset/dataset.csv"
58
+ )
59
+
60
+ # Load embeddings
61
+ with open(question_emb_path, 'rb') as f:
62
+ question_data = pickle.load(f)
63
+ question_embeddings = question_data.get('embeddings').astype('float32')
64
+
65
+ with open(doctor_emb_path, 'rb') as f:
66
+ doctor_data = pickle.load(f)
67
+ doctor_embeddings = doctor_data.get('embeddings').astype('float32')
68
+
69
+ # Load CSV
70
+ df = pd.read_csv(dataset_csv_path)
71
+ df.dropna(subset=['Description', 'Patient', 'Doctor'], inplace=True)
72
+ df.drop_duplicates(inplace=True)
73
+
74
+ num_samples = min(len(df), len(question_embeddings), len(doctor_embeddings))
75
+ df = df.iloc[:num_samples]
76
+ question_embeddings = question_embeddings[:num_samples]
77
+ doctor_embeddings = doctor_embeddings[:num_samples]
78
+
79
+ st.success(f"βœ… Loaded {num_samples} rows with SapBERT embeddings ({question_embeddings.shape[1]} dims)")
80
+
81
+ # Initialize FAISS
82
+ init_faiss(question_embeddings)
83
+
84
+ return {
85
+ "question_embeddings": question_embeddings,
86
+ "doctor_embeddings": doctor_embeddings,
87
+ "description_column": df["Description"].tolist(),
88
+ "patient_column": df["Patient"].tolist(),
89
+ "original_answers": df["Doctor"].tolist(),
90
+ }
91
+
92
+ except Exception as e:
93
+ st.error(f"❌ Error loading dataset or embeddings: {e}")
94
+ return None