Spaces:
Sleeping
Sleeping
Upload 17 files
Browse files- app.py +139 -0
- src/__pycache__/llm.cpython-313.pyc +0 -0
- src/__pycache__/pdf_utils.cpython-313.pyc +0 -0
- src/__pycache__/query_utils.cpython-313.pyc +0 -0
- src/__pycache__/search.cpython-313.pyc +0 -0
- src/__pycache__/ui.cpython-313.pyc +0 -0
- src/__pycache__/utils.cpython-313.pyc +0 -0
- src/llm.py +52 -0
- src/pdf_utils.py +87 -0
- src/query_utils.py +86 -0
- src/search.py +271 -0
- src/ui.py +124 -0
- src/utils.py +94 -0
app.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from src.utils import load_model, load_data
|
| 3 |
+
from src.query_utils import QueryEnhancer
|
| 4 |
+
from src.search import hybrid_search, set_description_texts, set_patient_texts, strong_recall_indices
|
| 5 |
+
from src.ui import apply_custom_css, render_header, render_sidebar, render_chat_history, bot_typing_animation
|
| 6 |
+
from src.pdf_utils import build_chat_pdf
|
| 7 |
+
from src.llm import summarize_with_gemini
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
def cosine_similarity(vec1, vec2):
|
| 11 |
+
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
| 12 |
+
|
| 13 |
+
def filter_similar_answers(indices, doctor_embeddings, threshold=0.88):
|
| 14 |
+
"""
|
| 15 |
+
Filters out semantically similar answers based on cosine similarity.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
indices (list[int]): candidate indices of top answers (assumed sorted by relevance)
|
| 19 |
+
doctor_embeddings (np.array): embeddings of all doctor's answers (num_samples x dim)
|
| 20 |
+
threshold (float): similarity above which answers are considered duplicates
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
filtered_indices (list[int]): indices of diverse answers
|
| 24 |
+
"""
|
| 25 |
+
if len(indices) == 0:
|
| 26 |
+
return []
|
| 27 |
+
|
| 28 |
+
filtered = [indices[0]] # Always keep the first (most relevant) one
|
| 29 |
+
for idx in indices[1:]:
|
| 30 |
+
emb = doctor_embeddings[idx]
|
| 31 |
+
keep = True
|
| 32 |
+
for f_idx in filtered:
|
| 33 |
+
existing_emb = doctor_embeddings[f_idx]
|
| 34 |
+
sim = np.dot(emb, existing_emb) # Cosine sim (since normalized)
|
| 35 |
+
if sim >= threshold:
|
| 36 |
+
keep = False
|
| 37 |
+
break
|
| 38 |
+
if keep:
|
| 39 |
+
filtered.append(idx)
|
| 40 |
+
return filtered
|
| 41 |
+
|
| 42 |
+
# --- 1. Setup UI ---
|
| 43 |
+
apply_custom_css()
|
| 44 |
+
render_header()
|
| 45 |
+
render_sidebar()
|
| 46 |
+
|
| 47 |
+
# --- 2. Load Model & Data ---
|
| 48 |
+
model = load_model()
|
| 49 |
+
data = load_data()
|
| 50 |
+
google_api_key = st.secrets.get("GOOGLE_API_KEY")
|
| 51 |
+
|
| 52 |
+
if not data:
|
| 53 |
+
st.error("β Could not load dataset or embeddings. Please check your paths.")
|
| 54 |
+
st.stop()
|
| 55 |
+
|
| 56 |
+
query_enhancer = QueryEnhancer(model)
|
| 57 |
+
|
| 58 |
+
# --- 3. Set texts for recall ---
|
| 59 |
+
set_description_texts(data['description_column'])
|
| 60 |
+
set_patient_texts(data['patient_column'])
|
| 61 |
+
|
| 62 |
+
# --- 4. Initialize Chat History ---
|
| 63 |
+
if "messages" not in st.session_state:
|
| 64 |
+
st.session_state.messages = []
|
| 65 |
+
|
| 66 |
+
render_chat_history(st.session_state.messages)
|
| 67 |
+
|
| 68 |
+
# --- 5. PDF export sidebar ---
|
| 69 |
+
with st.sidebar:
|
| 70 |
+
st.markdown("---")
|
| 71 |
+
st.subheader("Export")
|
| 72 |
+
if st.session_state.get("messages"):
|
| 73 |
+
pdf_buffer = build_chat_pdf(
|
| 74 |
+
st.session_state.messages, title="MediLingua Chat Transcript"
|
| 75 |
+
)
|
| 76 |
+
if pdf_buffer:
|
| 77 |
+
st.download_button(
|
| 78 |
+
label="π Download chat as PDF",
|
| 79 |
+
data=pdf_buffer,
|
| 80 |
+
file_name="medilingua_chat.pdf",
|
| 81 |
+
mime="application/pdf",
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
st.caption(
|
| 85 |
+
"Install `reportlab` to enable PDF export: pip install reportlab"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# --- 6. Handle User Input ---
|
| 89 |
+
if prompt := st.chat_input("What is your medical question?"):
|
| 90 |
+
# Show user input immediately
|
| 91 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 92 |
+
with st.chat_message("user"):
|
| 93 |
+
st.markdown(prompt)
|
| 94 |
+
|
| 95 |
+
# Placeholder for bot response
|
| 96 |
+
with st.chat_message("assistant"):
|
| 97 |
+
message_placeholder = st.empty()
|
| 98 |
+
|
| 99 |
+
# Step 1: Enhance query
|
| 100 |
+
enhanced_query = query_enhancer.enhance_query(prompt)
|
| 101 |
+
|
| 102 |
+
# Step 2: Strong recall + hybrid search
|
| 103 |
+
with st.spinner("π Searching for top relevant answers..."):
|
| 104 |
+
# 1οΈβ£ Get top-k candidates
|
| 105 |
+
indices = strong_recall_indices(prompt, top_k=10)
|
| 106 |
+
if not indices or len(indices) < 3:
|
| 107 |
+
indices = hybrid_search(
|
| 108 |
+
model,
|
| 109 |
+
data['question_embeddings'],
|
| 110 |
+
user_query_raw=prompt,
|
| 111 |
+
user_query_enhanced=enhanced_query,
|
| 112 |
+
top_k=10,
|
| 113 |
+
weight_semantic=0.7,
|
| 114 |
+
faiss_top_candidates=256,
|
| 115 |
+
use_exact_match=True,
|
| 116 |
+
use_fuzzy_match=True
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# 2οΈβ£ Filter out semantically similar doctor answers to ensure diversity
|
| 120 |
+
indices = filter_similar_answers(indices, data['doctor_embeddings'], threshold=0.88)
|
| 121 |
+
|
| 122 |
+
# Step 3: Summarization / Gemini
|
| 123 |
+
if indices is not None and len(indices) > 0:
|
| 124 |
+
# Gather ALL filtered diverse answers for summarization (used as reference)
|
| 125 |
+
top_answers = [data['original_answers'][i] for i in indices]
|
| 126 |
+
combined_text = " ".join(top_answers)
|
| 127 |
+
|
| 128 |
+
summary = summarize_with_gemini(google_api_key, combined_text, prompt)
|
| 129 |
+
|
| 130 |
+
# Show only the AI's summarized answer (no doctor's notes displayed)
|
| 131 |
+
bot_typing_animation(message_placeholder, summary)
|
| 132 |
+
|
| 133 |
+
response = summary
|
| 134 |
+
else:
|
| 135 |
+
response = "βοΈ I couldnβt find any contextually similar answer in the dataset."
|
| 136 |
+
bot_typing_animation(message_placeholder, response)
|
| 137 |
+
|
| 138 |
+
# Save conversation
|
| 139 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
src/__pycache__/llm.cpython-313.pyc
ADDED
|
Binary file (3.77 kB). View file
|
|
|
src/__pycache__/pdf_utils.cpython-313.pyc
ADDED
|
Binary file (3.46 kB). View file
|
|
|
src/__pycache__/query_utils.cpython-313.pyc
ADDED
|
Binary file (4.41 kB). View file
|
|
|
src/__pycache__/search.cpython-313.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
src/__pycache__/ui.cpython-313.pyc
ADDED
|
Binary file (5.37 kB). View file
|
|
|
src/__pycache__/utils.cpython-313.pyc
ADDED
|
Binary file (4.97 kB). View file
|
|
|
src/llm.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def summarize_with_gemini(api_key, doctor_answer, user_question, max_retries=2):
|
| 2 |
+
import requests, json, time, streamlit as st
|
| 3 |
+
|
| 4 |
+
if not api_key:
|
| 5 |
+
st.warning("β οΈ Google API Key missing. Showing full answer instead.")
|
| 6 |
+
return doctor_answer if isinstance(doctor_answer, str) else "\n\n".join(doctor_answer)
|
| 7 |
+
|
| 8 |
+
combined_answer = "\n\n---\n\n".join(doctor_answer) if isinstance(doctor_answer, list) else doctor_answer
|
| 9 |
+
|
| 10 |
+
candidate_models = ["gemini-2.0-flash", "gemini-2.5-flash", "gemini-2.0-pro"]
|
| 11 |
+
|
| 12 |
+
for model_name in candidate_models:
|
| 13 |
+
prompt = f"""You are a professional AI medical assistant.
|
| 14 |
+
|
| 15 |
+
Summarize the doctor's responses clearly, accurately, and concisely for the patient.
|
| 16 |
+
Focus only on medically relevant information that directly answers the user's question.
|
| 17 |
+
|
| 18 |
+
User's Question:
|
| 19 |
+
"{user_question}"
|
| 20 |
+
|
| 21 |
+
Doctor's Answer(s):
|
| 22 |
+
"{combined_answer}"
|
| 23 |
+
|
| 24 |
+
Instructions:
|
| 25 |
+
- Provide a medically correct, patient-friendly summary in simple, clear language.
|
| 26 |
+
- List multiple points as bullets if possible.
|
| 27 |
+
- If the user's question lacks personal details (e.g., gender, age, weight), generate a generalized, gender-neutral summary.
|
| 28 |
+
- Avoid gender-specific recommendations (e.g., consulting a gynecologist) unless the query explicitly mentions gender or related details.
|
| 29 |
+
- Dont forget to mention potential next steps, treatments, or lifestyle changes if doctor's answer has it mentioned.
|
| 30 |
+
- Always include a recommendation to consult a relevant doctor type (e.g., general practitioner, orthopedist) at the end of the summary, unless the doctor's answers already specify a consultation with a specific doctor type.
|
| 31 |
+
- For example, for back pain, recommend consulting an orthopedist or general practitioner unless the query or doctor's answers suggest a more specific specialist.
|
| 32 |
+
- If the doctor's response does not address the question, respond:
|
| 33 |
+
"There is no information related to your question in the doctor's answer, so I generated the best possible answer based on the information provided."""
|
| 34 |
+
|
| 35 |
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent?key={api_key}"
|
| 36 |
+
payload = {"contents": [{"parts": [{"text": prompt}]}]}
|
| 37 |
+
headers = {"Content-Type": "application/json"}
|
| 38 |
+
|
| 39 |
+
for attempt in range(max_retries):
|
| 40 |
+
try:
|
| 41 |
+
resp = requests.post(url, headers=headers, data=json.dumps(payload), timeout=60)
|
| 42 |
+
resp.raise_for_status()
|
| 43 |
+
result = resp.json()
|
| 44 |
+
if "candidates" in result and result["candidates"]:
|
| 45 |
+
return result["candidates"][0]["content"]["parts"][0]["text"].strip()
|
| 46 |
+
except requests.exceptions.HTTPError as e:
|
| 47 |
+
if resp.status_code == 404: break
|
| 48 |
+
time.sleep(1)
|
| 49 |
+
except Exception: time.sleep(1)
|
| 50 |
+
|
| 51 |
+
st.warning("βοΈ Could not generate summary. Showing original answer.")
|
| 52 |
+
return combined_answer
|
src/pdf_utils.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
from typing import List, Dict
|
| 3 |
+
|
| 4 |
+
import streamlit as st
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _import_reportlab():
|
| 8 |
+
try:
|
| 9 |
+
from reportlab.lib.pagesizes import A4
|
| 10 |
+
from reportlab.lib.styles import getSampleStyleSheet
|
| 11 |
+
from reportlab.lib.units import mm
|
| 12 |
+
from reportlab.pdfgen import canvas
|
| 13 |
+
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer, Table, TableStyle
|
| 14 |
+
from reportlab.lib import colors
|
| 15 |
+
return {
|
| 16 |
+
"A4": A4,
|
| 17 |
+
"getSampleStyleSheet": getSampleStyleSheet,
|
| 18 |
+
"mm": mm,
|
| 19 |
+
"canvas": canvas,
|
| 20 |
+
"Paragraph": Paragraph,
|
| 21 |
+
"SimpleDocTemplate": SimpleDocTemplate,
|
| 22 |
+
"Spacer": Spacer,
|
| 23 |
+
"Table": Table,
|
| 24 |
+
"TableStyle": TableStyle,
|
| 25 |
+
"colors": colors,
|
| 26 |
+
}
|
| 27 |
+
except Exception:
|
| 28 |
+
return None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def build_chat_pdf(messages: List[Dict[str, str]], title: str = "MediLingua Chat Transcript") -> io.BytesIO:
|
| 32 |
+
"""
|
| 33 |
+
Create a PDF from chat messages and return as an in-memory BytesIO buffer.
|
| 34 |
+
Each message is a dict with keys: role ('user'|'assistant'), content (str).
|
| 35 |
+
"""
|
| 36 |
+
libs = _import_reportlab()
|
| 37 |
+
if libs is None:
|
| 38 |
+
st.error(
|
| 39 |
+
"PDF generation library not found. Install with: `pip install reportlab` and rerun."
|
| 40 |
+
)
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
buffer = io.BytesIO()
|
| 44 |
+
|
| 45 |
+
doc = libs["SimpleDocTemplate"](buffer, pagesize=libs["A4"], rightMargin=28, leftMargin=28, topMargin=36, bottomMargin=28)
|
| 46 |
+
styles = libs["getSampleStyleSheet"]()
|
| 47 |
+
|
| 48 |
+
elements = []
|
| 49 |
+
|
| 50 |
+
# Title
|
| 51 |
+
title_style = styles["Title"]
|
| 52 |
+
elements.append(libs["Paragraph"](title, title_style))
|
| 53 |
+
elements.append(libs["Spacer"](1, 12))
|
| 54 |
+
|
| 55 |
+
# Build a table-like layout for messages
|
| 56 |
+
data = []
|
| 57 |
+
table_style_cmds = [
|
| 58 |
+
("VALIGN", (0, 0), (-1, -1), "TOP"),
|
| 59 |
+
("INNERGRID", (0, 0), (-1, -1), 0.25, libs["colors"].lightgrey),
|
| 60 |
+
("BOX", (0, 0), (-1, -1), 0.25, libs["colors"].lightgrey),
|
| 61 |
+
("LEFTPADDING", (0, 0), (-1, -1), 6),
|
| 62 |
+
("RIGHTPADDING", (0, 0), (-1, -1), 6),
|
| 63 |
+
("TOPPADDING", (0, 0), (-1, -1), 6),
|
| 64 |
+
("BOTTOMPADDING", (0, 0), (-1, -1), 6),
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
role_style = styles["Heading5"]
|
| 68 |
+
body_style = styles["BodyText"]
|
| 69 |
+
|
| 70 |
+
for msg in messages:
|
| 71 |
+
role = msg.get("role", "").capitalize()
|
| 72 |
+
content = msg.get("content", "")
|
| 73 |
+
|
| 74 |
+
# Left column: role, Right column: content paragraph
|
| 75 |
+
role_par = libs["Paragraph"](f"<b>{role}</b>", role_style)
|
| 76 |
+
content_par = libs["Paragraph"](content.replace("\n", "<br/>"), body_style)
|
| 77 |
+
data.append([role_par, content_par])
|
| 78 |
+
|
| 79 |
+
table = libs["Table"](data, colWidths=[30 * libs["mm"], None])
|
| 80 |
+
table.setStyle(libs["TableStyle"](table_style_cmds))
|
| 81 |
+
elements.append(table)
|
| 82 |
+
|
| 83 |
+
doc.build(elements)
|
| 84 |
+
buffer.seek(0)
|
| 85 |
+
return buffer
|
| 86 |
+
|
| 87 |
+
|
src/query_utils.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import nltk
|
| 3 |
+
from nltk.tokenize import RegexpTokenizer
|
| 4 |
+
from nltk.corpus import stopwords
|
| 5 |
+
from nltk.stem import WordNetLemmatizer
|
| 6 |
+
from keybert import KeyBERT
|
| 7 |
+
|
| 8 |
+
# --- Download NLTK resources if needed ---
|
| 9 |
+
try:
|
| 10 |
+
stopwords.words('english')
|
| 11 |
+
except LookupError:
|
| 12 |
+
nltk.download('stopwords', quiet=True)
|
| 13 |
+
nltk.download('punkt', quiet=True)
|
| 14 |
+
nltk.download('averaged_perceptron_tagger', quiet=True)
|
| 15 |
+
nltk.download('wordnet', quiet=True)
|
| 16 |
+
|
| 17 |
+
# --- Initialize tools ---
|
| 18 |
+
tokenizer = RegexpTokenizer(r'\w+')
|
| 19 |
+
lemmatizer = WordNetLemmatizer()
|
| 20 |
+
custom_stopwords = set(stopwords.words('english')) - {'no', 'not', 'without', 'due', 'to', 'with', 'on', 'in'}
|
| 21 |
+
|
| 22 |
+
# --- Medical synonym expansion ---
|
| 23 |
+
medical_synonyms = {
|
| 24 |
+
"flu": ["influenza"],
|
| 25 |
+
"cold": ["common cold", "rhinitis"],
|
| 26 |
+
"heart attack": ["myocardial infarction"],
|
| 27 |
+
"diabetes": ["high blood sugar", "hyperglycemia"],
|
| 28 |
+
"bp": ["blood pressure", "hypertension"],
|
| 29 |
+
"hypertension": ["high blood pressure"],
|
| 30 |
+
"asthma": ["respiratory disease"],
|
| 31 |
+
"cough": ["dry cough", "wet cough"],
|
| 32 |
+
"fever": ["temperature", "high fever"]
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def expand_medical_terms(text: str) -> str:
|
| 36 |
+
"""Expands known medical terms with their synonyms for better recall."""
|
| 37 |
+
for key, syns in medical_synonyms.items():
|
| 38 |
+
for syn in syns:
|
| 39 |
+
text = re.sub(rf"\b{key}\b", f"{key} {syn}", text, flags=re.IGNORECASE)
|
| 40 |
+
return text
|
| 41 |
+
|
| 42 |
+
def preprocess_text(text: str) -> str:
|
| 43 |
+
"""Minimal preprocessing: lowercase, remove punctuation, collapse spaces."""
|
| 44 |
+
text = str(text).lower()
|
| 45 |
+
text = re.sub(r'[^\w\s]', ' ', text)
|
| 46 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 47 |
+
return text
|
| 48 |
+
|
| 49 |
+
class QueryEnhancer:
|
| 50 |
+
"""
|
| 51 |
+
Wrapper class to handle query enhancement with local SapBERT + KeyBERT.
|
| 52 |
+
"""
|
| 53 |
+
def __init__(self, sentence_transformer_model):
|
| 54 |
+
"""
|
| 55 |
+
sentence_transformer_model: the already-loaded local SapBERT SentenceTransformer
|
| 56 |
+
"""
|
| 57 |
+
self.kw_model = KeyBERT(sentence_transformer_model)
|
| 58 |
+
|
| 59 |
+
def extract_keywords(self, text: str, top_n: int = 5) -> list:
|
| 60 |
+
"""Extracts top keywords using KeyBERT."""
|
| 61 |
+
if not self.kw_model:
|
| 62 |
+
return []
|
| 63 |
+
try:
|
| 64 |
+
keywords = self.kw_model.extract_keywords(
|
| 65 |
+
text,
|
| 66 |
+
keyphrase_ngram_range=(1, 2),
|
| 67 |
+
stop_words='english',
|
| 68 |
+
top_n=top_n
|
| 69 |
+
)
|
| 70 |
+
return [kw[0] for kw in keywords]
|
| 71 |
+
except Exception:
|
| 72 |
+
return []
|
| 73 |
+
|
| 74 |
+
def enhance_query(self, user_query: str) -> str:
|
| 75 |
+
"""
|
| 76 |
+
Full query enhancement pipeline:
|
| 77 |
+
- Preprocess text
|
| 78 |
+
- Expand medical synonyms
|
| 79 |
+
- Extract keywords
|
| 80 |
+
- Return combined enhanced query string
|
| 81 |
+
"""
|
| 82 |
+
preprocessed = preprocess_text(user_query)
|
| 83 |
+
expanded = expand_medical_terms(preprocessed)
|
| 84 |
+
keywords = self.extract_keywords(user_query)
|
| 85 |
+
enhanced_query = f"{expanded} {' '.join(keywords)}".strip()
|
| 86 |
+
return enhanced_query
|
src/search.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 3 |
+
import faiss
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
# --- Global caches ---
|
| 7 |
+
tfidf_vectorizer = None
|
| 8 |
+
tfidf_matrix = None
|
| 9 |
+
corpus_texts = None
|
| 10 |
+
faiss_index = None
|
| 11 |
+
embeddings_array = None # FAISS requires float32
|
| 12 |
+
description_texts = None # For exact/fuzzy match
|
| 13 |
+
patient_texts = None # For exact/fuzzy match
|
| 14 |
+
description_norm_texts = None # Normalized (punctuation stripped)
|
| 15 |
+
patient_norm_texts = None # Normalized (punctuation stripped)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def encode_question(model, user_question):
|
| 19 |
+
"""Encodes the user's question using the embedding model."""
|
| 20 |
+
if model is None or not user_question.strip():
|
| 21 |
+
return None
|
| 22 |
+
return model.encode([user_question], show_progress_bar=False)[0].astype('float32')
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def init_tfidf(data_texts):
|
| 26 |
+
"""
|
| 27 |
+
Initialize TF-IDF matrix for hybrid search.
|
| 28 |
+
"""
|
| 29 |
+
global tfidf_vectorizer, tfidf_matrix, corpus_texts
|
| 30 |
+
corpus_texts = data_texts
|
| 31 |
+
tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=10000)
|
| 32 |
+
tfidf_matrix = tfidf_vectorizer.fit_transform(corpus_texts)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def init_faiss(embeddings):
|
| 36 |
+
"""
|
| 37 |
+
Initialize FAISS index for fast semantic search.
|
| 38 |
+
embeddings: np.array (num_samples x 768) normalized
|
| 39 |
+
"""
|
| 40 |
+
global faiss_index, embeddings_array
|
| 41 |
+
embeddings_array = embeddings.astype('float32')
|
| 42 |
+
|
| 43 |
+
# Do not renormalize
|
| 44 |
+
dimension = embeddings_array.shape[1]
|
| 45 |
+
faiss_index = faiss.IndexFlatIP(dimension) # Inner product for cosine similarity
|
| 46 |
+
faiss_index.add(embeddings_array)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def set_description_texts(texts):
|
| 51 |
+
"""
|
| 52 |
+
Provide the Description column for exact/fuzzy match search.
|
| 53 |
+
"""
|
| 54 |
+
global description_texts, description_norm_texts
|
| 55 |
+
description_texts = [str(t).lower() for t in texts]
|
| 56 |
+
description_norm_texts = [preprocess_text_for_embeddings(t) for t in texts]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def set_patient_texts(texts):
|
| 60 |
+
"""
|
| 61 |
+
Provide the Patient column for exact/fuzzy match search.
|
| 62 |
+
"""
|
| 63 |
+
global patient_texts, patient_norm_texts
|
| 64 |
+
patient_texts = [str(t).lower() for t in texts]
|
| 65 |
+
patient_norm_texts = [preprocess_text_for_embeddings(t) for t in texts]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def strong_recall_indices(user_query_raw: str, top_k: int = 10):
|
| 69 |
+
"""
|
| 70 |
+
Scan the entire dataset (Description + Patient) for:
|
| 71 |
+
1) Exact equality on normalized text
|
| 72 |
+
2) Exact substring presence
|
| 73 |
+
3) High-threshold fuzzy match (if rapidfuzz available)
|
| 74 |
+
|
| 75 |
+
Returns a list of indices (unique, in priority order) up to top_k.
|
| 76 |
+
"""
|
| 77 |
+
global description_texts, patient_texts, description_norm_texts, patient_norm_texts
|
| 78 |
+
|
| 79 |
+
if not user_query_raw:
|
| 80 |
+
return []
|
| 81 |
+
|
| 82 |
+
q_lower = str(user_query_raw).lower()
|
| 83 |
+
q_norm = preprocess_text_for_embeddings(user_query_raw)
|
| 84 |
+
|
| 85 |
+
N_desc = len(description_texts) if description_texts is not None else 0
|
| 86 |
+
N_pat = len(patient_texts) if patient_texts is not None else 0
|
| 87 |
+
N = max(N_desc, N_pat)
|
| 88 |
+
if N == 0:
|
| 89 |
+
return []
|
| 90 |
+
|
| 91 |
+
exact_equal = []
|
| 92 |
+
exact_sub = []
|
| 93 |
+
fuzzy_hits = []
|
| 94 |
+
|
| 95 |
+
# 1) Exact equality on normalized text
|
| 96 |
+
if description_norm_texts is not None:
|
| 97 |
+
exact_equal += [i for i in range(len(description_norm_texts)) if description_norm_texts[i] == q_norm]
|
| 98 |
+
if patient_norm_texts is not None:
|
| 99 |
+
exact_equal += [i for i in range(len(patient_norm_texts)) if patient_norm_texts[i] == q_norm]
|
| 100 |
+
|
| 101 |
+
# Deduplicate preserving order
|
| 102 |
+
seen = set()
|
| 103 |
+
ordered = []
|
| 104 |
+
for i in exact_equal:
|
| 105 |
+
if i not in seen:
|
| 106 |
+
seen.add(i)
|
| 107 |
+
ordered.append(i)
|
| 108 |
+
if len(ordered) >= top_k:
|
| 109 |
+
return ordered[:top_k]
|
| 110 |
+
|
| 111 |
+
# 2) Exact substring presence (lowercased)
|
| 112 |
+
if description_texts is not None:
|
| 113 |
+
exact_sub += [i for i in range(len(description_texts)) if q_lower in description_texts[i]]
|
| 114 |
+
if patient_texts is not None:
|
| 115 |
+
exact_sub += [i for i in range(len(patient_texts)) if q_lower in patient_texts[i]]
|
| 116 |
+
|
| 117 |
+
for i in exact_sub:
|
| 118 |
+
if i not in seen:
|
| 119 |
+
seen.add(i)
|
| 120 |
+
ordered.append(i)
|
| 121 |
+
if len(ordered) >= top_k:
|
| 122 |
+
return ordered[:top_k]
|
| 123 |
+
|
| 124 |
+
# 3) High-threshold fuzzy matches
|
| 125 |
+
try:
|
| 126 |
+
from rapidfuzz import fuzz
|
| 127 |
+
# Use partial_ratio and token_set_ratio; take max as score
|
| 128 |
+
scored = []
|
| 129 |
+
for i in range(N):
|
| 130 |
+
s_desc = description_texts[i] if (description_texts is not None and i < len(description_texts)) else ""
|
| 131 |
+
s_pat = patient_texts[i] if (patient_texts is not None and i < len(patient_texts)) else ""
|
| 132 |
+
score_desc = max(fuzz.partial_ratio(q_lower, s_desc), fuzz.token_set_ratio(q_lower, s_desc)) if s_desc else 0
|
| 133 |
+
score_pat = max(fuzz.partial_ratio(q_lower, s_pat), fuzz.token_set_ratio(q_lower, s_pat)) if s_pat else 0
|
| 134 |
+
score = max(score_desc, score_pat)
|
| 135 |
+
if score >= 90:
|
| 136 |
+
scored.append((i, score))
|
| 137 |
+
# sort by score desc
|
| 138 |
+
scored.sort(key=lambda x: x[1], reverse=True)
|
| 139 |
+
for i, _ in scored:
|
| 140 |
+
if i not in seen:
|
| 141 |
+
seen.add(i)
|
| 142 |
+
ordered.append(i)
|
| 143 |
+
if len(ordered) >= top_k:
|
| 144 |
+
break
|
| 145 |
+
except Exception:
|
| 146 |
+
pass
|
| 147 |
+
|
| 148 |
+
return ordered[:top_k]
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def hybrid_search(
|
| 152 |
+
model,
|
| 153 |
+
embeddings,
|
| 154 |
+
user_query_raw,
|
| 155 |
+
user_query_enhanced,
|
| 156 |
+
top_k=5,
|
| 157 |
+
weight_semantic=0.7,
|
| 158 |
+
faiss_top_candidates=256,
|
| 159 |
+
use_exact_match=True,
|
| 160 |
+
use_fuzzy_match=True
|
| 161 |
+
):
|
| 162 |
+
"""
|
| 163 |
+
Hybrid search combining:
|
| 164 |
+
1. FAISS semantic similarity
|
| 165 |
+
2. TF-IDF boosting
|
| 166 |
+
3. Optional exact substring match in Description
|
| 167 |
+
|
| 168 |
+
Returns: list of top indices in dataset
|
| 169 |
+
"""
|
| 170 |
+
global tfidf_vectorizer, tfidf_matrix, corpus_texts, faiss_index, embeddings_array, description_texts
|
| 171 |
+
|
| 172 |
+
if model is None or embeddings is None or len(embeddings) == 0:
|
| 173 |
+
return []
|
| 174 |
+
|
| 175 |
+
# Encode enhanced query for semantic/TF-IDF stages
|
| 176 |
+
question_embedding = encode_question(model, user_query_enhanced)
|
| 177 |
+
if question_embedding is None:
|
| 178 |
+
return []
|
| 179 |
+
|
| 180 |
+
# --- 1. FAISS semantic search ---
|
| 181 |
+
if faiss_index is not None:
|
| 182 |
+
D, I = faiss_index.search(np.array([question_embedding]), k=min(faiss_top_candidates, embeddings.shape[0]))
|
| 183 |
+
top_candidates = I[0]
|
| 184 |
+
semantic_sim_top = D[0]
|
| 185 |
+
else:
|
| 186 |
+
semantic_sim_full = np.dot(embeddings, question_embedding)
|
| 187 |
+
top_candidates = np.argpartition(semantic_sim_full, -faiss_top_candidates)[-faiss_top_candidates:]
|
| 188 |
+
top_candidates = top_candidates[np.argsort(semantic_sim_full[top_candidates])[::-1]]
|
| 189 |
+
semantic_sim_top = semantic_sim_full[top_candidates]
|
| 190 |
+
|
| 191 |
+
# --- 2. TF-IDF similarity ---
|
| 192 |
+
if tfidf_vectorizer is not None and tfidf_matrix is not None:
|
| 193 |
+
tfidf_vec = tfidf_vectorizer.transform([user_query_enhanced])
|
| 194 |
+
tfidf_sim_top = (tfidf_matrix[top_candidates] @ tfidf_vec.T).toarray().ravel()
|
| 195 |
+
else:
|
| 196 |
+
tfidf_sim_top = np.zeros(len(top_candidates))
|
| 197 |
+
|
| 198 |
+
# --- 3. Optional exact + fuzzy match across Description & Patient ---
|
| 199 |
+
combined_sim_top = weight_semantic * semantic_sim_top + (1 - weight_semantic) * tfidf_sim_top
|
| 200 |
+
|
| 201 |
+
if use_exact_match or use_fuzzy_match:
|
| 202 |
+
query_lower = user_query_raw.lower()
|
| 203 |
+
|
| 204 |
+
# Exact substring presence boosts
|
| 205 |
+
exact_desc = np.zeros(len(top_candidates))
|
| 206 |
+
exact_pat = np.zeros(len(top_candidates))
|
| 207 |
+
if description_texts is not None:
|
| 208 |
+
exact_desc = np.array([1.0 if query_lower in description_texts[i] else 0.0 for i in top_candidates])
|
| 209 |
+
if patient_texts is not None:
|
| 210 |
+
exact_pat = np.array([1.0 if query_lower in patient_texts[i] else 0.0 for i in top_candidates])
|
| 211 |
+
|
| 212 |
+
# Fuzzy partial ratio via rapidfuzz (graceful fallback)
|
| 213 |
+
fuzzy_desc = np.zeros(len(top_candidates))
|
| 214 |
+
fuzzy_pat = np.zeros(len(top_candidates))
|
| 215 |
+
if use_fuzzy_match:
|
| 216 |
+
try:
|
| 217 |
+
from rapidfuzz import fuzz
|
| 218 |
+
if description_texts is not None:
|
| 219 |
+
fuzzy_desc = np.array([
|
| 220 |
+
fuzz.partial_ratio(query_lower, description_texts[i]) / 100.0 for i in top_candidates
|
| 221 |
+
])
|
| 222 |
+
if patient_texts is not None:
|
| 223 |
+
fuzzy_pat = np.array([
|
| 224 |
+
fuzz.partial_ratio(query_lower, patient_texts[i]) / 100.0 for i in top_candidates
|
| 225 |
+
])
|
| 226 |
+
except Exception:
|
| 227 |
+
pass
|
| 228 |
+
|
| 229 |
+
# Token overlap (Jaccard) as an additional weak signal
|
| 230 |
+
def jaccard(a: str, b: str) -> float:
|
| 231 |
+
sa = set(a.split())
|
| 232 |
+
sb = set(b.split())
|
| 233 |
+
if not sa or not sb:
|
| 234 |
+
return 0.0
|
| 235 |
+
inter = len(sa & sb)
|
| 236 |
+
union = len(sa | sb)
|
| 237 |
+
return inter / union if union else 0.0
|
| 238 |
+
|
| 239 |
+
token_desc = np.zeros(len(top_candidates))
|
| 240 |
+
token_pat = np.zeros(len(top_candidates))
|
| 241 |
+
if description_texts is not None:
|
| 242 |
+
token_desc = np.array([jaccard(query_lower, description_texts[i]) for i in top_candidates])
|
| 243 |
+
if patient_texts is not None:
|
| 244 |
+
token_pat = np.array([jaccard(query_lower, patient_texts[i]) for i in top_candidates])
|
| 245 |
+
|
| 246 |
+
# Combine boosters with gentle weights; exact match is strongest
|
| 247 |
+
booster = 0.20 * exact_desc + 0.20 * exact_pat + 0.10 * fuzzy_desc + 0.10 * fuzzy_pat + 0.05 * token_desc + 0.05 * token_pat
|
| 248 |
+
combined_sim_top = combined_sim_top + booster
|
| 249 |
+
|
| 250 |
+
# --- 4. Select final top-k indices ---
|
| 251 |
+
sorted_top_indices = top_candidates[np.argsort(combined_sim_top)[::-1][:top_k]]
|
| 252 |
+
|
| 253 |
+
return sorted_top_indices
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# --- Minimal preprocessing for embeddings ---
|
| 257 |
+
def preprocess_text_for_embeddings(text: str) -> str:
|
| 258 |
+
"""Lowercase + remove punctuation for embeddings."""
|
| 259 |
+
text = str(text).lower()
|
| 260 |
+
text = re.sub(r'[^\w\s]', ' ', text)
|
| 261 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 262 |
+
return text
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# --- Minimal preprocessing for keywords ---
|
| 266 |
+
def preprocess_text_for_keywords(text: str) -> str:
|
| 267 |
+
"""Lowercase + remove punctuation for keywords."""
|
| 268 |
+
text = str(text).lower()
|
| 269 |
+
text = re.sub(r'[^\w\s]', ' ', text)
|
| 270 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 271 |
+
return text
|
src/ui.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
def apply_custom_css():
|
| 5 |
+
"""Applies custom CSS for proper left-right chat alignment with enhanced colors."""
|
| 6 |
+
css = """
|
| 7 |
+
<style>
|
| 8 |
+
.stApp {
|
| 9 |
+
background-color: #0f172a;
|
| 10 |
+
color: #e2e8f0;
|
| 11 |
+
font-family: 'Inter', sans-serif;
|
| 12 |
+
}
|
| 13 |
+
.block-container {
|
| 14 |
+
max-width: 900px;
|
| 15 |
+
margin: auto;
|
| 16 |
+
padding-top: 1rem;
|
| 17 |
+
}
|
| 18 |
+
/* Chat message layout override */
|
| 19 |
+
[data-testid="stChatMessage"] {
|
| 20 |
+
display: flex !important;
|
| 21 |
+
align-items: flex-start !important;
|
| 22 |
+
margin-bottom: 0.75rem;
|
| 23 |
+
}
|
| 24 |
+
[data-testid="stChatMessage"] > div[data-testid="stMarkdownContainer"] {
|
| 25 |
+
padding: 0.8rem 1rem;
|
| 26 |
+
border-radius: 16px;
|
| 27 |
+
max-width: 70%;
|
| 28 |
+
line-height: 1.5;
|
| 29 |
+
font-size: 0.95rem;
|
| 30 |
+
word-wrap: break-word;
|
| 31 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.25);
|
| 32 |
+
transition: all 0.2s ease-in-out;
|
| 33 |
+
animation: fadeIn 0.3s ease;
|
| 34 |
+
}
|
| 35 |
+
@keyframes fadeIn {
|
| 36 |
+
from { opacity: 0; transform: translateY(4px); }
|
| 37 |
+
to { opacity: 1; transform: translateY(0); }
|
| 38 |
+
}
|
| 39 |
+
/* Assistant (left) */
|
| 40 |
+
[data-testid="stChatMessage"]:has(.stChatMessageContent[data-testid="assistant"]) {
|
| 41 |
+
justify-content: flex-start !important;
|
| 42 |
+
}
|
| 43 |
+
[data-testid="stChatMessage"]:has(.stChatMessageContent[data-testid="assistant"])
|
| 44 |
+
> div[data-testid="stMarkdownContainer"] {
|
| 45 |
+
background-color: #1e293b;
|
| 46 |
+
color: #f1f5f9;
|
| 47 |
+
border: 1px solid #334155;
|
| 48 |
+
text-align: left;
|
| 49 |
+
}
|
| 50 |
+
/* User (right) */
|
| 51 |
+
[data-testid="stChatMessage"]:has(.stChatMessageContent[data-testid="user"]) {
|
| 52 |
+
justify-content: flex-end !important;
|
| 53 |
+
}
|
| 54 |
+
[data-testid="stChatMessage"]:has(.stChatMessageContent[data-testid="user"])
|
| 55 |
+
> div[data-testid="stMarkdownContainer"] {
|
| 56 |
+
background-color: #2563eb;
|
| 57 |
+
color: white;
|
| 58 |
+
border: 1px solid #1d4ed8;
|
| 59 |
+
text-align: right;
|
| 60 |
+
}
|
| 61 |
+
/* Expander (doctor notes) */
|
| 62 |
+
.streamlit-expanderHeader {
|
| 63 |
+
background: #111827;
|
| 64 |
+
color: #cbd5e1;
|
| 65 |
+
border: 1px solid #374151;
|
| 66 |
+
border-radius: 10px;
|
| 67 |
+
}
|
| 68 |
+
.streamlit-expanderContent {
|
| 69 |
+
background: #0b1220;
|
| 70 |
+
border-left: 2px solid #334155;
|
| 71 |
+
}
|
| 72 |
+
/* Scrollbar style */
|
| 73 |
+
::-webkit-scrollbar { width: 8px; }
|
| 74 |
+
::-webkit-scrollbar-thumb { background-color: #334155; border-radius: 10px; }
|
| 75 |
+
/* Header/title */
|
| 76 |
+
h1 {
|
| 77 |
+
text-align: center;
|
| 78 |
+
color: #60a5fa;
|
| 79 |
+
font-weight: 600;
|
| 80 |
+
}
|
| 81 |
+
p[style*='text-align: center;'] {
|
| 82 |
+
color: #94a3b8;
|
| 83 |
+
}
|
| 84 |
+
</style>
|
| 85 |
+
"""
|
| 86 |
+
st.markdown(css, unsafe_allow_html=True)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def render_header():
|
| 90 |
+
st.title("π€ MediLingua: Your Medical Assistant")
|
| 91 |
+
st.markdown(
|
| 92 |
+
"<p style='text-align: center; font-size: 1.1rem;'>Ask medical questions and get summarized answers from real doctor responses.</p>",
|
| 93 |
+
unsafe_allow_html=True
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def render_sidebar():
|
| 98 |
+
with st.sidebar:
|
| 99 |
+
st.header("βοΈ Configuration")
|
| 100 |
+
if st.secrets.get("GOOGLE_API_KEY"):
|
| 101 |
+
st.success("β
Google API Key configured.")
|
| 102 |
+
else:
|
| 103 |
+
st.error("β Missing API Key in `.streamlit/secrets.toml`.")
|
| 104 |
+
st.markdown("---")
|
| 105 |
+
st.markdown("π‘ Built with **Streamlit** & **Gemini**.")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def render_chat_history(messages):
|
| 109 |
+
"""Render previous messages."""
|
| 110 |
+
for message in messages:
|
| 111 |
+
with st.chat_message(message["role"]):
|
| 112 |
+
st.markdown(message["content"])
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def bot_typing_animation(message_placeholder, final_text, delay=0.02):
|
| 116 |
+
"""
|
| 117 |
+
Simulate bot typing animation in chat.
|
| 118 |
+
"""
|
| 119 |
+
message_placeholder.markdown("") # Empty initially
|
| 120 |
+
displayed = ""
|
| 121 |
+
for char in final_text:
|
| 122 |
+
displayed += char
|
| 123 |
+
message_placeholder.markdown(displayed)
|
| 124 |
+
time.sleep(delay)
|
src/utils.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pickle
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from sentence_transformers import SentenceTransformer, models
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from src.search import init_faiss
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
+
|
| 10 |
+
# Repo IDs
|
| 11 |
+
DATASET_REPO = "param2004/Medilingua-dataset"
|
| 12 |
+
MODEL_REPO = "param2004/Medilingua-model"
|
| 13 |
+
|
| 14 |
+
@st.cache_resource
|
| 15 |
+
def load_model():
|
| 16 |
+
"""Load SapBERT dynamically from Hugging Face Hub"""
|
| 17 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 18 |
+
|
| 19 |
+
st.info(f"π¬ Loading SapBERT from Hugging Face Hub on {device.upper()}...")
|
| 20 |
+
|
| 21 |
+
# Download model files dynamically
|
| 22 |
+
try:
|
| 23 |
+
model_path = hf_hub_download(
|
| 24 |
+
repo_id=MODEL_REPO,
|
| 25 |
+
filename="models/SapBERT-from-PubMedBERT-fulltext/pytorch_model.bin"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Load SentenceTransformer as before
|
| 29 |
+
word_embedding_model = models.Transformer(model_path)
|
| 30 |
+
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
|
| 31 |
+
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)
|
| 32 |
+
|
| 33 |
+
st.success("β
SapBERT loaded successfully from Hub.")
|
| 34 |
+
except Exception as e:
|
| 35 |
+
st.error(f"β Failed to load SapBERT from Hub. Details: {e}")
|
| 36 |
+
st.warning("β οΈ Falling back to 'all-MiniLM-L6-v2' model.")
|
| 37 |
+
model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
|
| 38 |
+
|
| 39 |
+
return model
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@st.cache_resource
|
| 43 |
+
def load_data():
|
| 44 |
+
"""Load embeddings and dataset dynamically from Hugging Face Hub"""
|
| 45 |
+
try:
|
| 46 |
+
# Download embeddings
|
| 47 |
+
question_emb_path = hf_hub_download(
|
| 48 |
+
repo_id=DATASET_REPO,
|
| 49 |
+
filename="dataset/question_embeddings.pkl"
|
| 50 |
+
)
|
| 51 |
+
doctor_emb_path = hf_hub_download(
|
| 52 |
+
repo_id=DATASET_REPO,
|
| 53 |
+
filename="dataset/doctor_embeddings.pkl"
|
| 54 |
+
)
|
| 55 |
+
dataset_csv_path = hf_hub_download(
|
| 56 |
+
repo_id=DATASET_REPO,
|
| 57 |
+
filename="dataset/dataset.csv"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Load embeddings
|
| 61 |
+
with open(question_emb_path, 'rb') as f:
|
| 62 |
+
question_data = pickle.load(f)
|
| 63 |
+
question_embeddings = question_data.get('embeddings').astype('float32')
|
| 64 |
+
|
| 65 |
+
with open(doctor_emb_path, 'rb') as f:
|
| 66 |
+
doctor_data = pickle.load(f)
|
| 67 |
+
doctor_embeddings = doctor_data.get('embeddings').astype('float32')
|
| 68 |
+
|
| 69 |
+
# Load CSV
|
| 70 |
+
df = pd.read_csv(dataset_csv_path)
|
| 71 |
+
df.dropna(subset=['Description', 'Patient', 'Doctor'], inplace=True)
|
| 72 |
+
df.drop_duplicates(inplace=True)
|
| 73 |
+
|
| 74 |
+
num_samples = min(len(df), len(question_embeddings), len(doctor_embeddings))
|
| 75 |
+
df = df.iloc[:num_samples]
|
| 76 |
+
question_embeddings = question_embeddings[:num_samples]
|
| 77 |
+
doctor_embeddings = doctor_embeddings[:num_samples]
|
| 78 |
+
|
| 79 |
+
st.success(f"β
Loaded {num_samples} rows with SapBERT embeddings ({question_embeddings.shape[1]} dims)")
|
| 80 |
+
|
| 81 |
+
# Initialize FAISS
|
| 82 |
+
init_faiss(question_embeddings)
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"question_embeddings": question_embeddings,
|
| 86 |
+
"doctor_embeddings": doctor_embeddings,
|
| 87 |
+
"description_column": df["Description"].tolist(),
|
| 88 |
+
"patient_column": df["Patient"].tolist(),
|
| 89 |
+
"original_answers": df["Doctor"].tolist(),
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
st.error(f"β Error loading dataset or embeddings: {e}")
|
| 94 |
+
return None
|