Spaces:
Running
Running
| import os | |
| import json | |
| import datetime | |
| import threading | |
| from pathlib import Path | |
| import numpy as np | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings | |
| # ================= ENV ================= | |
| load_dotenv() | |
| NVIDIA_API_KEY = os.getenv("NVIDIA_API_KEY") | |
| if not NVIDIA_API_KEY: | |
| raise RuntimeError("NVIDIA_API_KEY not found") | |
| os.environ["NVIDIA_API_KEY"] = NVIDIA_API_KEY | |
| # ================= CONFIG ================= | |
| DAILY_LIMIT = 50 | |
| RATE_FILE = Path("rate_limit.json") | |
| EMBEDDINGS_FILE = Path("embeddings.json") | |
| MAX_HISTORY_TURNS = 3 # keep last 3 Q/A pairs | |
| lock = threading.Lock() | |
| # ================= MODELS ================= | |
| embedder = NVIDIAEmbeddings( | |
| model="nvidia/nv-embed-v1", | |
| truncate="END" | |
| ) | |
| llm = ChatNVIDIA( | |
| model="mistralai/mixtral-8x22b-instruct-v0.1", | |
| temperature=0.2 | |
| ) | |
| # ================= LOAD DOCS ================= | |
| with open(EMBEDDINGS_FILE) as f: | |
| DOCS = json.load(f) | |
| # ================= UTILS ================= | |
| def cosine(a, b): | |
| a, b = np.array(a), np.array(b) | |
| return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) | |
| def retrieve(question, k=4): | |
| qvec = embedder.embed_query(question) | |
| scored = [(cosine(qvec, d["embedding"]), d["text"]) for d in DOCS] | |
| scored.sort(reverse=True) | |
| return [t for _, t in scored[:k]] | |
| def check_rate_limit(): | |
| today = datetime.date.today().isoformat() | |
| with lock: | |
| data = json.loads(RATE_FILE.read_text()) if RATE_FILE.exists() else {} | |
| if data.get(today, 0) >= DAILY_LIMIT: | |
| return False | |
| data[today] = data.get(today, 0) + 1 | |
| RATE_FILE.write_text(json.dumps(data)) | |
| return True | |
| def build_prompt(context, history, question): | |
| history_text = "\n".join([ | |
| f"User: {q}\nAssistant: {a}" | |
| for q, a in history[-MAX_HISTORY_TURNS:] | |
| ]) | |
| context_text = "\n\n---\n\n".join(context) | |
| return f"""You are a document-grounded assistant. | |
| Answer ONLY using the context. | |
| If the answer is not present, say "I don't know". | |
| Conversation so far: | |
| {history_text} | |
| Context: | |
| {context_text} | |
| User question: | |
| {question}""".strip() | |
| # ================= CHAT FN (STREAMING) ================= | |
| def chat_stream(question, history): | |
| if not question.strip(): | |
| yield history | |
| return | |
| if not check_rate_limit(): | |
| history.append((question, "Daily limit reached (50 queries).")) | |
| yield history | |
| return | |
| context = retrieve(question) | |
| prompt = build_prompt(context, history, question) | |
| partial = "" | |
| for chunk in llm.stream(prompt): | |
| partial += chunk.content | |
| yield history + [(question, partial)] | |
| # ================= UI ================= | |
| with gr.Blocks(title="Academic Regulations RAG") as demo: | |
| gr.Markdown("## π Academic Regulations Queries") | |
| gr.Markdown( | |
| "Ask questions about the academic regulations document. " | |
| "Answers are generated **only** from the official document." | |
| ) | |
| chatbot = gr.Chatbot(height=420) | |
| question = gr.Textbox( | |
| placeholder="e.g. What is the E grade?", | |
| label="Your question", | |
| scale=4 | |
| ) | |
| ask = gr.Button("Ask", scale=1, min_width=100) | |
| clear = gr.Button("Clear Chat") | |
| ask.click(chat_stream, [question, chatbot], chatbot) | |
| question.submit( | |
| chat_stream, | |
| inputs=[question, chatbot], | |
| outputs=chatbot | |
| ) | |
| clear.click(lambda: [], None, chatbot) | |
| # ================= RUN ================= | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.getenv("PORT", 7860)) | |
| ) |