Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| from langchain_experimental.graph_transformers import LLMGraphTransformer | |
| from langchain.chains import GraphQAChain | |
| from langchain_core.documents import Document | |
| from langchain_community.graphs.networkx_graph import NetworkxEntityGraph | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_groq import ChatGroq | |
| import pandas as pd | |
| from gradio_client import Client | |
| import numpy as np | |
| from PIL import Image as PILImage | |
| import base64 | |
| from io import BytesIO | |
| # Set the base directory | |
| BASE_DIR = os.getcwd() | |
| GROQ_API_KEY = os.environ.get('GROQ_API_KEY') | |
| # Set up LLM and Flux client | |
| llm = ChatGroq(temperature=0, model_name='llama-3.1-8b-instant', groq_api_key=GROQ_API_KEY) | |
| flux_client = Client("black-forest-labs/Flux.1-schnell") | |
| def create_graph(text): | |
| documents = [Document(page_content=text)] | |
| llm_transformer_filtered = LLMGraphTransformer(llm=llm) | |
| graph_documents_filtered = llm_transformer_filtered.convert_to_graph_documents(documents) | |
| graph = NetworkxEntityGraph() | |
| for node in graph_documents_filtered[0].nodes: | |
| graph.add_node(node.id) | |
| for edge in graph_documents_filtered[0].relationships: | |
| graph._graph.add_edge( | |
| edge.source.id, | |
| edge.target.id, | |
| relation=edge.type | |
| ) | |
| return graph, graph_documents_filtered | |
| def visualize_graph(graph): | |
| plt.figure(figsize=(12, 8)) | |
| pos = nx.spring_layout(graph._graph) | |
| nx.draw(graph._graph, pos, with_labels=True, node_color='lightblue', node_size=500, font_size=8, font_weight='bold') | |
| edge_labels = nx.get_edge_attributes(graph._graph, 'relation') | |
| nx.draw_networkx_edge_labels(graph._graph, pos, edge_labels=edge_labels, font_size=6) | |
| plt.title("Graph Visualization") | |
| plt.axis('off') | |
| # Save the plot as an image file | |
| graph_viz_path = os.path.join(BASE_DIR, 'graph_visualization.png') | |
| plt.savefig(graph_viz_path) | |
| plt.close() | |
| return graph_viz_path | |
| def generate_image(prompt): | |
| try: | |
| print(f"Generating image with prompt: {prompt}") | |
| result = flux_client.predict( | |
| prompt=prompt, | |
| seed=0, | |
| randomize_seed=True, | |
| width=1024, | |
| height=1024, | |
| num_inference_steps=4, | |
| api_name="/infer" | |
| ) | |
| if isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], str): | |
| img_str = result[0] | |
| img_str += '=' * (-len(img_str) % 4) | |
| img_data = base64.b64decode(img_str) | |
| image = PILImage.open(BytesIO(img_data)) | |
| elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], np.ndarray): | |
| image = PILImage.fromarray((result[0] * 255).astype(np.uint8)) | |
| elif isinstance(result, PILImage.Image): | |
| image = result | |
| else: | |
| raise ValueError(f"Unexpected result format from flux_client.predict: {type(result)}") | |
| image_path = os.path.join(BASE_DIR, 'generated_image.png') | |
| image.save(image_path) | |
| print(f"Image saved to: {image_path}") | |
| return image_path | |
| except Exception as e: | |
| print(f"Error in generate_image: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def process_text(text, question): | |
| try: | |
| print("Creating graph...") | |
| graph, graph_documents_filtered = create_graph(text) | |
| print("Setting up GraphQAChain...") | |
| graph_rag = GraphQAChain.from_llm( | |
| llm=llm, | |
| graph=graph, | |
| verbose=True | |
| ) | |
| print("Running question through GraphQAChain...") | |
| answer = graph_rag.run(question) | |
| print(f"Answer: {answer}") | |
| print("Visualizing graph...") | |
| graph_viz_path = visualize_graph(graph) | |
| print(f"Graph visualization saved to: {graph_viz_path}") | |
| print("Generating summary...") | |
| summary_prompt = f"Summarize the following text in one sentence: {text}" | |
| summary = llm.invoke(summary_prompt).content | |
| print(f"Summary: {summary}") | |
| print("Generating image...") | |
| image_path = generate_image(summary) | |
| if image_path and os.path.exists(image_path): | |
| print(f"Generated image saved to: {image_path}") | |
| else: | |
| print("Failed to generate or save image") | |
| return answer, graph_viz_path, summary, image_path | |
| except Exception as e: | |
| print(f"An error occurred in process_text: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return str(e), None, str(e), None | |
| def ui_function(text, question): | |
| answer, graph_viz_path, summary, image_path = process_text(text, question) | |
| if isinstance(answer, str) and answer.startswith("An error occurred"): | |
| return answer, None, answer, None | |
| return answer, graph_viz_path, summary, image_path | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=ui_function, | |
| inputs=[ | |
| gr.Textbox(label="Input Text"), | |
| gr.Textbox(label="Question") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Answer"), | |
| gr.Image(label="Graph Visualization", type="filepath"), | |
| gr.Textbox(label="Summary"), | |
| gr.Image(label="Generated Image", type="filepath") | |
| ], | |
| title="GraphRAG and Image Generation UI", | |
| description="Enter text to create a graph, ask a question, and generate a relevant image." | |
| ) | |
| iface.launch() |