import os import gradio as gr import networkx as nx import matplotlib matplotlib.use('Agg') # Use non-interactive backend import matplotlib.pyplot as plt from langchain_experimental.graph_transformers import LLMGraphTransformer from langchain_core.documents import Document from langchain_community.graphs.networkx_graph import NetworkxEntityGraph from langchain_groq import ChatGroq import pandas as pd # Set the base directory BASE_DIR = os.getcwd() GROQ_API_KEY = os.environ.get('GROQ_API_KEY') # Set up LLM llm = ChatGroq(temperature=0, model_name='llama-3.1-8b-instant', groq_api_key=GROQ_API_KEY) def create_graph(text): """Create a knowledge graph from text using LLM""" documents = [Document(page_content=text)] llm_transformer = LLMGraphTransformer(llm=llm) graph_documents = llm_transformer.convert_to_graph_documents(documents) graph = NetworkxEntityGraph() # Add nodes for node in graph_documents[0].nodes: graph.add_node(node.id) # Add edges with relationships for edge in graph_documents[0].relationships: graph._graph.add_edge( edge.source.id, edge.target.id, relation=edge.type ) return graph, graph_documents def visualize_graph(graph): """Generate graph visualization""" plt.figure(figsize=(12, 8)) pos = nx.spring_layout(graph._graph, k=0.5, iterations=50) # Draw nodes and edges nx.draw_networkx_nodes(graph._graph, pos, node_color='lightblue', node_size=1500, alpha=0.9) nx.draw_networkx_labels(graph._graph, pos, font_size=10, font_weight='bold') nx.draw_networkx_edges(graph._graph, pos, edge_color='gray', arrows=True, arrowsize=20, width=2) # Draw edge labels edge_labels = nx.get_edge_attributes(graph._graph, 'relation') nx.draw_networkx_edge_labels(graph._graph, pos, edge_labels=edge_labels, font_size=8) plt.title("Knowledge Graph Visualization", fontsize=16, fontweight='bold') plt.axis('off') plt.tight_layout() # Save the plot graph_viz_path = os.path.join(BASE_DIR, 'graph_visualization.png') plt.savefig(graph_viz_path, dpi=150, bbox_inches='tight') plt.close() return graph_viz_path def create_relations_table(graph_documents): """Create a DataFrame of all relationships in the graph""" relations_data = [] for edge in graph_documents[0].relationships: relations_data.append({ 'Source': edge.source.id, 'Relation': edge.type, 'Target': edge.target.id }) return pd.DataFrame(relations_data) def query_graph(graph, question): """Query the graph using the LLM""" # Get all nodes and edges as context nodes = list(graph._graph.nodes()) edges = [(u, v, graph._graph[u][v].get('relation', 'related_to')) for u, v in graph._graph.edges()] context = f"Nodes: {', '.join(nodes)}\n" context += "Relationships:\n" for source, target, relation in edges: context += f"- {source} --[{relation}]--> {target}\n" prompt = f"""Based on the following knowledge graph, answer the question. Knowledge Graph: {context} Question: {question} Please provide a clear and concise answer based on the information in the graph.""" response = llm.invoke(prompt) return response.content def generate_summary(text): """Generate a one-sentence summary of the text""" prompt = f"""Provide a single sentence summary of the following text. Be concise and capture the main point. Text: {text} Summary:""" response = llm.invoke(prompt) return response.content.strip() def process_text(text, question): """Main processing function""" try: if not text or not text.strip(): return "Please enter some text to analyze.", None, None, "" if not question or not question.strip(): return "Please enter a question.", None, None, "" print("Creating graph...") graph, graph_documents = create_graph(text) if len(list(graph._graph.nodes())) == 0: return "No entities found in the text. Try with more detailed content.", None, None, "" print("Querying graph...") answer = query_graph(graph, question) print("Visualizing graph...") graph_viz_path = visualize_graph(graph) print("Creating relations table...") relations_table = create_relations_table(graph_documents) print("Generating summary...") summary = generate_summary(text) return answer, graph_viz_path, relations_table, summary except Exception as e: print(f"Error: {str(e)}") import traceback traceback.print_exc() error_msg = f"An error occurred: {str(e)}" return error_msg, None, pd.DataFrame(), error_msg # Example text example_text = """The Apollo 11 mission, launched by NASA in July 1969, was the first manned mission to land on the Moon. Commanded by Neil Armstrong and piloted by Buzz Aldrin and Michael Collins, it was the culmination of the Space Race between the United States and the Soviet Union. On July 20, 1969, Armstrong and Aldrin became the first humans to set foot on the lunar surface, while Collins orbited above in the command module.""" # Create Gradio interface with gr.Blocks(title="Knowledge Graph RAG") as iface: gr.Markdown("# 🔍 Knowledge Graph RAG Application") gr.Markdown("Extract entities and relationships from text, then query the knowledge graph.") with gr.Row(): with gr.Column(): input_text = gr.Textbox( label="Input Text", lines=8, placeholder="Enter your text here...", value=example_text ) question = gr.Textbox( label="Question", lines=2, placeholder="Ask a question about the text...", value="Who were the astronauts on Apollo 11?" ) run_btn = gr.Button("🚀 Analyze", variant="primary", size="lg") with gr.Column(): answer = gr.Textbox( label="Answer", lines=4, placeholder="The answer will appear here." ) graph_viz = gr.Image( label="Graph Visualization", type="filepath" ) with gr.Row(): relations_table = gr.Dataframe( label="Extracted Relationships", headers=["Source", "Relation", "Target"], interactive=False ) with gr.Row(): summary = gr.Textbox( label="Summary", lines=2, placeholder="Summary will appear here." ) gr.Markdown(""" ### 📖 How to Use: 1. **Input Text:** Enter or modify the text you want to analyze 2. **Question:** Ask a question about the content 3. **Analyze:** Click the button to process 4. **Results:** View the answer, graph visualization, relationships, and summary 💡 **Tip:** The example text is pre-loaded. Try it out! """) run_btn.click( fn=process_text, inputs=[input_text, question], outputs=[answer, graph_viz, relations_table, summary] ) gr.Markdown(""" --- **Created with LangChain + Groq + NetworkX** | [LinkedIn](https://www.linkedin.com/in/girish-wangikar/) | [Portfolio](https://girishwangikar.github.io/Girish_Wangikar_Portfolio.github.io/) """) if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860)