kirubel1738 commited on
Commit
19a6024
·
verified ·
1 Parent(s): 68f0ca6

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +44 -25
src/streamlit_app.py CHANGED
@@ -1,8 +1,10 @@
1
  # streamlit_app.py
2
  import os
3
  import streamlit as st
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
5
  from peft import PeftModel
 
6
 
7
  # -----------------------------
8
  # Ensure cache dirs are writable in Spaces
@@ -13,17 +15,45 @@ os.environ.setdefault("HF_DATASETS_CACHE", "/tmp/huggingface/datasets")
13
  os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/huggingface/hub")
14
  os.environ.setdefault("XDG_CACHE_HOME", "/tmp/huggingface")
15
 
16
- # Base and adapter model IDs
17
- BASE_MODEL = "microsoft/BioGPT-Large-PubMedQA"
18
- ADAPTER_MODEL = "kirubel1738/biogpt-pubmedqa-finetuned"
 
 
19
 
 
 
 
20
  @st.cache_resource
21
  def load_model():
22
- """Load BioGPT with PubMedQA adapter on CPU."""
 
 
 
23
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
24
- base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map=None)
25
- model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL) # apply adapter
26
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  return generator
28
 
29
  # Load once
@@ -32,24 +62,17 @@ generator = load_model()
32
  # -----------------------------
33
  # Streamlit UI
34
  # -----------------------------
35
- st.set_page_config(page_title="BioGPT PubMedQA demo", layout="centered")
36
- st.title("🧬 BioGPTPubMedQA Demo")
37
-
38
- st.write("Ask a biomedical question and get an answer ")
39
- st.write(" generated by BioGPT-Large-PubMedQA fine-tuned on MMLU + SciQ dataset.")
40
 
41
- user_input = st.text_area("Enter your biomedical question:", height=150)
42
 
43
  if st.button("Get Answer"):
44
  if user_input.strip():
45
  with st.spinner("Generating answer..."):
46
  try:
47
- result = generator(
48
- user_input,
49
- max_new_tokens=128,
50
- do_sample=True,
51
- temperature=0.7
52
- )
53
  output_text = result[0]["generated_text"]
54
  st.success("Answer:")
55
  st.write(output_text)
@@ -59,8 +82,4 @@ if st.button("Get Answer"):
59
  st.warning("Please enter a question.")
60
 
61
  st.markdown("---")
62
- st.caption("Model: microsoft/biogpt + adapter kirubel1738/biogpt-pubmedqa-finetuned | Runs on CPU")
63
-
64
-
65
-
66
-
 
1
  # streamlit_app.py
2
  import os
3
  import streamlit as st
4
+ import torch
5
+ from transformers import AutoTokenizer, pipeline
6
  from peft import PeftModel
7
+ from unsloth import FastLanguageModel
8
 
9
  # -----------------------------
10
  # Ensure cache dirs are writable in Spaces
 
15
  os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/tmp/huggingface/hub")
16
  os.environ.setdefault("XDG_CACHE_HOME", "/tmp/huggingface")
17
 
18
+ # -----------------------------
19
+ # Model IDs
20
+ # -----------------------------
21
+ BASE_MODEL = "unsloth/llama-3-8b-bnb-4bit"
22
+ ADAPTER_MODEL = "kirubel1738/llama3-biology-qa"
23
 
24
+ # -----------------------------
25
+ # Load model once
26
+ # -----------------------------
27
  @st.cache_resource
28
  def load_model():
29
+ """Load LLaMA-3 8B with PEFT adapter entirely on CPU."""
30
+ st.info("Loading LLaMA-3 model on CPU... This may take a while.")
31
+
32
+ # Load tokenizer
33
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
34
+
35
+ # Load base model in 4-bit on CPU
36
+ base_model, _ = FastLanguageModel.from_pretrained(
37
+ model_name=BASE_MODEL,
38
+ max_seq_length=2048,
39
+ dtype=None,
40
+ load_in_4bit=True,
41
+ device_map={"": "cpu"} # force CPU
42
+ )
43
+
44
+ # Apply adapter
45
+ model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
46
+
47
+ # Text-generation pipeline on CPU
48
+ generator = pipeline(
49
+ "text-generation",
50
+ model=model,
51
+ tokenizer=tokenizer,
52
+ device=-1, # CPU
53
+ max_new_tokens=256,
54
+ do_sample=True,
55
+ temperature=0.7
56
+ )
57
  return generator
58
 
59
  # Load once
 
62
  # -----------------------------
63
  # Streamlit UI
64
  # -----------------------------
65
+ st.set_page_config(page_title="LLaMA-3 Biology QA", layout="centered")
66
+ st.title("🧬 LLaMA-3Biology QA Demo")
67
+ st.write("Ask a biology question and get an answer generated by LLaMA-3 fine-tuned on the Biology QA dataset.")
 
 
68
 
69
+ user_input = st.text_area("Enter your biology question:", height=150)
70
 
71
  if st.button("Get Answer"):
72
  if user_input.strip():
73
  with st.spinner("Generating answer..."):
74
  try:
75
+ result = generator(user_input)
 
 
 
 
 
76
  output_text = result[0]["generated_text"]
77
  st.success("Answer:")
78
  st.write(output_text)
 
82
  st.warning("Please enter a question.")
83
 
84
  st.markdown("---")
85
+ st.caption(f"Model: {BASE_MODEL} + adapter {ADAPTER_MODEL} | Runs on CPU")