Spaces:
Sleeping
Sleeping
| """ | |
| Multi-agent survey analysis system using LangGraph. | |
| Simplified implementation with integrated research brief and response synthesizer. | |
| """ | |
| import os | |
| from typing import TypedDict, Literal, Annotated, List, Dict, Any, Optional | |
| from pathlib import Path | |
| import operator | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage | |
| from pydantic import BaseModel, Field, ConfigDict | |
| # Local RAG system imports | |
| from questionnaire_rag import QuestionnaireRAG | |
| from toplines_rag import ToplinesRAG | |
| from crosstab_rag import CrosstabsRAG, CrosstabSummarizer | |
| from relevance_checker import ConversationRelevanceChecker | |
| from langchain.schema import Document | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| def _load_prompt_file(filename: str) -> str: | |
| """Load a prompt file from the prompts directory""" | |
| prompt_dir = Path(__file__).parent / "prompts" | |
| prompt_path = prompt_dir / filename | |
| if not prompt_path.exists(): | |
| raise FileNotFoundError(f"Prompt file not found: {prompt_path}") | |
| return prompt_path.read_text(encoding="utf-8") | |
| # ============================================================================ | |
| # STATE DEFINITIONS | |
| # ============================================================================ | |
| class QuestionInfo(BaseModel): | |
| """Structured question information""" | |
| variable_name: str | |
| year: Optional[int] = None | |
| month: Optional[str] = None | |
| poll_date: Optional[str] = None | |
| question_id: Optional[str] = None | |
| class QueryFilters(BaseModel): | |
| """Filters for data source queries""" | |
| model_config = ConfigDict(extra="forbid") | |
| year: Optional[int] = None | |
| month: Optional[str] = None | |
| poll_date: Optional[str] = None | |
| survey_name: Optional[str] = None | |
| topic: Optional[str] = None | |
| question_ids: Optional[List[str]] = None | |
| class DataSource(BaseModel): | |
| """Single data source pipeline to query""" | |
| model_config = ConfigDict(extra="forbid") | |
| source_type: Literal["questionnaire", "toplines", "crosstabs"] | |
| query_description: str | |
| filters: QueryFilters = Field(default_factory=QueryFilters) | |
| result_label: Optional[str] = None | |
| class ResearchStage(BaseModel): | |
| """Single stage in multi-stage research""" | |
| model_config = ConfigDict(extra="forbid") | |
| stage_number: int | |
| description: str | |
| data_sources: List[DataSource] | |
| depends_on_stages: List[int] = Field(default_factory=list) | |
| use_previous_results_for: Optional[str] = None | |
| class ResearchBrief(BaseModel): | |
| """Research brief defining how to answer the question""" | |
| model_config = ConfigDict(extra="forbid") | |
| action: Literal["answer", "followup", "route_to_sources", "execute_stages"] | |
| followup_question: Optional[str] = None | |
| reasoning: str | |
| data_sources: List[DataSource] = Field(default_factory=list) | |
| stages: List[ResearchStage] = Field(default_factory=list) | |
| class StageResult(BaseModel): | |
| """Results from executing a single research stage""" | |
| model_config = ConfigDict(extra="forbid") | |
| stage_number: int | |
| status: Literal["success", "partial", "failed"] | |
| questionnaire_results: Optional[Dict[str, Any]] = None | |
| toplines_results: Optional[Dict[str, Any]] = None | |
| crosstabs_results: Optional[Dict[str, Any]] = None | |
| extracted_context: Optional[Dict[str, Any]] = None | |
| # Metadata for better reusability tracking | |
| query_type: Optional[str] = None # "questionnaire", "toplines", "crosstabs" | |
| time_period: Optional[Dict[str, Any]] = None # {"year": 2025, "month": "June"} | |
| topics: Optional[List[str]] = None # Topics queried | |
| class SurveyAnalysisState(TypedDict): | |
| """State dictionary for the survey analysis agent workflow""" | |
| messages: Annotated[List, operator.add] | |
| user_question: str | |
| research_brief: Optional[ResearchBrief] | |
| current_stage: int | |
| stage_results: List[StageResult] | |
| final_answer: Optional[str] | |
| # ============================================================================ | |
| # SURVEY ANALYSIS AGENT | |
| # ============================================================================ | |
| class SurveyAnalysisAgent: | |
| """Multi-agent system for analyzing survey data using LangGraph.""" | |
| def __init__( | |
| self, | |
| openai_api_key: str, | |
| pinecone_api_key: str, | |
| questionnaire_persist_dir: Optional[str] = None, | |
| verbose: bool = True | |
| ): | |
| self.openai_api_key = openai_api_key | |
| self.pinecone_api_key = pinecone_api_key | |
| self.verbose = verbose | |
| # Initialize LLM | |
| self.llm = ChatOpenAI( | |
| model=os.getenv("OPENAI_MODEL", "gpt-4o"), | |
| temperature=0 | |
| ) | |
| # Initialize RAG systems | |
| if self.verbose: | |
| print("Initializing RAG systems...") | |
| # Default to parent directory if not specified | |
| if questionnaire_persist_dir is None: | |
| # Try current directory first, then parent | |
| current_dir = Path("./questionnaire_vectorstores") | |
| parent_dir = Path("../questionnaire_vectorstores") | |
| if current_dir.exists() and (current_dir / "poll_catalog.json").exists(): | |
| questionnaire_persist_dir = str(current_dir) | |
| elif parent_dir.exists() and (parent_dir / "poll_catalog.json").exists(): | |
| questionnaire_persist_dir = str(parent_dir) | |
| else: | |
| questionnaire_persist_dir = "./questionnaire_vectorstores" | |
| self.questionnaire_rag = QuestionnaireRAG( | |
| openai_api_key=openai_api_key, | |
| pinecone_api_key=pinecone_api_key, | |
| persist_directory=questionnaire_persist_dir, | |
| verbose=verbose | |
| ) | |
| self.toplines_rag = ToplinesRAG(verbose=verbose) | |
| self.crosstab_rag = CrosstabsRAG( | |
| questionnaire_rag=self.questionnaire_rag, | |
| verbose=self.verbose | |
| ) | |
| # Initialize relevance checker | |
| self.relevance_checker = ConversationRelevanceChecker( | |
| llm=self.llm, | |
| verbose=self.verbose | |
| ) | |
| # Build the graph | |
| self.graph = self._build_graph() | |
| if self.verbose: | |
| print("✓ Survey analysis agent initialized") | |
| def _build_graph(self) -> StateGraph: | |
| """Build the LangGraph workflow""" | |
| workflow = StateGraph(SurveyAnalysisState) | |
| # Add workflow nodes | |
| workflow.add_node("generate_research_brief", self._generate_research_brief) | |
| workflow.add_node("execute_stage", self._execute_stage) | |
| workflow.add_node("extract_stage_context", self._extract_stage_context) | |
| workflow.add_node("synthesize_response", self._synthesize_response) | |
| # Start: Always begin with research brief generation | |
| workflow.add_edge(START, "generate_research_brief") | |
| # Route after research brief | |
| workflow.add_conditional_edges( | |
| "generate_research_brief", | |
| self._route_after_brief, | |
| { | |
| "followup": "synthesize_response", | |
| "answer": "synthesize_response", | |
| "execute_stage": "execute_stage" | |
| } | |
| ) | |
| # After stage execution, extract context | |
| workflow.add_edge("execute_stage", "extract_stage_context") | |
| # Route after context extraction | |
| workflow.add_conditional_edges( | |
| "extract_stage_context", | |
| self._route_after_stage, | |
| { | |
| "next_stage": "execute_stage", | |
| "synthesize": "synthesize_response" | |
| } | |
| ) | |
| # End: Always end after synthesis | |
| workflow.add_edge("synthesize_response", END) | |
| # Compile with memory checkpointing | |
| memory = MemorySaver() | |
| return workflow.compile(checkpointer=memory) | |
| def _get_full_question_context(self, state: SurveyAnalysisState) -> str: | |
| """Extract current question from conversation history""" | |
| messages = state.get("messages", []) | |
| human_messages = [msg.content for msg in messages if isinstance(msg, HumanMessage)] | |
| if not human_messages: | |
| return state["user_question"] | |
| return human_messages[-1] | |
| def _combine_query_with_context(self, state: SurveyAnalysisState) -> str: | |
| """Combine user query with follow-up context for semantic search""" | |
| messages = state.get("messages", []) | |
| human_messages = [msg.content for msg in messages if isinstance(msg, HumanMessage)] | |
| if len(human_messages) <= 1: | |
| return state["user_question"] | |
| # Combine last 2 messages if the latest is a short answer | |
| latest = human_messages[-1] | |
| previous = human_messages[-2] if len(human_messages) > 1 else "" | |
| # Check if latest is a short answer (month, year, etc.) | |
| is_short_answer = len(latest.split()) <= 2 and any( | |
| word.lower() in ['june', 'february', 'march', 'april', 'may', 'july', | |
| 'august', 'september', 'october', 'november', 'december', 'january'] | |
| or word.isdigit() | |
| for word in latest.split() | |
| ) | |
| if is_short_answer and previous: | |
| # Combine: "Biden approval" + "June 2025" -> "Biden approval rating in June 2025" | |
| return f"{previous} {latest}" | |
| return latest | |
| def _get_available_surveys_description(self) -> str: | |
| """Get formatted description of available surveys""" | |
| survey_names = self.questionnaire_rag.get_available_survey_names() | |
| if not survey_names: | |
| return "No surveys currently loaded." | |
| lines = ["Available survey names in the system:"] | |
| for name in survey_names: | |
| lines.append(f" - '{name}'") | |
| return "\n".join(lines) | |
| def _get_available_months_description(self) -> str: | |
| """Get formatted description of available months by year""" | |
| month_order = [ | |
| "January", "February", "March", "April", "May", "June", | |
| "July", "August", "September", "October", "November", "December" | |
| ] | |
| catalog = self.questionnaire_rag.poll_catalog | |
| years = {} | |
| for poll_date, info in catalog.items(): | |
| year = info.get("year") | |
| month = info.get("month") | |
| survey = info.get("survey_name") | |
| if year and month and survey == "Vanderbilt_Unity_Poll": | |
| if year not in years: | |
| years[year] = [] | |
| if month not in years[year]: | |
| years[year].append(month) | |
| lines = ["Available polls by year (Vanderbilt Unity Poll):"] | |
| for year in sorted(years.keys()): | |
| months_sorted = sorted( | |
| years[year], | |
| key=lambda m: month_order.index(m) if m in month_order else 999 | |
| ) | |
| months_str = ", ".join(months_sorted) | |
| lines.append(f" {year}: {months_str}") | |
| return "\n".join(lines) | |
| def _generate_research_brief(self, state: SurveyAnalysisState) -> Dict[str, Any]: | |
| """Generate research brief""" | |
| if self.verbose: | |
| print("\n=== GENERATING RESEARCH BRIEF ===") | |
| question = self._get_full_question_context(state) | |
| messages = state.get("messages", []) | |
| human_messages = [msg.content for msg in messages if isinstance(msg, HumanMessage)] | |
| # Check for previously retrieved questions in conversation history | |
| previous_stage_results = state.get("stage_results", []) | |
| previously_retrieved_questions = [] | |
| if previous_stage_results: | |
| for prev_result in previous_stage_results: | |
| if prev_result.questionnaire_results: | |
| q_info = prev_result.questionnaire_results.get("question_info", []) | |
| source_questions = prev_result.questionnaire_results.get("source_questions", []) | |
| if q_info or source_questions: | |
| previously_retrieved_questions.append({ | |
| "question_info": q_info, | |
| "source_questions": source_questions, | |
| "num_questions": len(source_questions) if source_questions else len(q_info) | |
| }) | |
| # Check relevance if there's conversation history | |
| relevance_result = None | |
| if len(human_messages) > 1 and previous_stage_results: | |
| relevance_result = self.relevance_checker.check_relevance( | |
| current_question=question, | |
| conversation_history=messages[-6:], # Last 3 turns (user + assistant) | |
| previous_stage_results=previous_stage_results | |
| ) | |
| # Build conversation context | |
| conversation_context = "" | |
| if len(human_messages) > 1: | |
| conversation_context = "\n\nCONVERSATION HISTORY (for context):\n" | |
| recent_messages = human_messages[-4:-1] if len(human_messages) > 4 else human_messages[:-1] | |
| for i, msg in enumerate(recent_messages, 1): | |
| conversation_context += f" {i}. {msg}\n" | |
| conversation_context += f"\nCurrent question: {question}\n" | |
| # Add relevance analysis if available | |
| if relevance_result: | |
| conversation_context += f"\n{'='*80}\n" | |
| conversation_context += f"🔍 RELEVANCE ANALYSIS:\n" | |
| conversation_context += f"{'='*80}\n" | |
| conversation_context += f"Related to previous: {relevance_result['is_related']}\n" | |
| conversation_context += f"Relation type: {relevance_result['relation_type']}\n" | |
| conversation_context += f"Time period changed: {relevance_result['time_period_changed']}\n" | |
| conversation_context += f"Reasoning: {relevance_result['reasoning']}\n\n" | |
| if relevance_result['is_related']: | |
| reusable = relevance_result['reusable_data'] | |
| conversation_context += f"REUSABLE DATA:\n" | |
| if reusable.get('questions') and previously_retrieved_questions: | |
| num_questions = previously_retrieved_questions[0]['num_questions'] if previously_retrieved_questions else 0 | |
| conversation_context += f" ✅ Questions: {num_questions} question(s) available from previous turn\n" | |
| conversation_context += f" → Use route_to_sources (single-stage) for TOPLINES or CROSSTABS\n" | |
| conversation_context += f" → DO NOT query QUESTIONNAIRE again\n" | |
| if relevance_result['relation_type'] == 'trend_analysis': | |
| conversation_context += f"\n⚠️ CRITICAL: This is an ANALYSIS-ONLY query.\n" | |
| conversation_context += f" - User wants analysis of ALREADY RETRIEVED data\n" | |
| conversation_context += f" - Use action='answer' to synthesize from conversation history\n" | |
| conversation_context += f" - DO NOT retrieve any new data\n" | |
| if relevance_result['time_period_changed']: | |
| conversation_context += f"\n⚠️ Time period changed - treat as NEW QUERY\n" | |
| conversation_context += f" - Previous questions may not exist in new time period\n" | |
| conversation_context += f" - Must query QUESTIONNAIRE for new time period\n" | |
| conversation_context += f"\n" | |
| # Add information about previously retrieved questions | |
| if previously_retrieved_questions: | |
| conversation_context += f"\n🚨 CRITICAL: Questions were ALREADY RETRIEVED in previous conversation turns:\n" | |
| for i, prev_q in enumerate(previously_retrieved_questions, 1): | |
| num_q = prev_q["num_questions"] | |
| q_info = prev_q["question_info"] | |
| if q_info: | |
| # Show sample of question info | |
| sample_vars = [q.get("variable_name", "unknown") for q in q_info[:3]] | |
| sample_vars_str = ", ".join(sample_vars) | |
| if len(q_info) > 3: | |
| sample_vars_str += f" ... and {len(q_info) - 3} more" | |
| conversation_context += f" Previous turn {i}: {num_q} question(s) retrieved (variables: {sample_vars_str})\n" | |
| else: | |
| conversation_context += f" Previous turn {i}: {num_q} question(s) retrieved\n" | |
| conversation_context += f"\n⚠️ IMPORTANT: If the current question references 'these questions' or asks about responses to previously shown questions:\n" | |
| conversation_context += f" - DO NOT query QUESTIONNAIRE pipeline again - questions are already available\n" | |
| conversation_context += f" - Go DIRECTLY to TOPLINES or CROSSTABS using question_info from previous results\n" | |
| conversation_context += f" - Use execute_stages with Stage 1 extracting question_info, Stage 2 querying TOPLINES/CROSSTABS\n" | |
| # Check for short answers | |
| latest = human_messages[-1] | |
| is_short_answer = len(latest.split()) <= 2 and any( | |
| word.lower() in ['june', 'february', 'march', 'april', 'may', 'july', | |
| 'august', 'september', 'october', 'november', 'december', 'january'] | |
| or word.isdigit() | |
| for word in latest.split() | |
| ) | |
| if is_short_answer and len(human_messages) > 1: | |
| original_question = human_messages[-2] | |
| conversation_context += f"\n🚨 IMPORTANT: The current question '{latest}' is a SHORT ANSWER.\n" | |
| conversation_context += f"Original question was: '{original_question}'\n" | |
| conversation_context += f"Combine '{latest}' with the original intent from '{original_question}'.\n" | |
| # Load research brief prompt | |
| system_prompt_template = _load_prompt_file("research_brief_prompt.txt") | |
| system_prompt = system_prompt_template.format( | |
| available_pipelines="Questionnaire: Survey questions\nToplines: Response frequencies\nCrosstabs: Demographic breakdowns", | |
| available_surveys=self._get_available_surveys_description(), | |
| available_months=self._get_available_months_description() | |
| ) | |
| brief_generator = self.llm.with_structured_output(ResearchBrief) | |
| user_prompt = f"User question: {question}\n\nGenerate a research brief." | |
| if conversation_context: | |
| user_prompt = conversation_context + "\n\n" + user_prompt | |
| brief = brief_generator.invoke([ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=user_prompt) | |
| ]) | |
| if self.verbose: | |
| print(f"Action: {brief.action}") | |
| print(f"Reasoning: {brief.reasoning}") | |
| # Clear stage_results if this is a new topic, otherwise preserve for follow-ups | |
| existing_stage_results = state.get("stage_results", []) | |
| # If relevance checker determined this is a new/unrelated topic, clear previous results | |
| if relevance_result and not relevance_result.get('is_related', True): | |
| if self.verbose: | |
| print(f"🔄 Clearing previous stage results (new topic detected)") | |
| existing_stage_results = [] | |
| return { | |
| "research_brief": brief, | |
| "current_stage": 0, | |
| "stage_results": existing_stage_results, # Clear for new topics, preserve for follow-ups | |
| "messages": [AIMessage(content=f"[Research plan: {brief.action}]")] | |
| } | |
| def _route_after_brief(self, state: SurveyAnalysisState) -> str: | |
| """Route after research brief""" | |
| brief = state["research_brief"] | |
| if brief.action in ("followup", "answer"): | |
| return brief.action | |
| return "execute_stage" | |
| def _execute_stage(self, state: SurveyAnalysisState) -> Dict[str, Any]: | |
| """Execute one stage of research""" | |
| brief = state["research_brief"] | |
| current_stage_idx = state.get("current_stage", 0) | |
| previous_stage_results = state.get("stage_results", []) | |
| # Determine execution mode | |
| if brief.action == "route_to_sources": | |
| stage_data_sources = brief.data_sources | |
| stage_desc = "Single-stage retrieval" | |
| # CRITICAL: Enrich with previous results to avoid redundant Questionnaire queries | |
| if previous_stage_results: | |
| # Try to extract question_info from previous results | |
| question_info_available = False | |
| for prev_result in previous_stage_results: | |
| if prev_result.questionnaire_results: | |
| q_info = prev_result.questionnaire_results.get("question_info", []) | |
| if q_info: | |
| question_info_available = True | |
| if self.verbose: | |
| print(f" ✅ Found {len(q_info)} questions from previous turn - will reuse") | |
| break | |
| # If question_info is available and we're querying toplines/crosstabs, enrich | |
| if question_info_available: | |
| for ds in stage_data_sources: | |
| if ds.source_type in ("toplines", "crosstabs"): | |
| if self.verbose: | |
| print(f" 🔄 Enriching {ds.source_type} data source with previous question_info") | |
| stage_data_sources = self._enrich_data_sources_with_context( | |
| stage_data_sources, previous_stage_results, | |
| "Extract question_info from previous conversation results" | |
| ) | |
| break | |
| else: | |
| if self.verbose: | |
| print(f" ⚠️ No question_info available from previous turns") | |
| elif brief.action == "execute_stages": | |
| stage = brief.stages[current_stage_idx] | |
| stage_data_sources = stage.data_sources | |
| stage_desc = stage.description | |
| # Check if Stage 1 is just extraction (no data sources to execute) | |
| if (current_stage_idx == 0 and | |
| stage.use_previous_results_for and | |
| "extract" in stage.use_previous_results_for.lower() and | |
| not stage.data_sources and | |
| previous_stage_results): | |
| # This is an extraction-only stage - skip execution, just extract context | |
| if self.verbose: | |
| print(f"\n=== EXECUTING STAGE {stage.stage_number} (EXTRACTION ONLY) ===") | |
| print(f"Description: {stage.description}") | |
| print(f" Extracting question_info from previous conversation results") | |
| # Extract and return immediately (context extraction happens in _extract_stage_context) | |
| return { | |
| "stage_results": previous_stage_results, # Keep previous results | |
| "current_stage": current_stage_idx + 1 | |
| } | |
| # Enrich with context from previous stages | |
| if stage.use_previous_results_for and previous_stage_results: | |
| stage_data_sources = self._enrich_data_sources_with_context( | |
| stage_data_sources, previous_stage_results, stage.use_previous_results_for | |
| ) | |
| else: | |
| return {} | |
| if self.verbose: | |
| print(f"\n=== EXECUTING STAGE ===") | |
| print(f"Description: {stage_desc}") | |
| # Initialize stage result | |
| stage_result = StageResult( | |
| stage_number=current_stage_idx + 1, | |
| status="success" | |
| ) | |
| # Get combined query for semantic search fallback | |
| combined_query = self._combine_query_with_context(state) | |
| # Execute each data source | |
| for ds in stage_data_sources: | |
| filters_dict = {k: v for k, v in ds.filters.model_dump().items() if v is not None} | |
| if ds.source_type == "questionnaire": | |
| if self.verbose: | |
| print(f"\n📊 [Questionnaire] {ds.query_description}") | |
| result = self.questionnaire_rag.retrieve_raw_data( | |
| question=ds.query_description, | |
| filters=filters_dict if filters_dict else None | |
| ) | |
| stage_result.questionnaire_results = result | |
| # Populate metadata | |
| stage_result.query_type = "questionnaire" | |
| if filters_dict: | |
| stage_result.time_period = { | |
| "year": filters_dict.get("year"), | |
| "month": filters_dict.get("month") | |
| } | |
| if filters_dict.get("topic"): | |
| stage_result.topics = [filters_dict.get("topic")] | |
| elif ds.source_type == "toplines": | |
| if self.verbose: | |
| print(f"\n📊 [Toplines] {ds.query_description}") | |
| # Extract question_info and source_questions | |
| # PRIORITY ORDER: | |
| # 1. First check enriched data on data source (from current execution's Stage 1) | |
| # 2. Then fall back to previous stage results (from conversation history) | |
| question_info = None | |
| source_questions = None | |
| # Check enriched data first (most recent Stage 1 execution) | |
| if hasattr(ds, '_question_info'): | |
| question_info = ds._question_info # type: ignore | |
| if self.verbose and question_info: | |
| print(f" 📋 Using {len(question_info)} questions from enriched data source") | |
| # Fall back to previous questionnaire results if no enriched data | |
| if not question_info and previous_stage_results: | |
| for prev_result in previous_stage_results: | |
| if prev_result.questionnaire_results: | |
| question_info = prev_result.questionnaire_results.get("question_info", []) | |
| source_questions = prev_result.questionnaire_results.get("source_questions", []) | |
| if question_info and self.verbose: | |
| print(f" 📋 Using {len(question_info)} questions from conversation history") | |
| break | |
| result = self.toplines_rag.retrieve_raw_data( | |
| query=combined_query if not question_info else ds.query_description, | |
| question_info=question_info, | |
| source_questions=source_questions, | |
| filters=filters_dict if not question_info else None, | |
| top_k=10 | |
| ) | |
| stage_result.toplines_results = result | |
| # Populate metadata | |
| stage_result.query_type = "toplines" | |
| if filters_dict: | |
| stage_result.time_period = { | |
| "year": filters_dict.get("year"), | |
| "month": filters_dict.get("month") | |
| } | |
| elif ds.source_type == "crosstabs": | |
| if self.verbose: | |
| print(f"\n📊 [Crosstabs] {ds.query_description}") | |
| # Extract question_info and source_questions | |
| # PRIORITY ORDER: | |
| # 1. First check enriched data on data source (from current execution's Stage 1) | |
| # 2. Then fall back to previous stage results (from conversation history) | |
| question_info = None | |
| source_questions = None | |
| # Check enriched data first (most recent Stage 1 execution) | |
| if hasattr(ds, '_question_info'): | |
| question_info = ds._question_info # type: ignore | |
| if self.verbose and question_info: | |
| print(f" 📋 Using {len(question_info)} questions from enriched data source") | |
| # Fall back to previous questionnaire results if no enriched data | |
| if not question_info and previous_stage_results: | |
| for prev_result in previous_stage_results: | |
| if prev_result.questionnaire_results: | |
| question_info = prev_result.questionnaire_results.get("question_info", []) | |
| source_questions = prev_result.questionnaire_results.get("source_questions", []) | |
| if question_info: | |
| if self.verbose: | |
| print(f" 📋 Using {len(question_info)} questions from conversation history") | |
| break | |
| if self.verbose: | |
| if question_info: | |
| print(f" ✅ Will use question_info (skipping Questionnaire)") | |
| else: | |
| print(f" ⚠️ No question_info available, will query Questionnaire") | |
| result = self.crosstab_rag.retrieve_raw_data( | |
| user_query=combined_query if not question_info else ds.query_description, | |
| question_info=question_info, | |
| source_questions=source_questions, | |
| filters=filters_dict # Always pass filters for time period filtering | |
| ) | |
| stage_result.crosstabs_results = result | |
| # Populate metadata | |
| stage_result.query_type = "crosstabs" | |
| if filters_dict: | |
| stage_result.time_period = { | |
| "year": filters_dict.get("year"), | |
| "month": filters_dict.get("month") | |
| } | |
| # Check if stage has data | |
| has_data = ( | |
| (stage_result.questionnaire_results and stage_result.questionnaire_results.get("source_questions")) or | |
| (stage_result.toplines_results and stage_result.toplines_results.get("retrieved_docs")) or | |
| (stage_result.crosstabs_results and stage_result.crosstabs_results.get("matched_variables")) | |
| ) | |
| updated_stage_results = previous_stage_results | |
| if has_data: | |
| updated_stage_results = previous_stage_results + [stage_result] | |
| return { | |
| "stage_results": updated_stage_results, | |
| "current_stage": current_stage_idx + 1 | |
| } | |
| def _enrich_data_sources_with_context( | |
| self, | |
| data_sources: List[DataSource], | |
| previous_results: List[StageResult], | |
| use_instruction: str | |
| ) -> List[DataSource]: | |
| """Enrich data sources with context from previous stages""" | |
| if self.verbose: | |
| print(f" Enriching with context: {use_instruction}") | |
| # Extract question_info from previous results | |
| # IMPORTANT: Prioritize most recent questionnaire results (from current execution) | |
| # over older conversation results to avoid using stale data when time period changes | |
| question_info_list = [] | |
| # First, check if the most recent stage has questionnaire results | |
| # If so, use ONLY those (this handles time period changes correctly) | |
| has_recent_questionnaire = False | |
| if previous_results: | |
| most_recent = previous_results[-1] | |
| if most_recent.questionnaire_results: | |
| q_info = most_recent.questionnaire_results.get("question_info", []) | |
| if q_info: | |
| question_info_list.extend(q_info) | |
| has_recent_questionnaire = True | |
| if self.verbose: | |
| print(f" 📋 Using {len(q_info)} questions from most recent stage (Stage {most_recent.stage_number})") | |
| # If no recent questionnaire results, fall back to collecting from all previous results | |
| if not has_recent_questionnaire: | |
| for prev_result in previous_results: | |
| if prev_result.questionnaire_results: | |
| q_info = prev_result.questionnaire_results.get("question_info", []) | |
| question_info_list.extend(q_info) | |
| if prev_result.toplines_results: | |
| # Extract from toplines metadata | |
| # Handle both Document objects and dicts (from checkpoint deserialization) | |
| docs = prev_result.toplines_results.get("retrieved_docs", []) | |
| for doc in docs: | |
| if hasattr(doc, 'metadata'): | |
| metadata = doc.metadata or {} | |
| elif isinstance(doc, dict): | |
| metadata = doc.get("metadata", {}) | |
| else: | |
| metadata = {} | |
| var_name = metadata.get("variable") or metadata.get("variable_name") | |
| if var_name: | |
| question_info_list.append({ | |
| "variable_name": var_name, | |
| "year": metadata.get("year"), | |
| "month": metadata.get("month", ""), | |
| "poll_date": metadata.get("poll_date", ""), | |
| "question_id": metadata.get("question_id") | |
| }) | |
| if self.verbose and question_info_list: | |
| print(f" 📋 Using {len(question_info_list)} questions from conversation history") | |
| if not question_info_list: | |
| return data_sources | |
| # Store question_info in a way that can be accessed during execution | |
| # We'll pass it directly to RAG methods, not through filters | |
| enriched_sources = [] | |
| for ds in data_sources: | |
| # Create a copy and store question_info as an attribute (not in filters) | |
| enriched_ds = ds.model_copy() | |
| # Store question_info separately - will be extracted in _execute_stage | |
| enriched_ds._question_info = question_info_list # type: ignore | |
| enriched_sources.append(enriched_ds) | |
| return enriched_sources | |
| def _extract_stage_context(self, state: SurveyAnalysisState) -> Dict[str, Any]: | |
| """Extract context from completed stage""" | |
| stage_results = state.get("stage_results", []) | |
| if not stage_results: | |
| return {} | |
| current_result = stage_results[-1] | |
| extracted_context = {} | |
| # Extract question_info | |
| question_info_list = [] | |
| if current_result.questionnaire_results: | |
| q_info = current_result.questionnaire_results.get("question_info", []) | |
| question_info_list.extend(q_info) | |
| if current_result.toplines_results: | |
| docs = current_result.toplines_results.get("retrieved_docs", []) | |
| for doc in docs: | |
| # Handle both Document objects and dicts (from checkpoint deserialization) | |
| if hasattr(doc, 'metadata'): | |
| metadata = doc.metadata or {} | |
| elif isinstance(doc, dict): | |
| metadata = doc.get("metadata", {}) | |
| else: | |
| metadata = {} | |
| var_name = metadata.get("variable") or metadata.get("variable_name") | |
| if var_name: | |
| question_info_list.append({ | |
| "variable_name": var_name, | |
| "year": metadata.get("year"), | |
| "month": metadata.get("month", ""), | |
| "poll_date": metadata.get("poll_date", ""), | |
| "question_id": metadata.get("question_id") | |
| }) | |
| if question_info_list: | |
| extracted_context["question_info"] = question_info_list | |
| current_result.extracted_context = extracted_context | |
| return {} | |
| def _route_after_stage(self, state: SurveyAnalysisState) -> str: | |
| """Route after stage execution""" | |
| brief = state["research_brief"] | |
| current_stage_idx = state.get("current_stage", 0) | |
| if brief.action == "route_to_sources": | |
| return "synthesize" | |
| total_stages = len(brief.stages) | |
| if current_stage_idx < total_stages: | |
| return "next_stage" | |
| else: | |
| return "synthesize" | |
| def _summarize_crosstab_data( | |
| self, | |
| crosstab_docs: List[Document], | |
| user_query: str, | |
| question_text: str, | |
| var_name: str | |
| ) -> str: | |
| """ | |
| Summarize crosstab data by combining multi-part crosstabs and extracting relevant demographics. | |
| This reduces token usage significantly compared to including all raw chunks. | |
| """ | |
| if not crosstab_docs: | |
| return "No crosstab data available." | |
| try: | |
| # Initialize summarizer | |
| summarizer = CrosstabSummarizer( | |
| llm_model=os.getenv("OPENAI_MODEL", "gpt-4o"), | |
| openai_api_key=self.openai_api_key | |
| ) | |
| # Sort chunks by chunk_index to maintain correct order | |
| # Handle both Document objects and dicts (from checkpoint deserialization) | |
| def get_chunk_index(d): | |
| if hasattr(d, 'metadata') and d.metadata: | |
| return d.metadata.get("chunk_index", 999) | |
| elif isinstance(d, dict): | |
| return d.get("metadata", {}).get("chunk_index", 999) | |
| return 999 | |
| sorted_docs = sorted(crosstab_docs, key=get_chunk_index) | |
| # Use summarizer to combine parts and extract relevant demographic breakdown | |
| result = summarizer.summarize( | |
| user_query=user_query, | |
| retrieved_docs=sorted_docs, | |
| question_text=question_text, | |
| top_n_sources=10 | |
| ) | |
| summary = result.get("answer", "") | |
| # Truncate summary if too long | |
| if len(summary) > 2000: | |
| summary = summary[:1500] + "\n... (summary truncated for length)" | |
| return summary | |
| except Exception as e: | |
| # Fallback: return truncated raw content if summarization fails | |
| if self.verbose: | |
| print(f"Error summarizing crosstab for {var_name}: {e}") | |
| # Handle both Document objects and dicts | |
| def get_chunk_index(d): | |
| if hasattr(d, 'metadata') and d.metadata: | |
| return d.metadata.get("chunk_index", 999) | |
| elif isinstance(d, dict): | |
| return d.get("metadata", {}).get("chunk_index", 999) | |
| return 999 | |
| sorted_docs = sorted(crosstab_docs, key=get_chunk_index) | |
| # Use only first 3 chunks to limit length | |
| combined_parts = [] | |
| for doc in sorted_docs[:3]: | |
| if hasattr(doc, 'page_content'): | |
| combined_parts.append(doc.page_content) | |
| elif isinstance(doc, dict): | |
| combined_parts.append(doc.get('page_content', '')) | |
| combined = "\n\n".join(combined_parts) | |
| if len(combined) > 1500: | |
| combined = combined[:1500] + "... (truncated)" | |
| return f"Crosstab data for {var_name}:\n{combined}" | |
| def _synthesize_response(self, state: SurveyAnalysisState) -> Dict[str, Any]: | |
| """Synthesize final answer""" | |
| if self.verbose: | |
| print("\n=== SYNTHESIZING RESPONSE ===") | |
| brief = state["research_brief"] | |
| full_question = self._get_full_question_context(state) | |
| # Handle follow-up | |
| if brief.action == "followup": | |
| return { | |
| "final_answer": brief.followup_question, | |
| "messages": [AIMessage(content=brief.followup_question)] | |
| } | |
| # Handle direct answer | |
| if brief.action == "answer": | |
| answer = self.llm.invoke([ | |
| SystemMessage(content="Answer the user's question directly."), | |
| HumanMessage(content=full_question) | |
| ]).content | |
| return { | |
| "final_answer": answer, | |
| "messages": [AIMessage(content=answer)] | |
| } | |
| # Get stage results | |
| stage_results = state.get("stage_results", []) | |
| if not stage_results: | |
| return { | |
| "final_answer": "I was unable to retrieve any data to answer your question.", | |
| "messages": [AIMessage(content="I was unable to retrieve any data to answer your question.")] | |
| } | |
| # Build context from stage results | |
| context_parts = [] | |
| for i, stage_result in enumerate(stage_results, 1): | |
| # Questionnaire data | |
| if stage_result.questionnaire_results: | |
| q_res = stage_result.questionnaire_results | |
| source_questions = q_res.get("source_questions", []) | |
| context_parts.append(f"\n=== STAGE {i} (QUESTIONNAIRE DATA) ===") | |
| if source_questions: | |
| context_parts.append(f"Retrieved {len(source_questions)} question(s):\n") | |
| for j, q in enumerate(source_questions, 1): | |
| context_parts.append(f"Question {j}: {q.get('question_text', 'N/A')}") | |
| context_parts.append(f"Variable: {q.get('variable_name', 'N/A')}") | |
| context_parts.append(f"Poll: {q.get('poll_date', 'N/A')}") | |
| context_parts.append("") | |
| # Toplines data | |
| if stage_result.toplines_results: | |
| t_res = stage_result.toplines_results | |
| retrieved_docs = t_res.get("retrieved_docs", []) | |
| context_parts.append(f"\n=== STAGE {i} (TOPLINES DATA) ===") | |
| if retrieved_docs: | |
| context_parts.append(f"Retrieved {len(retrieved_docs)} topline document(s):\n") | |
| for j, doc in enumerate(retrieved_docs, 1): | |
| # Handle both Document objects and dicts (from checkpoint deserialization) | |
| if hasattr(doc, 'metadata'): | |
| metadata = doc.metadata or {} | |
| content = doc.page_content or "" | |
| elif isinstance(doc, dict): | |
| metadata = doc.get("metadata", {}) | |
| content = doc.get("page_content", "") | |
| else: | |
| metadata = {} | |
| content = "" | |
| context_parts.append(f"--- Topline Document {j} ---") | |
| context_parts.append(f"Survey: {metadata.get('survey_name', 'Vanderbilt Unity Poll')} ({metadata.get('month', '')} {metadata.get('year', '')})") | |
| context_parts.append(f"Variable: {metadata.get('variable_name', 'N/A')}") | |
| context_parts.append(f"Response: {metadata.get('response_label', 'N/A')}") | |
| context_parts.append(f"Percentage: {metadata.get('pct', 'N/A')}%") | |
| if content: | |
| context_parts.append(f"\nContent:\n{content}") | |
| context_parts.append("") | |
| # Crosstabs data | |
| if stage_result.crosstabs_results: | |
| c_res = stage_result.crosstabs_results | |
| if "error" not in c_res: | |
| crosstab_docs_by_var = c_res.get("crosstab_docs_by_variable", {}) | |
| context_parts.append(f"\n=== STAGE {i} (CROSSTABS DATA) ===") | |
| for var_name, var_data in crosstab_docs_by_var.items(): | |
| crosstab_docs = var_data.get("crosstab_docs", []) | |
| question_text = var_data.get("question_text", "") | |
| matched_question = var_data.get("matched_question", {}) | |
| # Extract metadata from matched_question | |
| year = matched_question.get("year", "Unknown") | |
| month = matched_question.get("month", "Unknown") | |
| poll_date = matched_question.get("poll_date", f"{year}-{month}") | |
| # Extract sample size from first crosstab document if available | |
| # Note: Sample size is typically embedded in the crosstab content | |
| sample_size = "See crosstab data below" | |
| if crosstab_docs: | |
| first_doc = crosstab_docs[0] | |
| if hasattr(first_doc, 'metadata'): | |
| sample_size = first_doc.metadata.get("total_n", sample_size) | |
| elif isinstance(first_doc, dict): | |
| sample_size = first_doc.get("metadata", {}).get("total_n", sample_size) | |
| context_parts.append(f"\n{'='*80}") | |
| context_parts.append(f"Variable: {var_name}") | |
| context_parts.append(f"Question: {question_text}") | |
| context_parts.append(f"Poll Date: {poll_date}") | |
| context_parts.append(f"Year: {year}, Month: {month}") | |
| context_parts.append(f"Sample Size (N): {sample_size}") | |
| context_parts.append(f"{'='*80}\n") | |
| # Summarize crosstab data (reduces token usage significantly) | |
| summary = self._summarize_crosstab_data( | |
| crosstab_docs=crosstab_docs, | |
| user_query=full_question, | |
| question_text=question_text, | |
| var_name=var_name | |
| ) | |
| context_parts.append(summary) | |
| context_parts.append("") # Extra blank line between variables | |
| # Synthesize | |
| synthesis_prompt_template = _load_prompt_file("synthesis_prompt_user.txt") | |
| synthesis_prompt = synthesis_prompt_template.format( | |
| stage_count='multiple stages' if len(stage_results) > 1 else 'the research', | |
| full_question=full_question, | |
| reasoning=brief.reasoning, | |
| context_parts="\n".join(context_parts), | |
| unavailable_note="" | |
| ) | |
| synthesis_system_prompt = _load_prompt_file("synthesis_prompt_system.txt") | |
| final_answer = self.llm.invoke([ | |
| SystemMessage(content=synthesis_system_prompt), | |
| HumanMessage(content=synthesis_prompt) | |
| ]).content | |
| if self.verbose: | |
| print("Synthesis complete") | |
| return { | |
| "final_answer": final_answer, | |
| "messages": [AIMessage(content=final_answer)] | |
| } | |
| def query(self, question: str, thread_id: str = "default") -> str: | |
| """Query the survey analysis system""" | |
| # Try to load previous state from checkpoint to preserve stage_results across turns | |
| previous_stage_results = [] | |
| try: | |
| checkpoint = self.graph.get_state(config={"configurable": {"thread_id": thread_id}}) | |
| if checkpoint and checkpoint.values: | |
| previous_stage_results = checkpoint.values.get("stage_results", []) | |
| except: | |
| previous_stage_results = [] | |
| initial_state = { | |
| "messages": [HumanMessage(content=question)], | |
| "user_question": question, | |
| "research_brief": None, | |
| "current_stage": 0, | |
| "stage_results": previous_stage_results, # Preserve previous results for multi-turn | |
| "final_answer": None | |
| } | |
| config = { | |
| "configurable": {"thread_id": thread_id}, | |
| "recursion_limit": 50 | |
| } | |
| if self.verbose: | |
| print(f"\n🧵 Thread ID: {thread_id}") | |
| final_state = self.graph.invoke(initial_state, config) | |
| return final_state["final_answer"] | |
| # ============================================================================ | |
| # CLI INTERFACE | |
| # ============================================================================ | |
| def main(): | |
| """Interactive CLI""" | |
| import sys | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| pinecone_api_key = os.getenv("PINECONE_API_KEY") | |
| if not openai_api_key or not pinecone_api_key: | |
| print("Error: Missing API keys") | |
| print("Set OPENAI_API_KEY and PINECONE_API_KEY environment variables") | |
| sys.exit(1) | |
| print("Initializing survey analysis agent...") | |
| agent = SurveyAnalysisAgent( | |
| openai_api_key=openai_api_key, | |
| pinecone_api_key=pinecone_api_key, | |
| verbose=True | |
| ) | |
| print("\n" + "="*80) | |
| print("SURVEY ANALYSIS AGENT") | |
| print("="*80) | |
| print("\nType 'quit' to exit\n") | |
| thread_id = "cli_session" | |
| while True: | |
| try: | |
| question = input("\nYour question: ").strip() | |
| if not question or question.lower() in ['quit', 'exit', 'q']: | |
| print("\nGoodbye!") | |
| break | |
| print("\n" + "-"*80) | |
| answer = agent.query(question, thread_id=thread_id) | |
| print("\n" + "="*80) | |
| print("ANSWER:") | |
| print("="*80) | |
| print(answer) | |
| print("="*80) | |
| except KeyboardInterrupt: | |
| print("\n\nGoodbye!") | |
| break | |
| except Exception as e: | |
| print(f"\nError: {e}") | |
| if os.getenv("DEBUG"): | |
| raise | |
| if __name__ == "__main__": | |
| main() | |