Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import logging | |
| import os | |
| from io import BytesIO | |
| import re | |
| import time | |
| from typing import List, Tuple, Optional | |
| import pdfplumber | |
| # Optional OCR (guarded) | |
| try: | |
| import pytesseract | |
| OCR_AVAILABLE = True | |
| except Exception: | |
| OCR_AVAILABLE = False | |
| from rank_bm25 import BM25Okapi | |
| # Embeddings + Vector store | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| try: | |
| import faiss # direct FAISS for speed and control | |
| FAISS_OK = True | |
| except Exception: | |
| FAISS_OK = False | |
| # Lightweight HF pipelines | |
| from transformers import pipeline | |
| # ---------------------------- | |
| # App & Logging Setup | |
| # ---------------------------- | |
| st.set_page_config(page_title="Smart PDF Chat & Summarizer", page_icon="📄", layout="wide") | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger("smart_pdf") | |
| # ---------------------------- | |
| # Caching: models & utilities | |
| # ---------------------------- | |
| def get_embedder(name: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
| return SentenceTransformer(name) | |
| def get_qa_pipeline(): | |
| # Small, fast instruction model | |
| return pipeline( | |
| "text2text-generation", | |
| model="google/flan-t5-small", | |
| device=-1, | |
| max_length=220 | |
| ) | |
| def get_summarizer(): | |
| # DistilBART is much faster than bart-large-cnn | |
| return pipeline( | |
| "summarization", | |
| model="sshleifer/distilbart-cnn-12-6", | |
| device=-1, | |
| max_length=220, | |
| min_length=80, | |
| do_sample=False, | |
| ) | |
| # ---------------------------- | |
| # PDF processing | |
| # ---------------------------- | |
| def _looks_like_code(line: str) -> bool: | |
| if len(line.strip()) == 0: | |
| return False | |
| # Heuristics for code-y lines | |
| code_tokens = [ | |
| r"\b(def|class|import|from|return|if|elif|else|for|while|try|except|finally|with)\b", | |
| r"[{}`;<>]|::|=>|#|//|/\*|\*/", | |
| r"\(|\)|\[|\]|\{|\}", | |
| ] | |
| matches = sum(bool(re.search(p, line)) for p in code_tokens) | |
| indent = len(line) - len(line.lstrip()) | |
| return matches >= 1 or indent >= 4 | |
| def extract_text_and_code_from_pdf(file_bytes: bytes, ocr_fallback: bool = True, max_pages: int = 50) -> Tuple[str, List[str]]: | |
| """Return (plain_text, code_blocks[]) from a PDF with simple OCR fallback.""" | |
| text_parts: List[str] = [] | |
| code_lines: List[str] = [] | |
| with pdfplumber.open(BytesIO(file_bytes)) as pdf: | |
| pages = pdf.pages[:max_pages] | |
| for page in pages: | |
| # 1) Try text extraction | |
| extracted = page.extract_text(x_tolerance=1.5, y_tolerance=1.0) or "" | |
| # 2) OCR fallback if page empty and OCR available | |
| if not extracted.strip() and ocr_fallback and OCR_AVAILABLE: | |
| try: | |
| img = page.to_image(resolution=180).original | |
| extracted = pytesseract.image_to_string(img, config='--psm 6') or "" | |
| except Exception as e: | |
| logger.warning(f"OCR failed on a page: {e}") | |
| # 3) Clean and collect | |
| if extracted: | |
| # Remove common headers/footers by simple rules | |
| lines = [ln for ln in extracted.splitlines() if not re.match(r"^(Page\s*\d+|Copyright.*)$", ln, flags=re.I)] | |
| text_parts.append("\n".join(lines)) | |
| # Code detection: fenced blocks first | |
| fenced = re.findall(r"```[\w-]*\n([\s\S]*?)```", extracted, flags=re.M) | |
| for blk in fenced: | |
| blk = blk.strip() | |
| if blk: | |
| code_lines.append(blk) | |
| # Otherwise, line-wise heuristic | |
| for ln in lines: | |
| if _looks_like_code(ln): | |
| code_lines.append(ln) | |
| # 4) Tables -> pipe-separated rows | |
| try: | |
| tables = page.extract_tables() or [] | |
| for tb in tables: | |
| for row in tb: | |
| if row and any(str(c).strip() for c in row): | |
| text_parts.append(" | ".join(str(c).strip() for c in row)) | |
| except Exception: | |
| pass | |
| full_text = "\n\n".join(tp for tp in text_parts if tp.strip()) | |
| # Merge adjacent code lines into blocks | |
| code_blocks: List[str] = [] | |
| if code_lines: | |
| current: List[str] = [] | |
| for ln in code_lines: | |
| if ln.strip(): | |
| current.append(ln) | |
| else: | |
| if current: | |
| code_blocks.append("\n".join(current)) | |
| current = [] | |
| if current: | |
| code_blocks.append("\n".join(current)) | |
| # Deduplicate & trim giant blocks | |
| seen = set() | |
| unique_blocks = [] | |
| for blk in code_blocks: | |
| key = blk.strip() | |
| if key and key not in seen: | |
| seen.add(key) | |
| # cap extreme long blocks for UI; still allow download of full | |
| unique_blocks.append(blk[:8000]) | |
| return full_text, unique_blocks | |
| # ---------------------------- | |
| # Chunking & Indexing | |
| # ---------------------------- | |
| def chunk_text(text: str, chunk_size: int = 700, chunk_overlap: int = 120) -> List[str]: | |
| text = re.sub(r"\n{3,}", "\n\n", text).strip() | |
| paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()] | |
| chunks: List[str] = [] | |
| buf: str = "" | |
| for para in paras: | |
| if not buf: | |
| buf = para | |
| elif len(buf) + len(para) + 1 <= chunk_size: | |
| buf += "\n" + para | |
| else: | |
| chunks.append(buf) | |
| # overlap | |
| overlap = buf[-chunk_overlap:] if chunk_overlap > 0 else "" | |
| buf = (overlap + "\n" + para).strip() | |
| if buf: | |
| chunks.append(buf) | |
| return chunks | |
| def build_indexes(chunks: List[str]): | |
| embedder = get_embedder() | |
| matrix = embedder.encode(chunks, show_progress_bar=False, batch_size=64, normalize_embeddings=True) | |
| matrix = np.asarray(matrix).astype('float32') | |
| bm25 = BM25Okapi([c.split() for c in chunks]) | |
| if FAISS_OK: | |
| index = faiss.IndexFlatIP(matrix.shape[1]) | |
| index.add(matrix) | |
| return { | |
| "chunks": chunks, | |
| "embeddings": matrix, | |
| "faiss": index, | |
| "bm25": bm25, | |
| } | |
| else: | |
| # Fallback: cosine via numpy (slower but OK for small docs) | |
| return { | |
| "chunks": chunks, | |
| "embeddings": matrix, | |
| "faiss": None, | |
| "bm25": bm25, | |
| } | |
| # ---------------------------- | |
| # Retrieval + QA | |
| # ---------------------------- | |
| def retrieve(topk: int, query: str, idx): | |
| chunks = idx["chunks"] | |
| embeddings = idx["embeddings"] | |
| bm25 = idx["bm25"] | |
| # BM25 | |
| bm25_docs = bm25.get_top_n(query.split(), chunks, n=min(topk, len(chunks))) | |
| # FAISS / cosine | |
| embedder = get_embedder() | |
| qv = embedder.encode([query], normalize_embeddings=True)[0].astype('float32') | |
| if idx["faiss"] is not None: | |
| D, I = idx["faiss"].search(np.array([qv]), min(topk, len(chunks))) | |
| faiss_docs = [chunks[i] for i in I[0]] | |
| else: | |
| # cosine with numpy | |
| sims = embeddings @ qv | |
| order = np.argsort(-sims)[:topk] | |
| faiss_docs = [chunks[i] for i in order] | |
| # Merge uniques with preference to BM25 then FAISS | |
| merged: List[str] = [] | |
| seen = set() | |
| for c in bm25_docs + faiss_docs: | |
| if c not in seen: | |
| merged.append(c) | |
| seen.add(c) | |
| if len(merged) >= topk: | |
| break | |
| return merged | |
| def rag_answer(query: str, idx, max_ctx_chars: int = 3000) -> str: | |
| ctx_chunks = retrieve(6, query, idx) | |
| # Concatenate up to a char budget | |
| ctx = "\n\n".join(ctx_chunks) | |
| if len(ctx) > max_ctx_chars: | |
| ctx = ctx[:max_ctx_chars] | |
| qa = get_qa_pipeline() | |
| prompt = ( | |
| "Answer the question using ONLY the provided context. " | |
| "If the answer is not in the context, say 'I couldn't find that in the PDF.'\n\n" | |
| f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:" | |
| ) | |
| out = qa(prompt)[0]["generated_text"].strip() | |
| return out | |
| def summarize_text(full_text: str) -> str: | |
| summarizer = get_summarizer() | |
| # Summarize in parts for long docs | |
| chunks = chunk_text(full_text, chunk_size=1200, chunk_overlap=150) | |
| partials = [] | |
| for ch in chunks[:8]: # cap to keep it snappy on CPU | |
| partials.append(summarizer(ch)[0]["summary_text"].strip()) | |
| # Final stitch summary | |
| stitched = " ".join(partials) | |
| if len(stitched) > 2000: | |
| stitched = summarizer(stitched[:3000])[0]["summary_text"].strip() | |
| return stitched | |
| # ---------------------------- | |
| # UI | |
| # ---------------------------- | |
| st.markdown( | |
| """ | |
| <style> | |
| .app-header {background: linear-gradient(90deg,#10b981,#22c55e); color: white; padding: 16px; border-radius: 14px; text-align:center; box-shadow: 0 6px 20px rgba(16,185,129,.25)} | |
| .card {border:1px solid #e5e7eb; border-radius: 14px; padding: 16px; background: #fff} | |
| .muted {color:#6b7280} | |
| .kbd {background:#f3f4f6; border:1px solid #e5e7eb; border-radius:6px; padding:2px 6px; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco} | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown('<div class="app-header"><h1>📄 Smart PDF Chat & Summarizer</h1><p class="muted">Fast answers, focused summaries, and automatic code extraction</p></div>', unsafe_allow_html=True) | |
| # Session state | |
| if "idx" not in st.session_state: | |
| st.session_state.idx = None | |
| if "pdf_text" not in st.session_state: | |
| st.session_state.pdf_text = "" | |
| if "code_blocks" not in st.session_state: | |
| st.session_state.code_blocks = [] | |
| # Sidebar | |
| with st.sidebar: | |
| st.subheader("Upload & Options") | |
| file = st.file_uploader("Upload a PDF", type=["pdf"], help="Max ~50 pages for speed. Uses OCR fallback if needed.") | |
| max_pages = st.slider("Max pages to parse", 5, 100, 50, help="Lower = faster") | |
| do_ocr = st.toggle("Enable OCR fallback (slower)", value=False) | |
| chunk_size = st.slider("Chunk size", 300, 1400, 700, step=50) | |
| overlap = st.slider("Chunk overlap", 0, 300, 120, step=10) | |
| colA, colB = st.columns(2) | |
| with colA: | |
| if st.button("⚙️ Build Index", use_container_width=True, type="primary"): | |
| if not file: | |
| st.warning("Please upload a PDF first.") | |
| else: | |
| with st.spinner("Reading & indexing PDF…"): | |
| data = file.read() | |
| text, code_blocks = extract_text_and_code_from_pdf(data, ocr_fallback=do_ocr, max_pages=max_pages) | |
| st.session_state.pdf_text = text | |
| st.session_state.code_blocks = code_blocks | |
| if not text.strip(): | |
| st.error("Couldn't extract any text from the PDF.") | |
| else: | |
| chunks = chunk_text(text, chunk_size=chunk_size, chunk_overlap=overlap) | |
| st.session_state.idx = build_indexes(chunks) | |
| st.success(f"Indexed {len(chunks)} chunks. Ready!") | |
| with colB: | |
| if st.button("🧹 Clear", use_container_width=True): | |
| st.session_state.idx = None | |
| st.session_state.pdf_text = "" | |
| st.session_state.code_blocks = [] | |
| st.experimental_rerun() | |
| if st.session_state.code_blocks: | |
| st.caption("Detected code blocks. You can copy or download from the Summary tab.") | |
| # Main area — two sections exactly: Chat & Summary | |
| chat_tab, summary_tab = st.tabs(["💬 Chat", "📝 Summary (with Code)"]) | |
| with chat_tab: | |
| st.markdown("<div class='card'>Ask questions about your PDF. Retrieval-augmented answers use only the document context.</div>", unsafe_allow_html=True) | |
| if st.session_state.idx is None: | |
| st.info("Upload a PDF and click **Build Index** in the sidebar.") | |
| else: | |
| user_q = st.chat_input("Ask anything about the PDF…") | |
| if "chat" not in st.session_state: | |
| st.session_state.chat = [] | |
| # Render history | |
| for role, content in st.session_state.get("chat", []): | |
| with st.chat_message(role): | |
| st.markdown(content) | |
| if user_q: | |
| st.session_state.chat.append(("user", user_q)) | |
| with st.chat_message("user"): | |
| st.markdown(user_q) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking…"): | |
| try: | |
| ans = rag_answer(user_q, st.session_state.idx) | |
| except Exception as e: | |
| ans = f"Sorry, I hit an error while answering: {e}" | |
| st.markdown(ans) | |
| st.session_state.chat.append(("assistant", ans)) | |
| with summary_tab: | |
| st.markdown("<div class='card'>One-click concise summary of the entire document, plus extracted programming code if detected.</div>", unsafe_allow_html=True) | |
| col1, col2 = st.columns([1,1]) | |
| with col1: | |
| if st.button("🔎 Summarize PDF", type="primary", use_container_width=True): | |
| if not st.session_state.pdf_text.strip(): | |
| st.warning("No parsed text yet. Upload & Build Index first.") | |
| else: | |
| with st.spinner("Summarizing…"): | |
| try: | |
| sm = summarize_text(st.session_state.pdf_text) | |
| st.session_state.summary = sm | |
| st.success("Summary generated.") | |
| except Exception as e: | |
| st.error(f"Summarization failed: {e}") | |
| with col2: | |
| if st.session_state.pdf_text: | |
| st.download_button( | |
| "⬇️ Download raw extracted text", | |
| st.session_state.pdf_text, | |
| file_name="extracted_text.txt", | |
| use_container_width=True, | |
| ) | |
| if st.session_state.get("summary"): | |
| st.subheader("Summary") | |
| st.write(st.session_state.summary) | |
| st.divider() | |
| st.subheader("Extracted Code") | |
| if st.session_state.code_blocks: | |
| for i, blk in enumerate(st.session_state.code_blocks, start=1): | |
| with st.expander(f"Code block #{i}"): | |
| st.code(blk, language=None) | |
| st.download_button( | |
| f"Download code #{i}", | |
| blk, | |
| file_name=f"code_block_{i}.txt", | |
| key=f"dl_{i}", | |
| ) | |
| all_code = "\n\n\n".join(st.session_state.code_blocks) | |
| st.download_button("⬇️ Download all code", all_code, file_name="all_code.txt") | |
| else: | |
| st.caption("No code-like content detected yet.") | |
| # Footer tips | |
| st.markdown( | |
| """ | |
| <div class="muted" style="margin-top:24px">⚡ Tips for faster responses: use smaller PDFs, lower the "Max pages" and "Chunk size" in the sidebar, and keep OCR off unless needed.</div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |