Spaces:
Running
Running
File size: 3,656 Bytes
3073782 de1c0ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import json
import datetime
import threading
from pathlib import Path
import numpy as np
import gradio as gr
from dotenv import load_dotenv
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
# ================= ENV =================
load_dotenv()
NVIDIA_API_KEY = os.getenv("NVIDIA_API_KEY")
if not NVIDIA_API_KEY:
raise RuntimeError("NVIDIA_API_KEY not found")
os.environ["NVIDIA_API_KEY"] = NVIDIA_API_KEY
# ================= CONFIG =================
DAILY_LIMIT = 50
RATE_FILE = Path("rate_limit.json")
EMBEDDINGS_FILE = Path("embeddings.json")
MAX_HISTORY_TURNS = 3 # keep last 3 Q/A pairs
lock = threading.Lock()
# ================= MODELS =================
embedder = NVIDIAEmbeddings(
model="nvidia/nv-embed-v1",
truncate="END"
)
llm = ChatNVIDIA(
model="mistralai/mixtral-8x22b-instruct-v0.1",
temperature=0.2
)
# ================= LOAD DOCS =================
with open(EMBEDDINGS_FILE) as f:
DOCS = json.load(f)
# ================= UTILS =================
def cosine(a, b):
a, b = np.array(a), np.array(b)
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
def retrieve(question, k=4):
qvec = embedder.embed_query(question)
scored = [(cosine(qvec, d["embedding"]), d["text"]) for d in DOCS]
scored.sort(reverse=True)
return [t for _, t in scored[:k]]
def check_rate_limit():
today = datetime.date.today().isoformat()
with lock:
data = json.loads(RATE_FILE.read_text()) if RATE_FILE.exists() else {}
if data.get(today, 0) >= DAILY_LIMIT:
return False
data[today] = data.get(today, 0) + 1
RATE_FILE.write_text(json.dumps(data))
return True
def build_prompt(context, history, question):
history_text = "\n".join([
f"User: {q}\nAssistant: {a}"
for q, a in history[-MAX_HISTORY_TURNS:]
])
context_text = "\n\n---\n\n".join(context)
return f"""You are a document-grounded assistant.
Answer ONLY using the context.
If the answer is not present, say "I don't know".
Conversation so far:
{history_text}
Context:
{context_text}
User question:
{question}""".strip()
# ================= CHAT FN (STREAMING) =================
def chat_stream(question, history):
if not question.strip():
yield history
return
if not check_rate_limit():
history.append((question, "Daily limit reached (50 queries)."))
yield history
return
context = retrieve(question)
prompt = build_prompt(context, history, question)
partial = ""
for chunk in llm.stream(prompt):
partial += chunk.content
yield history + [(question, partial)]
# ================= UI =================
with gr.Blocks(title="Academic Regulations RAG") as demo:
gr.Markdown("## 📘 Academic Regulations Queries")
gr.Markdown(
"Ask questions about the academic regulations document. "
"Answers are generated **only** from the official document."
)
chatbot = gr.Chatbot(height=420)
question = gr.Textbox(
placeholder="e.g. What is the E grade?",
label="Your question",
scale=4
)
ask = gr.Button("Ask", scale=1, min_width=100)
clear = gr.Button("Clear Chat")
ask.click(chat_stream, [question, chatbot], chatbot)
question.submit(
chat_stream,
inputs=[question, chatbot],
outputs=chatbot
)
clear.click(lambda: [], None, chatbot)
# ================= RUN =================
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", 7860))
) |