Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -58,12 +58,12 @@ CURRENT_VDB = None
|
|
| 58 |
@spaces.GPU()
|
| 59 |
def get_image_description(image: Image.Image) -> str:
|
| 60 |
"""
|
| 61 |
-
Lazy-loads the Llava processor + model
|
| 62 |
runs captioning, and returns a one-sentence description.
|
| 63 |
"""
|
| 64 |
global processor, vision_model
|
| 65 |
|
| 66 |
-
#
|
| 67 |
if processor is None or vision_model is None:
|
| 68 |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
| 69 |
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
|
|
@@ -72,9 +72,9 @@ def get_image_description(image: Image.Image) -> str:
|
|
| 72 |
low_cpu_mem_usage=True
|
| 73 |
).to("cuda")
|
| 74 |
|
| 75 |
-
# clear and run
|
| 76 |
torch.cuda.empty_cache()
|
| 77 |
gc.collect()
|
|
|
|
| 78 |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
|
| 79 |
inputs = processor(prompt, image, return_tensors="pt").to("cuda")
|
| 80 |
output = vision_model.generate(**inputs, max_new_tokens=100)
|
|
@@ -175,21 +175,21 @@ def extract_data_from_pdfs(
|
|
| 175 |
):
|
| 176 |
"""
|
| 177 |
1) (Optional) OCR setup
|
| 178 |
-
2)
|
| 179 |
-
3) Extract text
|
| 180 |
-
4) Build and
|
| 181 |
"""
|
| 182 |
if not docs:
|
| 183 |
raise gr.Error("No documents to process")
|
| 184 |
|
| 185 |
-
# 1) OCR
|
| 186 |
if do_ocr == "Get Text With OCR":
|
| 187 |
db_m, crnn_m = OCR_CHOICES[ocr_choice]
|
| 188 |
local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
|
| 189 |
else:
|
| 190 |
local_ocr = None
|
| 191 |
|
| 192 |
-
# 2) Vision–language model
|
| 193 |
proc = LlavaNextProcessor.from_pretrained(vlm_choice)
|
| 194 |
vis = (
|
| 195 |
LlavaNextForConditionalGeneration
|
|
@@ -197,9 +197,10 @@ def extract_data_from_pdfs(
|
|
| 197 |
.to("cuda")
|
| 198 |
)
|
| 199 |
|
| 200 |
-
# Monkey-patch
|
| 201 |
def describe(img: Image.Image) -> str:
|
| 202 |
-
torch.cuda.empty_cache()
|
|
|
|
| 203 |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
|
| 204 |
inputs = proc(prompt, img, return_tensors="pt").to("cuda")
|
| 205 |
output = vis.generate(**inputs, max_new_tokens=100)
|
|
@@ -208,13 +209,12 @@ def extract_data_from_pdfs(
|
|
| 208 |
global get_image_description, CURRENT_VDB
|
| 209 |
get_image_description = describe
|
| 210 |
|
| 211 |
-
# 3) Extract text
|
| 212 |
progress(0.2, "Extracting text and images…")
|
| 213 |
all_text = ""
|
| 214 |
images, names = [], []
|
| 215 |
|
| 216 |
for path in docs:
|
| 217 |
-
# text
|
| 218 |
if local_ocr:
|
| 219 |
pdf = DocumentFile.from_pdf(path)
|
| 220 |
res = local_ocr(pdf)
|
|
@@ -223,29 +223,28 @@ def extract_data_from_pdfs(
|
|
| 223 |
txt = PdfReader(path).pages[0].extract_text() or ""
|
| 224 |
all_text += txt + "\n\n"
|
| 225 |
|
| 226 |
-
# images
|
| 227 |
if include_images == "Include Images":
|
| 228 |
imgs = extract_images([path])
|
| 229 |
images.extend(imgs)
|
| 230 |
names.extend([os.path.basename(path)] * len(imgs))
|
| 231 |
|
| 232 |
-
# 4) Build
|
| 233 |
progress(0.6, "Indexing in vector DB…")
|
| 234 |
CURRENT_VDB = get_vectordb(all_text, images, names)
|
| 235 |
|
| 236 |
-
# mark done & return only picklable outputs
|
| 237 |
session["processed"] = True
|
| 238 |
sample_imgs = images[:4] if include_images == "Include Images" else []
|
| 239 |
|
|
|
|
| 240 |
return (
|
| 241 |
-
session,
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
"<h3>Done!</h3>"
|
| 246 |
)
|
| 247 |
|
| 248 |
|
|
|
|
| 249 |
# Chat function
|
| 250 |
def conversation(
|
| 251 |
session: dict,
|
|
@@ -258,8 +257,7 @@ def conversation(
|
|
| 258 |
model_id: str
|
| 259 |
):
|
| 260 |
"""
|
| 261 |
-
|
| 262 |
-
calls the HF endpoint, and returns updated chat history.
|
| 263 |
"""
|
| 264 |
global CURRENT_VDB
|
| 265 |
if not session.get("processed") or CURRENT_VDB is None:
|
|
@@ -272,7 +270,7 @@ def conversation(
|
|
| 272 |
huggingfacehub_api_token=HF_TOKEN
|
| 273 |
)
|
| 274 |
|
| 275 |
-
#
|
| 276 |
text_col = CURRENT_VDB.get_collection("text_db")
|
| 277 |
docs = text_col.query(
|
| 278 |
query_texts=[question],
|
|
@@ -280,6 +278,7 @@ def conversation(
|
|
| 280 |
include=["documents"]
|
| 281 |
)["documents"][0]
|
| 282 |
|
|
|
|
| 283 |
img_col = CURRENT_VDB.get_collection("image_db")
|
| 284 |
img_q = img_col.query(
|
| 285 |
query_texts=[question],
|
|
@@ -296,7 +295,7 @@ def conversation(
|
|
| 296 |
pass
|
| 297 |
img_desc = "\n".join(img_descs)
|
| 298 |
|
| 299 |
-
# Build
|
| 300 |
prompt = PromptTemplate(
|
| 301 |
template="""
|
| 302 |
Context:
|
|
@@ -336,6 +335,7 @@ Answer:
|
|
| 336 |
|
| 337 |
|
| 338 |
|
|
|
|
| 339 |
# ─────────────────────────────────────────────────────────────────────────────
|
| 340 |
# Gradio UI
|
| 341 |
CSS = """
|
|
@@ -357,14 +357,13 @@ MODEL_OPTIONS = [
|
|
| 357 |
]
|
| 358 |
|
| 359 |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
| 360 |
-
vdb_state
|
| 361 |
session_state = gr.State({})
|
| 362 |
|
| 363 |
# ─── Welcome Screen ─────────────────────────────────────────────
|
| 364 |
with gr.Column(visible=True) as welcome_col:
|
| 365 |
-
|
| 366 |
gr.Markdown(
|
| 367 |
-
|
| 368 |
elem_id="welcome_md"
|
| 369 |
)
|
| 370 |
start_btn = gr.Button("🚀 Start")
|
|
@@ -386,6 +385,11 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
| 386 |
value="Exclude Images",
|
| 387 |
label="Images"
|
| 388 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
ocr_dd = gr.Dropdown(
|
| 390 |
choices=[
|
| 391 |
"db_resnet50 + crnn_mobilenet_v3_large",
|
|
@@ -405,28 +409,23 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
| 405 |
extract_btn = gr.Button("Extract")
|
| 406 |
preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False)
|
| 407 |
preview_img = gr.Gallery(label="Sample Images", rows=2, value=[])
|
|
|
|
| 408 |
|
| 409 |
extract_btn.click(
|
| 410 |
-
extract_data_from_pdfs,
|
| 411 |
inputs=[
|
| 412 |
docs,
|
| 413 |
session_state,
|
| 414 |
include_dd,
|
| 415 |
-
|
| 416 |
-
["Get Text With OCR", "Get Available Text Only"],
|
| 417 |
-
value="Get Available Text Only",
|
| 418 |
-
label="OCR"
|
| 419 |
-
),
|
| 420 |
ocr_dd,
|
| 421 |
vlm_dd
|
| 422 |
],
|
| 423 |
outputs=[
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
preview_img,
|
| 429 |
-
gr.HTML()
|
| 430 |
]
|
| 431 |
)
|
| 432 |
|
|
@@ -446,15 +445,15 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
| 446 |
value=MODEL_OPTIONS[0],
|
| 447 |
label="Choose Chat Model"
|
| 448 |
)
|
| 449 |
-
num_ctx = gr.Slider(1,20,value=3,label="Text Contexts")
|
| 450 |
-
img_ctx = gr.Slider(1,10,value=2,label="Image Contexts")
|
| 451 |
-
temp = gr.Slider(0.1,1.0,step=0.1,value=0.4,label="Temperature")
|
| 452 |
-
max_tok = gr.Slider(10,1000,step=10,value=200,label="Max Tokens")
|
| 453 |
|
| 454 |
send.click(
|
| 455 |
-
conversation,
|
| 456 |
inputs=[
|
| 457 |
-
|
| 458 |
msg,
|
| 459 |
num_ctx,
|
| 460 |
img_ctx,
|
|
@@ -465,18 +464,18 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
|
| 465 |
],
|
| 466 |
outputs=[
|
| 467 |
chat,
|
| 468 |
-
gr.Dataframe(),
|
| 469 |
gr.Gallery(label="Relevant Images", rows=2, value=[])
|
| 470 |
]
|
| 471 |
)
|
| 472 |
|
| 473 |
-
# Footer inside app_col
|
| 474 |
gr.HTML("<center>Made with ❤️ by Zamal</center>")
|
| 475 |
|
| 476 |
# ─── Wire the Start button ───────────────────────────────────────
|
| 477 |
start_btn.click(
|
| 478 |
fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
|
| 479 |
-
inputs=[],
|
|
|
|
| 480 |
)
|
| 481 |
|
| 482 |
if __name__ == "__main__":
|
|
|
|
| 58 |
@spaces.GPU()
|
| 59 |
def get_image_description(image: Image.Image) -> str:
|
| 60 |
"""
|
| 61 |
+
Lazy-loads the Llava processor + model inside the GPU worker,
|
| 62 |
runs captioning, and returns a one-sentence description.
|
| 63 |
"""
|
| 64 |
global processor, vision_model
|
| 65 |
|
| 66 |
+
# On first call, instantiate + move to CUDA
|
| 67 |
if processor is None or vision_model is None:
|
| 68 |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
|
| 69 |
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
|
|
|
|
| 72 |
low_cpu_mem_usage=True
|
| 73 |
).to("cuda")
|
| 74 |
|
|
|
|
| 75 |
torch.cuda.empty_cache()
|
| 76 |
gc.collect()
|
| 77 |
+
|
| 78 |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
|
| 79 |
inputs = processor(prompt, image, return_tensors="pt").to("cuda")
|
| 80 |
output = vision_model.generate(**inputs, max_new_tokens=100)
|
|
|
|
| 175 |
):
|
| 176 |
"""
|
| 177 |
1) (Optional) OCR setup
|
| 178 |
+
2) Vision+Lang model setup & monkey-patch get_image_description
|
| 179 |
+
3) Extract text & images
|
| 180 |
+
4) Build and stash vector DB in CURRENT_VDB
|
| 181 |
"""
|
| 182 |
if not docs:
|
| 183 |
raise gr.Error("No documents to process")
|
| 184 |
|
| 185 |
+
# 1) OCR pipeline if requested
|
| 186 |
if do_ocr == "Get Text With OCR":
|
| 187 |
db_m, crnn_m = OCR_CHOICES[ocr_choice]
|
| 188 |
local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
|
| 189 |
else:
|
| 190 |
local_ocr = None
|
| 191 |
|
| 192 |
+
# 2) Vision–language model
|
| 193 |
proc = LlavaNextProcessor.from_pretrained(vlm_choice)
|
| 194 |
vis = (
|
| 195 |
LlavaNextForConditionalGeneration
|
|
|
|
| 197 |
.to("cuda")
|
| 198 |
)
|
| 199 |
|
| 200 |
+
# Monkey-patch our pipeline for image captions
|
| 201 |
def describe(img: Image.Image) -> str:
|
| 202 |
+
torch.cuda.empty_cache()
|
| 203 |
+
gc.collect()
|
| 204 |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
|
| 205 |
inputs = proc(prompt, img, return_tensors="pt").to("cuda")
|
| 206 |
output = vis.generate(**inputs, max_new_tokens=100)
|
|
|
|
| 209 |
global get_image_description, CURRENT_VDB
|
| 210 |
get_image_description = describe
|
| 211 |
|
| 212 |
+
# 3) Extract text + images
|
| 213 |
progress(0.2, "Extracting text and images…")
|
| 214 |
all_text = ""
|
| 215 |
images, names = [], []
|
| 216 |
|
| 217 |
for path in docs:
|
|
|
|
| 218 |
if local_ocr:
|
| 219 |
pdf = DocumentFile.from_pdf(path)
|
| 220 |
res = local_ocr(pdf)
|
|
|
|
| 223 |
txt = PdfReader(path).pages[0].extract_text() or ""
|
| 224 |
all_text += txt + "\n\n"
|
| 225 |
|
|
|
|
| 226 |
if include_images == "Include Images":
|
| 227 |
imgs = extract_images([path])
|
| 228 |
images.extend(imgs)
|
| 229 |
names.extend([os.path.basename(path)] * len(imgs))
|
| 230 |
|
| 231 |
+
# 4) Build + store the vector DB
|
| 232 |
progress(0.6, "Indexing in vector DB…")
|
| 233 |
CURRENT_VDB = get_vectordb(all_text, images, names)
|
| 234 |
|
|
|
|
| 235 |
session["processed"] = True
|
| 236 |
sample_imgs = images[:4] if include_images == "Include Images" else []
|
| 237 |
|
| 238 |
+
# ─── return *exactly four* picklable outputs ───
|
| 239 |
return (
|
| 240 |
+
session, # gr.State: so UI knows we're ready
|
| 241 |
+
all_text[:2000] + "...", # preview text
|
| 242 |
+
sample_imgs, # preview images
|
| 243 |
+
"<h3>Done!</h3>" # Done message
|
|
|
|
| 244 |
)
|
| 245 |
|
| 246 |
|
| 247 |
+
|
| 248 |
# Chat function
|
| 249 |
def conversation(
|
| 250 |
session: dict,
|
|
|
|
| 257 |
model_id: str
|
| 258 |
):
|
| 259 |
"""
|
| 260 |
+
Uses the global CURRENT_VDB (set by extract_data_from_pdfs) to answer.
|
|
|
|
| 261 |
"""
|
| 262 |
global CURRENT_VDB
|
| 263 |
if not session.get("processed") or CURRENT_VDB is None:
|
|
|
|
| 270 |
huggingfacehub_api_token=HF_TOKEN
|
| 271 |
)
|
| 272 |
|
| 273 |
+
# 1) Text retrieval
|
| 274 |
text_col = CURRENT_VDB.get_collection("text_db")
|
| 275 |
docs = text_col.query(
|
| 276 |
query_texts=[question],
|
|
|
|
| 278 |
include=["documents"]
|
| 279 |
)["documents"][0]
|
| 280 |
|
| 281 |
+
# 2) Image retrieval
|
| 282 |
img_col = CURRENT_VDB.get_collection("image_db")
|
| 283 |
img_q = img_col.query(
|
| 284 |
query_texts=[question],
|
|
|
|
| 295 |
pass
|
| 296 |
img_desc = "\n".join(img_descs)
|
| 297 |
|
| 298 |
+
# 3) Build prompt & call LLM
|
| 299 |
prompt = PromptTemplate(
|
| 300 |
template="""
|
| 301 |
Context:
|
|
|
|
| 335 |
|
| 336 |
|
| 337 |
|
| 338 |
+
|
| 339 |
# ─────────────────────────────────────────────────────────────────────────────
|
| 340 |
# Gradio UI
|
| 341 |
CSS = """
|
|
|
|
| 357 |
]
|
| 358 |
|
| 359 |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
|
| 360 |
+
# We no longer need vdb_state – we keep only session_state
|
| 361 |
session_state = gr.State({})
|
| 362 |
|
| 363 |
# ─── Welcome Screen ─────────────────────────────────────────────
|
| 364 |
with gr.Column(visible=True) as welcome_col:
|
|
|
|
| 365 |
gr.Markdown(
|
| 366 |
+
f"<div style='text-align: center'>\n{WELCOME_INTRO}\n</div>",
|
| 367 |
elem_id="welcome_md"
|
| 368 |
)
|
| 369 |
start_btn = gr.Button("🚀 Start")
|
|
|
|
| 385 |
value="Exclude Images",
|
| 386 |
label="Images"
|
| 387 |
)
|
| 388 |
+
ocr_radio = gr.Radio(
|
| 389 |
+
["Get Text With OCR", "Get Available Text Only"],
|
| 390 |
+
value="Get Available Text Only",
|
| 391 |
+
label="OCR"
|
| 392 |
+
)
|
| 393 |
ocr_dd = gr.Dropdown(
|
| 394 |
choices=[
|
| 395 |
"db_resnet50 + crnn_mobilenet_v3_large",
|
|
|
|
| 409 |
extract_btn = gr.Button("Extract")
|
| 410 |
preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False)
|
| 411 |
preview_img = gr.Gallery(label="Sample Images", rows=2, value=[])
|
| 412 |
+
preview_html = gr.HTML() # for the “Done!” message
|
| 413 |
|
| 414 |
extract_btn.click(
|
| 415 |
+
fn=extract_data_from_pdfs,
|
| 416 |
inputs=[
|
| 417 |
docs,
|
| 418 |
session_state,
|
| 419 |
include_dd,
|
| 420 |
+
ocr_radio,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
ocr_dd,
|
| 422 |
vlm_dd
|
| 423 |
],
|
| 424 |
outputs=[
|
| 425 |
+
session_state, # session “processed” flag
|
| 426 |
+
preview_text, # preview text
|
| 427 |
+
preview_img, # preview images
|
| 428 |
+
preview_html # done HTML
|
|
|
|
|
|
|
| 429 |
]
|
| 430 |
)
|
| 431 |
|
|
|
|
| 445 |
value=MODEL_OPTIONS[0],
|
| 446 |
label="Choose Chat Model"
|
| 447 |
)
|
| 448 |
+
num_ctx = gr.Slider(1, 20, value=3, label="Text Contexts")
|
| 449 |
+
img_ctx = gr.Slider(1, 10, value=2, label="Image Contexts")
|
| 450 |
+
temp = gr.Slider(0.1, 1.0, step=0.1, value=0.4, label="Temperature")
|
| 451 |
+
max_tok = gr.Slider(10, 1000, step=10, value=200, label="Max Tokens")
|
| 452 |
|
| 453 |
send.click(
|
| 454 |
+
fn=conversation,
|
| 455 |
inputs=[
|
| 456 |
+
session_state, # now drives conversation
|
| 457 |
msg,
|
| 458 |
num_ctx,
|
| 459 |
img_ctx,
|
|
|
|
| 464 |
],
|
| 465 |
outputs=[
|
| 466 |
chat,
|
| 467 |
+
gr.Dataframe(), # returned docs
|
| 468 |
gr.Gallery(label="Relevant Images", rows=2, value=[])
|
| 469 |
]
|
| 470 |
)
|
| 471 |
|
|
|
|
| 472 |
gr.HTML("<center>Made with ❤️ by Zamal</center>")
|
| 473 |
|
| 474 |
# ─── Wire the Start button ───────────────────────────────────────
|
| 475 |
start_btn.click(
|
| 476 |
fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
|
| 477 |
+
inputs=[],
|
| 478 |
+
outputs=[welcome_col, app_col]
|
| 479 |
)
|
| 480 |
|
| 481 |
if __name__ == "__main__":
|