SchoolSpiritAI / app.py
phanerozoic's picture
Update app.py
318dc96 verified
raw
history blame
6.76 kB
"""
SchoolSpiritΒ AI – Granite‑3.3‑2B chatbot (GradioΒ 4.3, messages API)
────────────────────────────────────────────────────────────────────
β€’ Persistent HF cache: HF_HOME=/data/.huggingface (25Β GB tier)
β€’ Persistent request log: /data/requests.log
β€’ Detailed system prompt (brand + guardrails)
β€’ Traces every request: Received β†’ Prompt β†’ generate() timing
β€’ Cleans replies & removes any stray β€œUser:” / β€œAI:” echoes
"""
# ──────────────────── standard libraries ───────────────────────────────────
from __future__ import annotations
import os, re, time, datetime, traceback
# ───── gradio + hf transformers ────────────────────────────────────────────
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers.utils import logging as hf_logging
# ──────────────────── persistent disk paths ────────────────────────────────
os.environ["HF_HOME"] = "/data/.huggingface" # model / tokenizer cache
LOG_FILE = "/data/requests.log" # simple persistent log
def log(msg: str) -> None:
"""Print + append to /data/requests.log with UTC timestamp."""
ts = datetime.datetime.utcnow().strftime("%H:%M:%S.%f")[:-3]
line = f"[{ts}] {msg}"
print(line, flush=True)
try: # ignore first‑run errors
with open(LOG_FILE, "a") as f:
f.write(line + "\n")
except FileNotFoundError:
pass
# ──────────────────── chatbot configuration ────────────────────────────────
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" # 2Β B params, Apache‑2
MAX_TURNS = 6 # keep last N user/assistant pairs
MAX_TOKENS = 128 # reply length (raise if you have patience)
MAX_INPUT_CH = 400 # user message length guard
SYSTEM_MSG = (
"You are **SchoolSpiritΒ AI**, the friendly digital mascot for a company "
"that provides on‑prem AI chat mascots, fine‑tuning services, and turnkey "
"GPU hardware for schools.\n\n"
"β€’ Keep answers concise, upbeat, and age‑appropriate (K‑12).\n"
"β€’ If you are unsure, say so and suggest contacting a human staff member.\n"
"β€’ Never request personal data beyond an email if the user volunteers it.\n"
"β€’ Do **not** provide medical, legal, or financial advice.\n"
"β€’ No politics, mature content, or profanity.\n"
"Respond in a friendly, encouraging toneβ€”as a helpful school mascot!"
)
# ──────────────────── load model & pipeline ────────────────────────────────
hf_logging.set_verbosity_error()
try:
log("Loading tokenizer & model …")
tok = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto"
)
gen = pipeline(
"text-generation",
model=model,
tokenizer=tok,
max_new_tokens=MAX_TOKENS,
do_sample=True,
temperature=0.7,
)
MODEL_ERR = None
log("Model loaded βœ”")
except Exception as exc: # noqa: BLE001
MODEL_ERR, gen = f"Model load error: {exc}", None
log(MODEL_ERR)
# ──────────────────── small helpers ────────────────────────────────────────
def clean(txt: str) -> str:
"""Collapse whitespace & guarantee non‑empty string."""
return re.sub(r"\s+", " ", txt.strip()) or "…"
def trim_history(msgs: list[dict]) -> list[dict]:
"""Keep system + last MAX_TURNS pairs."""
return msgs if len(msgs) <= 1 + MAX_TURNS * 2 else [msgs[0]] + msgs[-MAX_TURNS * 2 :]
# ──────────────────── core chat function ───────────────────────────────────
def chat_fn(user_msg: str, history: list[dict] | None):
log(f"User sent {len(user_msg)} chars")
# ensure history list exists & begins with system prompt
if not history or history[0]["role"] != "system":
history = [{"role": "system", "content": SYSTEM_MSG}]
# fatal model‑load failure
if MODEL_ERR:
return MODEL_ERR
# basic user‑input checks
user_msg = clean(user_msg or "")
if not user_msg:
return "Please type something."
if len(user_msg) > MAX_INPUT_CH:
return f"Message too long (>{MAX_INPUT_CH} chars)."
# add user message & trim
history.append({"role": "user", "content": user_msg})
history = trim_history(history)
# build prompt string
prompt_lines: list[str] = []
for m in history:
if m["role"] == "system":
prompt_lines.append(m["content"])
elif m["role"] == "user":
prompt_lines.append(f"User: {m['content']}")
else:
prompt_lines.append(f"AI: {m['content']}")
prompt_lines.append("AI:")
prompt = "\n".join(prompt_lines)
log(f"Prompt {len(prompt)} chars β€’ generating…")
# call generator
t0 = time.time()
try:
raw = gen(prompt)[0]["generated_text"]
reply = clean(raw.split("AI:", 1)[-1])
# βœ‚ remove any echoed tags
reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
log(f"generate() {time.time() - t0:.2f}s β€’ reply {len(reply)} chars")
except Exception: # noqa: BLE001
log("❌ Inference exception:\n" + traceback.format_exc())
reply = "Sorryβ€”AI backend crashed. Please try again later."
return reply
# ──────────────────── Gradio UI ────────────────────────────────────────────
gr.ChatInterface(
fn=chat_fn,
chatbot=gr.Chatbot(height=480, type="messages"),
title="SchoolSpiritΒ AI Chat",
theme=gr.themes.Soft(primary_hue="blue"), # light‑blue accent
type="messages", # modern message dicts
).launch()