RAG_bai / app.py
Krish2005tech2's picture
Update app.py
3073782 verified
import os
import json
import datetime
import threading
from pathlib import Path
import numpy as np
import gradio as gr
from dotenv import load_dotenv
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
# ================= ENV =================
load_dotenv()
NVIDIA_API_KEY = os.getenv("NVIDIA_API_KEY")
if not NVIDIA_API_KEY:
raise RuntimeError("NVIDIA_API_KEY not found")
os.environ["NVIDIA_API_KEY"] = NVIDIA_API_KEY
# ================= CONFIG =================
DAILY_LIMIT = 50
RATE_FILE = Path("rate_limit.json")
EMBEDDINGS_FILE = Path("embeddings.json")
MAX_HISTORY_TURNS = 3 # keep last 3 Q/A pairs
lock = threading.Lock()
# ================= MODELS =================
embedder = NVIDIAEmbeddings(
model="nvidia/nv-embed-v1",
truncate="END"
)
llm = ChatNVIDIA(
model="mistralai/mixtral-8x22b-instruct-v0.1",
temperature=0.2
)
# ================= LOAD DOCS =================
with open(EMBEDDINGS_FILE) as f:
DOCS = json.load(f)
# ================= UTILS =================
def cosine(a, b):
a, b = np.array(a), np.array(b)
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
def retrieve(question, k=4):
qvec = embedder.embed_query(question)
scored = [(cosine(qvec, d["embedding"]), d["text"]) for d in DOCS]
scored.sort(reverse=True)
return [t for _, t in scored[:k]]
def check_rate_limit():
today = datetime.date.today().isoformat()
with lock:
data = json.loads(RATE_FILE.read_text()) if RATE_FILE.exists() else {}
if data.get(today, 0) >= DAILY_LIMIT:
return False
data[today] = data.get(today, 0) + 1
RATE_FILE.write_text(json.dumps(data))
return True
def build_prompt(context, history, question):
history_text = "\n".join([
f"User: {q}\nAssistant: {a}"
for q, a in history[-MAX_HISTORY_TURNS:]
])
context_text = "\n\n---\n\n".join(context)
return f"""You are a document-grounded assistant.
Answer ONLY using the context.
If the answer is not present, say "I don't know".
Conversation so far:
{history_text}
Context:
{context_text}
User question:
{question}""".strip()
# ================= CHAT FN (STREAMING) =================
def chat_stream(question, history):
if not question.strip():
yield history
return
if not check_rate_limit():
history.append((question, "Daily limit reached (50 queries)."))
yield history
return
context = retrieve(question)
prompt = build_prompt(context, history, question)
partial = ""
for chunk in llm.stream(prompt):
partial += chunk.content
yield history + [(question, partial)]
# ================= UI =================
with gr.Blocks(title="Academic Regulations RAG") as demo:
gr.Markdown("## πŸ“˜ Academic Regulations Queries")
gr.Markdown(
"Ask questions about the academic regulations document. "
"Answers are generated **only** from the official document."
)
chatbot = gr.Chatbot(height=420)
question = gr.Textbox(
placeholder="e.g. What is the E grade?",
label="Your question",
scale=4
)
ask = gr.Button("Ask", scale=1, min_width=100)
clear = gr.Button("Clear Chat")
ask.click(chat_stream, [question, chatbot], chatbot)
question.submit(
chat_stream,
inputs=[question, chatbot],
outputs=chatbot
)
clear.click(lambda: [], None, chatbot)
# ================= RUN =================
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", 7860))
)