Spaces:
Paused
Paused
| import os, re, time, datetime, threading, traceback, torch, gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| from transformers.utils import logging as hf_logging | |
| os.environ["HF_HOME"] = "/data/.huggingface" | |
| LOG_FILE = "/data/requests.log" | |
| def log(m): | |
| line = f"[{datetime.datetime.utcnow().strftime('%H:%M:%S.%f')[:-3]}] {m}" | |
| print(line, flush=True) | |
| try: | |
| with open(LOG_FILE, "a") as f: | |
| f.write(line + "\n") | |
| except FileNotFoundError: | |
| pass | |
| MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" | |
| CTX_TOK, MAX_NEW, TEMP = 1800, 64, 0.6 | |
| MAX_IN, RATE_N, RATE_T = 300, 5, 60 | |
| SYSTEM_MSG = ( | |
| "You are **SchoolSpirit AI**, the friendly digital mascot of " | |
| "SchoolSpirit AI LLC, founded by Charles Norton in 2025. " | |
| "The company installs on‑prem AI chat mascots, fine‑tunes language models, " | |
| "and ships turnkey GPU servers to K‑12 schools.\n\n" | |
| "RULES:\n" | |
| "• Reply in ≤ 4 sentences unless asked for detail.\n" | |
| "• No personal‑data collection; no medical/legal/financial advice.\n" | |
| "• If uncertain, say so and suggest contacting a human.\n" | |
| "• If you can’t answer, politely direct the user to admin@schoolspiritai.com.\n" | |
| "• Keep language age‑appropriate; avoid profanity, politics, mature themes." | |
| ) | |
| WELCOME = "Hi there! I’m SchoolSpirit AI. How can I help?" | |
| strip = lambda s: re.sub(r"\s+", " ", s.strip()) | |
| hf_logging.set_verbosity_error() | |
| try: | |
| tok = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto" if torch.cuda.is_available() else "cpu", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else "auto", | |
| low_cpu_mem_usage=True, | |
| ) | |
| MODEL_ERR = None | |
| log("Model loaded") | |
| except Exception as e: | |
| MODEL_ERR = f"Model load error: {e}" | |
| log(MODEL_ERR + "\n" + traceback.format_exc()) | |
| VISITS = {} | |
| def allowed(ip): | |
| now = time.time() | |
| VISITS[ip] = [t for t in VISITS.get(ip, []) if now - t < RATE_T] | |
| if len(VISITS[ip]) >= RATE_N: | |
| return False | |
| VISITS[ip].append(now) | |
| return True | |
| def build_prompt(raw): | |
| def render(m): | |
| if m["role"] == "system": | |
| return m["content"] | |
| return f"{'User:' if m['role']=='user' else 'AI:'} {m['content']}" | |
| sys, convo = raw[0], raw[1:] | |
| while True: | |
| parts = [sys["content"]] + [render(m) for m in convo] + ["AI:"] | |
| if len(tok.encode("\n".join(parts), add_special_tokens=False)) <= CTX_TOK or len(convo) <= 2: | |
| return "\n".join(parts) | |
| convo = convo[2:] | |
| def chat_fn(user_msg, hist, state, request: gr.Request): | |
| ip = request.client.host if request else "anon" | |
| if not allowed(ip): | |
| hist.append((user_msg, "Rate limit exceeded — please wait a minute.")) | |
| return hist, state, "" | |
| user_msg = strip(user_msg or "") | |
| if not user_msg: | |
| return hist, state, "" | |
| if len(user_msg) > MAX_IN: | |
| hist.append((user_msg, f"Input >{MAX_IN} chars.")) | |
| return hist, state, "" | |
| if MODEL_ERR: | |
| hist.append((user_msg, MODEL_ERR)) | |
| return hist, state, "" | |
| hist.append((user_msg, "")) | |
| state["raw"].append({"role": "user", "content": user_msg}) | |
| prompt = build_prompt(state["raw"]) | |
| ids = tok(prompt, return_tensors="pt").to(model.device).input_ids | |
| streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True) | |
| threading.Thread( | |
| target=model.generate, | |
| kwargs=dict(input_ids=ids, max_new_tokens=MAX_NEW, temperature=TEMP, streamer=streamer), | |
| ).start() | |
| partial = "" | |
| for piece in streamer: | |
| partial += piece | |
| if "User:" in partial or "\nAI:" in partial: | |
| partial = re.split(r"(?:\n?User:|\n?AI:)", partial)[0].strip() | |
| break | |
| hist[-1] = (user_msg, partial) | |
| yield hist, state, "" | |
| reply = strip(partial) | |
| hist[-1] = (user_msg, reply) | |
| state["raw"].append({"role": "assistant", "content": reply}) | |
| yield hist, state, "" | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
| gr.Markdown("### SchoolSpirit AI Chat") | |
| bot = gr.Chatbot(value=[("", WELCOME)], height=480) | |
| st = gr.State({ | |
| "raw": [ | |
| {"role": "system", "content": SYSTEM_MSG}, | |
| {"role": "assistant", "content": WELCOME}, | |
| ] | |
| }) | |
| with gr.Row(): | |
| txt = gr.Textbox(placeholder="Type your question here…", show_label=False, lines=1, scale=4) | |
| send = gr.Button("Send", variant="primary") | |
| send.click(chat_fn, inputs=[txt, bot, st], outputs=[bot, st, txt]) | |
| txt.submit(chat_fn, inputs=[txt, bot, st], outputs=[bot, st, txt]) | |
| demo.launch() | |