phanerozoic commited on
Commit
2966210
·
verified ·
1 Parent(s): 94cf8c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -79
app.py CHANGED
@@ -1,11 +1,20 @@
1
- import os, re, time, datetime, traceback, torch
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from transformers.utils import logging as hf_logging
5
 
6
- # -------------------------------------------------------------------
7
- # 1. Logging helpers
8
- # -------------------------------------------------------------------
9
  os.environ["HF_HOME"] = "/data/.huggingface"
10
  LOG_FILE = "/data/requests.log"
11
 
@@ -21,38 +30,38 @@ def log(msg: str):
21
  pass
22
 
23
 
24
- # -------------------------------------------------------------------
25
- # 2. Configuration
26
- # -------------------------------------------------------------------
27
- MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
28
- MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 4, 64, 300
 
 
 
29
 
30
  SYSTEM_MSG = (
31
- "You are **SchoolSpiritAI**, the digital mascot for SchoolSpirit AI LLC, "
32
- "founded by Charles Norton in 2025. The company installs on‑prem AI chat "
33
- "mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to "
34
- "K‑12 schools.\n\n"
35
- "GUIDELINES:\n"
36
- "• Warm, encouraging tone for students, parents, staff.\n"
37
- "• Replies ≤ 4 sentences unless asked for detail.\n"
38
- "• If unsure/out‑of‑scope: say so and suggest human follow‑up.\n"
39
- "• No personal‑data collection or sensitive advice.\n"
40
- "• No profanity, politics, or mature themes."
41
  )
42
- WELCOME_MSG = "Welcome to SchoolSpiritAI! Do you have any questions?"
43
 
 
44
 
45
- def strip(s: str) -> str:
46
- return re.sub(r"\s+", " ", s.strip())
47
 
48
-
49
- # -------------------------------------------------------------------
50
- # 3. Load model (GPU FP‑16 → CPU fallback)
51
- # -------------------------------------------------------------------
52
  hf_logging.set_verbosity_error()
53
  try:
54
  log("Loading tokenizer …")
55
- tok = AutoTokenizer.from_pretrained(MODEL_ID)
56
 
57
  if torch.cuda.is_available():
58
  log("GPU detected → FP‑16")
@@ -62,90 +71,132 @@ try:
62
  else:
63
  log("CPU fallback")
64
  model = AutoModelForCausalLM.from_pretrained(
65
- MODEL_ID, device_map="cpu", torch_dtype="auto", low_cpu_mem_usage=True
 
 
 
66
  )
67
 
68
- gen = pipeline(
69
  "text-generation",
70
  model=model,
71
- tokenizer=tok,
72
- max_new_tokens=MAX_TOKENS,
73
  do_sample=True,
74
- temperature=0.6,
75
  )
76
  MODEL_ERR = None
77
  log("Model loaded ✔")
78
- except Exception as exc: # noqa: BLE001
79
- MODEL_ERR, gen = f"Model load error: {exc}", None
 
80
  log(MODEL_ERR)
81
 
82
 
83
- # -------------------------------------------------------------------
84
- # 4. Chat callback
85
- # -------------------------------------------------------------------
86
- def chat_fn(user_msg: str, history: list[tuple[str, str]], state: dict):
87
  """
88
- history: list of (user, assistant) tuples (Gradio default)
89
- state : dict carrying system_prompt + raw_history for the model
90
- Returns updated history (for UI) and state (for next round)
91
  """
92
- if MODEL_ERR:
93
- return history + [(user_msg, MODEL_ERR)], state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
 
 
 
 
 
 
 
 
 
 
 
95
  user_msg = strip(user_msg or "")
96
  if not user_msg:
97
- return history + [(user_msg, "Please type something.")], state
 
98
  if len(user_msg) > MAX_INPUT_CH:
99
- warn = f"Message too long (>{MAX_INPUT_CH} chars)."
100
- return history + [(user_msg, warn)], state
101
-
102
- # ------------------------------------------------ Prompt assembly
103
- raw_hist = state.get("raw", [])
104
- raw_hist.append({"role": "user", "content": user_msg})
105
- # keep system + last N exchanges
106
- convo = [m for m in raw_hist if m["role"] != "system"][-MAX_TURNS * 2 :]
107
- raw_hist = [{"role": "system", "content": SYSTEM_MSG}] + convo
108
-
109
- prompt = "\n".join(
110
- [
111
- m["content"]
112
- if m["role"] == "system"
113
- else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}'
114
- for m in raw_hist
115
- ]
116
- + ["AI:"]
117
- )
118
 
 
 
 
 
119
  try:
120
- raw = gen(prompt)[0]["generated_text"]
121
- reply = strip(raw.split("AI:", 1)[-1])
 
122
  reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
 
123
  except Exception:
124
  log("❌ Inference error:\n" + traceback.format_exc())
125
- reply = "Sorrybackend crashed. Please try again later."
126
 
127
- # ------------------------------------------------ Update state + UI history
128
- raw_hist.append({"role": "assistant", "content": reply})
129
- state["raw"] = raw_hist
130
- history.append((user_msg, reply))
131
- return history, state
132
 
133
 
134
- # -------------------------------------------------------------------
135
- # 5. Launch
136
- # -------------------------------------------------------------------
137
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
 
 
138
  chatbot = gr.Chatbot(
139
- value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI"
 
 
140
  )
141
- state = gr.State({"raw": [{"role": "system", "content": SYSTEM_MSG}]})
 
 
 
 
 
 
 
 
 
142
  with gr.Row():
143
  txt = gr.Textbox(
144
- scale=4, placeholder="Type your question here...", show_label=False
 
 
 
145
  )
146
- send = gr.Button("Send", variant="primary")
147
 
148
- send.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
149
  txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
150
 
151
  demo.launch()
 
1
+ ##############################################################################
2
+ # SchoolSpirit AI – Full‑context Chatbot Space
3
+ # -------------------------------------------------
4
+ # • Maintains a rolling prompt that only truncates when token budget exceeded
5
+ # • GPU FP‑16 load, CPU fallback
6
+ # • Detailed logging, duplicate‑proof welcome
7
+ # • Gradio Blocks UI (textbox + Send button)
8
+ ##############################################################################
9
+
10
+ import os, re, time, datetime, traceback, torch, json, math
11
  import gradio as gr
12
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
13
  from transformers.utils import logging as hf_logging
14
 
15
+ # ---------------------------------------------------------------------------
16
+ # 0. Paths & basic logging helper
17
+ # ---------------------------------------------------------------------------
18
  os.environ["HF_HOME"] = "/data/.huggingface"
19
  LOG_FILE = "/data/requests.log"
20
 
 
30
  pass
31
 
32
 
33
+ # ---------------------------------------------------------------------------
34
+ # 1. Configuration constants
35
+ # ---------------------------------------------------------------------------
36
+ MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" # 2 B model fits Spaces
37
+ CONTEXT_TOKENS = 1800 # leave head‑room for reply inside 2k window
38
+ MAX_NEW_TOKENS = 64
39
+ TEMPERATURE = 0.6
40
+ MAX_INPUT_CH = 300 # UI safeguard
41
 
42
  SYSTEM_MSG = (
43
+ "You are **SchoolSpirit AI**, the official digital mascot of "
44
+ "SchoolSpirit AI LLC. Founded by Charles Norton in 2025, the company "
45
+ "deploys on‑prem AI chat mascots, fine‑tunes language models, and ships "
46
+ "turnkey GPU servers to K‑12 schools.\n\n"
47
+ "RULES:\n"
48
+ "• Friendly, concise (≤4 sentences unless prompted).\n"
49
+ "• No personal data collection; no medical/legal/financial advice.\n"
50
+ "• If uncertain, admit it & suggest human follow‑up.\n"
51
+ "• Avoid profanity, politics, mature themes."
 
52
  )
53
+ WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
54
 
55
+ strip = lambda s: re.sub(r"\s+", " ", s.strip())
56
 
 
 
57
 
58
+ # ---------------------------------------------------------------------------
59
+ # 2. Load tokenizer + model (GPU FP‑16 → CPU)
60
+ # ---------------------------------------------------------------------------
 
61
  hf_logging.set_verbosity_error()
62
  try:
63
  log("Loading tokenizer …")
64
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
65
 
66
  if torch.cuda.is_available():
67
  log("GPU detected → FP‑16")
 
71
  else:
72
  log("CPU fallback")
73
  model = AutoModelForCausalLM.from_pretrained(
74
+ MODEL_ID,
75
+ device_map="cpu",
76
+ torch_dtype="auto",
77
+ low_cpu_mem_usage=True,
78
  )
79
 
80
+ generator = pipeline(
81
  "text-generation",
82
  model=model,
83
+ tokenizer=tokenizer,
84
+ max_new_tokens=MAX_NEW_TOKENS,
85
  do_sample=True,
86
+ temperature=TEMPERATURE,
87
  )
88
  MODEL_ERR = None
89
  log("Model loaded ✔")
90
+ except Exception as exc:
91
+ MODEL_ERR = f"Model load error: {exc}"
92
+ generator = None
93
  log(MODEL_ERR)
94
 
95
 
96
+ # ---------------------------------------------------------------------------
97
+ # 3. Helper: build prompt under token budget
98
+ # ---------------------------------------------------------------------------
99
+ def build_prompt(raw_history: list[dict]) -> str:
100
  """
101
+ raw_history: list [{'role':'system'|'user'|'assistant', 'content': str}, ...]
102
+ Keeps trimming oldest user/assistant pair until total tokens < CONTEXT_TOKENS
 
103
  """
104
+ def render(msg):
105
+ if msg["role"] == "system":
106
+ return msg["content"]
107
+ prefix = "User:" if msg["role"] == "user" else "AI:"
108
+ return f"{prefix} {msg['content']}"
109
+
110
+ # always include system
111
+ system_msg = [msg for msg in raw_history if msg["role"] == "system"][0]
112
+ convo = [m for m in raw_history if m["role"] != "system"]
113
+
114
+ # iterative trim
115
+ while True:
116
+ prompt_parts = [system_msg["content"]] + [render(m) for m in convo] + ["AI:"]
117
+ token_len = len(tokenizer.encode("\n".join(prompt_parts), add_special_tokens=False))
118
+ if token_len <= CONTEXT_TOKENS or len(convo) <= 2:
119
+ break
120
+ # drop oldest user+assistant pair
121
+ convo = convo[2:]
122
 
123
+ return "\n".join(prompt_parts)
124
+
125
+
126
+ # ---------------------------------------------------------------------------
127
+ # 4. Chat callback
128
+ # ---------------------------------------------------------------------------
129
+ def chat_fn(user_msg: str, display_history: list, state: dict):
130
+ """
131
+ display_history : list[tuple[str,str]] for UI
132
+ state["raw"] : list[dict] for prompting
133
+ """
134
  user_msg = strip(user_msg or "")
135
  if not user_msg:
136
+ return display_history, state
137
+
138
  if len(user_msg) > MAX_INPUT_CH:
139
+ display_history.append((user_msg, f"Input >{MAX_INPUT_CH} chars."))
140
+ return display_history, state
141
+
142
+ if MODEL_ERR:
143
+ display_history.append((user_msg, MODEL_ERR))
144
+ return display_history, state
145
+
146
+ # --- Update raw history
147
+ state["raw"].append({"role": "user", "content": user_msg})
 
 
 
 
 
 
 
 
 
 
148
 
149
+ # --- Build prompt within token budget
150
+ prompt = build_prompt(state["raw"])
151
+
152
+ # --- Generate
153
  try:
154
+ start = time.time()
155
+ completion = generator(prompt)[0]["generated_text"]
156
+ reply = strip(completion.split("AI:", 1)[-1])
157
  reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
158
+ log(f"Reply in {time.time()-start:.2f}s ({len(reply)} chars)")
159
  except Exception:
160
  log("❌ Inference error:\n" + traceback.format_exc())
161
+ reply = "Apologiesan internal error occurred. Please try again."
162
 
163
+ # --- Append assistant reply to both histories
164
+ display_history.append((user_msg, reply))
165
+ state["raw"].append({"role": "assistant", "content": reply})
166
+ return display_history, state
 
167
 
168
 
169
+ # ---------------------------------------------------------------------------
170
+ # 5. Launch Gradio Blocks UI
171
+ # ---------------------------------------------------------------------------
172
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
173
+ gr.Markdown("### SchoolSpirit AI Chat")
174
+
175
  chatbot = gr.Chatbot(
176
+ value=[("", WELCOME_MSG)],
177
+ height=480,
178
+ label="SchoolSpirit AI",
179
  )
180
+
181
+ state = gr.State(
182
+ {
183
+ "raw": [
184
+ {"role": "system", "content": SYSTEM_MSG},
185
+ {"role": "assistant", "content": WELCOME_MSG},
186
+ ]
187
+ }
188
+ )
189
+
190
  with gr.Row():
191
  txt = gr.Textbox(
192
+ placeholder="Type your question here",
193
+ show_label=False,
194
+ scale=4,
195
+ lines=1,
196
  )
197
+ send_btn = gr.Button("Send", variant="primary")
198
 
199
+ send_btn.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
200
  txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
201
 
202
  demo.launch()