SmartPDF_Q_A / app.py
aaporosh's picture
Update app.py
56d0815 verified
raw
history blame
15.1 kB
import streamlit as st
import logging
import os
from io import BytesIO
import re
import time
from typing import List, Tuple, Optional
import pdfplumber
# Optional OCR (guarded)
try:
import pytesseract
OCR_AVAILABLE = True
except Exception:
OCR_AVAILABLE = False
from rank_bm25 import BM25Okapi
# Embeddings + Vector store
from sentence_transformers import SentenceTransformer
import numpy as np
try:
import faiss # direct FAISS for speed and control
FAISS_OK = True
except Exception:
FAISS_OK = False
# Lightweight HF pipelines
from transformers import pipeline
# ----------------------------
# App & Logging Setup
# ----------------------------
st.set_page_config(page_title="Smart PDF Chat & Summarizer", page_icon="📄", layout="wide")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("smart_pdf")
# ----------------------------
# Caching: models & utilities
# ----------------------------
@st.cache_resource(show_spinner=False)
def get_embedder(name: str = "sentence-transformers/all-MiniLM-L6-v2"):
return SentenceTransformer(name)
@st.cache_resource(show_spinner=False)
def get_qa_pipeline():
# Small, fast instruction model
return pipeline(
"text2text-generation",
model="google/flan-t5-small",
device=-1,
max_length=220
)
@st.cache_resource(show_spinner=False)
def get_summarizer():
# DistilBART is much faster than bart-large-cnn
return pipeline(
"summarization",
model="sshleifer/distilbart-cnn-12-6",
device=-1,
max_length=220,
min_length=80,
do_sample=False,
)
# ----------------------------
# PDF processing
# ----------------------------
def _looks_like_code(line: str) -> bool:
if len(line.strip()) == 0:
return False
# Heuristics for code-y lines
code_tokens = [
r"\b(def|class|import|from|return|if|elif|else|for|while|try|except|finally|with)\b",
r"[{}`;<>]|::|=>|#|//|/\*|\*/",
r"\(|\)|\[|\]|\{|\}",
]
matches = sum(bool(re.search(p, line)) for p in code_tokens)
indent = len(line) - len(line.lstrip())
return matches >= 1 or indent >= 4
def extract_text_and_code_from_pdf(file_bytes: bytes, ocr_fallback: bool = True, max_pages: int = 50) -> Tuple[str, List[str]]:
"""Return (plain_text, code_blocks[]) from a PDF with simple OCR fallback."""
text_parts: List[str] = []
code_lines: List[str] = []
with pdfplumber.open(BytesIO(file_bytes)) as pdf:
pages = pdf.pages[:max_pages]
for page in pages:
# 1) Try text extraction
extracted = page.extract_text(x_tolerance=1.5, y_tolerance=1.0) or ""
# 2) OCR fallback if page empty and OCR available
if not extracted.strip() and ocr_fallback and OCR_AVAILABLE:
try:
img = page.to_image(resolution=180).original
extracted = pytesseract.image_to_string(img, config='--psm 6') or ""
except Exception as e:
logger.warning(f"OCR failed on a page: {e}")
# 3) Clean and collect
if extracted:
# Remove common headers/footers by simple rules
lines = [ln for ln in extracted.splitlines() if not re.match(r"^(Page\s*\d+|Copyright.*)$", ln, flags=re.I)]
text_parts.append("\n".join(lines))
# Code detection: fenced blocks first
fenced = re.findall(r"```[\w-]*\n([\s\S]*?)```", extracted, flags=re.M)
for blk in fenced:
blk = blk.strip()
if blk:
code_lines.append(blk)
# Otherwise, line-wise heuristic
for ln in lines:
if _looks_like_code(ln):
code_lines.append(ln)
# 4) Tables -> pipe-separated rows
try:
tables = page.extract_tables() or []
for tb in tables:
for row in tb:
if row and any(str(c).strip() for c in row):
text_parts.append(" | ".join(str(c).strip() for c in row))
except Exception:
pass
full_text = "\n\n".join(tp for tp in text_parts if tp.strip())
# Merge adjacent code lines into blocks
code_blocks: List[str] = []
if code_lines:
current: List[str] = []
for ln in code_lines:
if ln.strip():
current.append(ln)
else:
if current:
code_blocks.append("\n".join(current))
current = []
if current:
code_blocks.append("\n".join(current))
# Deduplicate & trim giant blocks
seen = set()
unique_blocks = []
for blk in code_blocks:
key = blk.strip()
if key and key not in seen:
seen.add(key)
# cap extreme long blocks for UI; still allow download of full
unique_blocks.append(blk[:8000])
return full_text, unique_blocks
# ----------------------------
# Chunking & Indexing
# ----------------------------
def chunk_text(text: str, chunk_size: int = 700, chunk_overlap: int = 120) -> List[str]:
text = re.sub(r"\n{3,}", "\n\n", text).strip()
paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
chunks: List[str] = []
buf: str = ""
for para in paras:
if not buf:
buf = para
elif len(buf) + len(para) + 1 <= chunk_size:
buf += "\n" + para
else:
chunks.append(buf)
# overlap
overlap = buf[-chunk_overlap:] if chunk_overlap > 0 else ""
buf = (overlap + "\n" + para).strip()
if buf:
chunks.append(buf)
return chunks
@st.cache_resource(show_spinner=False)
def build_indexes(chunks: List[str]):
embedder = get_embedder()
matrix = embedder.encode(chunks, show_progress_bar=False, batch_size=64, normalize_embeddings=True)
matrix = np.asarray(matrix).astype('float32')
bm25 = BM25Okapi([c.split() for c in chunks])
if FAISS_OK:
index = faiss.IndexFlatIP(matrix.shape[1])
index.add(matrix)
return {
"chunks": chunks,
"embeddings": matrix,
"faiss": index,
"bm25": bm25,
}
else:
# Fallback: cosine via numpy (slower but OK for small docs)
return {
"chunks": chunks,
"embeddings": matrix,
"faiss": None,
"bm25": bm25,
}
# ----------------------------
# Retrieval + QA
# ----------------------------
def retrieve(topk: int, query: str, idx):
chunks = idx["chunks"]
embeddings = idx["embeddings"]
bm25 = idx["bm25"]
# BM25
bm25_docs = bm25.get_top_n(query.split(), chunks, n=min(topk, len(chunks)))
# FAISS / cosine
embedder = get_embedder()
qv = embedder.encode([query], normalize_embeddings=True)[0].astype('float32')
if idx["faiss"] is not None:
D, I = idx["faiss"].search(np.array([qv]), min(topk, len(chunks)))
faiss_docs = [chunks[i] for i in I[0]]
else:
# cosine with numpy
sims = embeddings @ qv
order = np.argsort(-sims)[:topk]
faiss_docs = [chunks[i] for i in order]
# Merge uniques with preference to BM25 then FAISS
merged: List[str] = []
seen = set()
for c in bm25_docs + faiss_docs:
if c not in seen:
merged.append(c)
seen.add(c)
if len(merged) >= topk:
break
return merged
def rag_answer(query: str, idx, max_ctx_chars: int = 3000) -> str:
ctx_chunks = retrieve(6, query, idx)
# Concatenate up to a char budget
ctx = "\n\n".join(ctx_chunks)
if len(ctx) > max_ctx_chars:
ctx = ctx[:max_ctx_chars]
qa = get_qa_pipeline()
prompt = (
"Answer the question using ONLY the provided context. "
"If the answer is not in the context, say 'I couldn't find that in the PDF.'\n\n"
f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:"
)
out = qa(prompt)[0]["generated_text"].strip()
return out
def summarize_text(full_text: str) -> str:
summarizer = get_summarizer()
# Summarize in parts for long docs
chunks = chunk_text(full_text, chunk_size=1200, chunk_overlap=150)
partials = []
for ch in chunks[:8]: # cap to keep it snappy on CPU
partials.append(summarizer(ch)[0]["summary_text"].strip())
# Final stitch summary
stitched = " ".join(partials)
if len(stitched) > 2000:
stitched = summarizer(stitched[:3000])[0]["summary_text"].strip()
return stitched
# ----------------------------
# UI
# ----------------------------
st.markdown(
"""
<style>
.app-header {background: linear-gradient(90deg,#10b981,#22c55e); color: white; padding: 16px; border-radius: 14px; text-align:center; box-shadow: 0 6px 20px rgba(16,185,129,.25)}
.card {border:1px solid #e5e7eb; border-radius: 14px; padding: 16px; background: #fff}
.muted {color:#6b7280}
.kbd {background:#f3f4f6; border:1px solid #e5e7eb; border-radius:6px; padding:2px 6px; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco}
</style>
""",
unsafe_allow_html=True,
)
st.markdown('<div class="app-header"><h1>📄 Smart PDF Chat & Summarizer</h1><p class="muted">Fast answers, focused summaries, and automatic code extraction</p></div>', unsafe_allow_html=True)
# Session state
if "idx" not in st.session_state:
st.session_state.idx = None
if "pdf_text" not in st.session_state:
st.session_state.pdf_text = ""
if "code_blocks" not in st.session_state:
st.session_state.code_blocks = []
# Sidebar
with st.sidebar:
st.subheader("Upload & Options")
file = st.file_uploader("Upload a PDF", type=["pdf"], help="Max ~50 pages for speed. Uses OCR fallback if needed.")
max_pages = st.slider("Max pages to parse", 5, 100, 50, help="Lower = faster")
do_ocr = st.toggle("Enable OCR fallback (slower)", value=False)
chunk_size = st.slider("Chunk size", 300, 1400, 700, step=50)
overlap = st.slider("Chunk overlap", 0, 300, 120, step=10)
colA, colB = st.columns(2)
with colA:
if st.button("⚙️ Build Index", use_container_width=True, type="primary"):
if not file:
st.warning("Please upload a PDF first.")
else:
with st.spinner("Reading & indexing PDF…"):
data = file.read()
text, code_blocks = extract_text_and_code_from_pdf(data, ocr_fallback=do_ocr, max_pages=max_pages)
st.session_state.pdf_text = text
st.session_state.code_blocks = code_blocks
if not text.strip():
st.error("Couldn't extract any text from the PDF.")
else:
chunks = chunk_text(text, chunk_size=chunk_size, chunk_overlap=overlap)
st.session_state.idx = build_indexes(chunks)
st.success(f"Indexed {len(chunks)} chunks. Ready!")
with colB:
if st.button("🧹 Clear", use_container_width=True):
st.session_state.idx = None
st.session_state.pdf_text = ""
st.session_state.code_blocks = []
st.experimental_rerun()
if st.session_state.code_blocks:
st.caption("Detected code blocks. You can copy or download from the Summary tab.")
# Main area — two sections exactly: Chat & Summary
chat_tab, summary_tab = st.tabs(["💬 Chat", "📝 Summary (with Code)"])
with chat_tab:
st.markdown("<div class='card'>Ask questions about your PDF. Retrieval-augmented answers use only the document context.</div>", unsafe_allow_html=True)
if st.session_state.idx is None:
st.info("Upload a PDF and click **Build Index** in the sidebar.")
else:
user_q = st.chat_input("Ask anything about the PDF…")
if "chat" not in st.session_state:
st.session_state.chat = []
# Render history
for role, content in st.session_state.get("chat", []):
with st.chat_message(role):
st.markdown(content)
if user_q:
st.session_state.chat.append(("user", user_q))
with st.chat_message("user"):
st.markdown(user_q)
with st.chat_message("assistant"):
with st.spinner("Thinking…"):
try:
ans = rag_answer(user_q, st.session_state.idx)
except Exception as e:
ans = f"Sorry, I hit an error while answering: {e}"
st.markdown(ans)
st.session_state.chat.append(("assistant", ans))
with summary_tab:
st.markdown("<div class='card'>One-click concise summary of the entire document, plus extracted programming code if detected.</div>", unsafe_allow_html=True)
col1, col2 = st.columns([1,1])
with col1:
if st.button("🔎 Summarize PDF", type="primary", use_container_width=True):
if not st.session_state.pdf_text.strip():
st.warning("No parsed text yet. Upload & Build Index first.")
else:
with st.spinner("Summarizing…"):
try:
sm = summarize_text(st.session_state.pdf_text)
st.session_state.summary = sm
st.success("Summary generated.")
except Exception as e:
st.error(f"Summarization failed: {e}")
with col2:
if st.session_state.pdf_text:
st.download_button(
"⬇️ Download raw extracted text",
st.session_state.pdf_text,
file_name="extracted_text.txt",
use_container_width=True,
)
if st.session_state.get("summary"):
st.subheader("Summary")
st.write(st.session_state.summary)
st.divider()
st.subheader("Extracted Code")
if st.session_state.code_blocks:
for i, blk in enumerate(st.session_state.code_blocks, start=1):
with st.expander(f"Code block #{i}"):
st.code(blk, language=None)
st.download_button(
f"Download code #{i}",
blk,
file_name=f"code_block_{i}.txt",
key=f"dl_{i}",
)
all_code = "\n\n\n".join(st.session_state.code_blocks)
st.download_button("⬇️ Download all code", all_code, file_name="all_code.txt")
else:
st.caption("No code-like content detected yet.")
# Footer tips
st.markdown(
"""
<div class="muted" style="margin-top:24px">⚡ Tips for faster responses: use smaller PDFs, lower the "Max pages" and "Chunk size" in the sidebar, and keep OCR off unless needed.</div>
""",
unsafe_allow_html=True,
)