survey-analytics / survey_agent.py
umangchaudhry's picture
Upload 20 files
cc2626e verified
"""
Multi-agent survey analysis system using LangGraph.
Simplified implementation with integrated research brief and response synthesizer.
"""
import os
from typing import TypedDict, Literal, Annotated, List, Dict, Any, Optional
from pathlib import Path
import operator
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from pydantic import BaseModel, Field, ConfigDict
# Local RAG system imports
from questionnaire_rag import QuestionnaireRAG
from toplines_rag import ToplinesRAG
from crosstab_rag import CrosstabsRAG, CrosstabSummarizer
from relevance_checker import ConversationRelevanceChecker
from langchain.schema import Document
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
def _load_prompt_file(filename: str) -> str:
"""Load a prompt file from the prompts directory"""
prompt_dir = Path(__file__).parent / "prompts"
prompt_path = prompt_dir / filename
if not prompt_path.exists():
raise FileNotFoundError(f"Prompt file not found: {prompt_path}")
return prompt_path.read_text(encoding="utf-8")
# ============================================================================
# STATE DEFINITIONS
# ============================================================================
class QuestionInfo(BaseModel):
"""Structured question information"""
variable_name: str
year: Optional[int] = None
month: Optional[str] = None
poll_date: Optional[str] = None
question_id: Optional[str] = None
class QueryFilters(BaseModel):
"""Filters for data source queries"""
model_config = ConfigDict(extra="forbid")
year: Optional[int] = None
month: Optional[str] = None
poll_date: Optional[str] = None
survey_name: Optional[str] = None
topic: Optional[str] = None
question_ids: Optional[List[str]] = None
class DataSource(BaseModel):
"""Single data source pipeline to query"""
model_config = ConfigDict(extra="forbid")
source_type: Literal["questionnaire", "toplines", "crosstabs"]
query_description: str
filters: QueryFilters = Field(default_factory=QueryFilters)
result_label: Optional[str] = None
class ResearchStage(BaseModel):
"""Single stage in multi-stage research"""
model_config = ConfigDict(extra="forbid")
stage_number: int
description: str
data_sources: List[DataSource]
depends_on_stages: List[int] = Field(default_factory=list)
use_previous_results_for: Optional[str] = None
class ResearchBrief(BaseModel):
"""Research brief defining how to answer the question"""
model_config = ConfigDict(extra="forbid")
action: Literal["answer", "followup", "route_to_sources", "execute_stages"]
followup_question: Optional[str] = None
reasoning: str
data_sources: List[DataSource] = Field(default_factory=list)
stages: List[ResearchStage] = Field(default_factory=list)
class StageResult(BaseModel):
"""Results from executing a single research stage"""
model_config = ConfigDict(extra="forbid")
stage_number: int
status: Literal["success", "partial", "failed"]
questionnaire_results: Optional[Dict[str, Any]] = None
toplines_results: Optional[Dict[str, Any]] = None
crosstabs_results: Optional[Dict[str, Any]] = None
extracted_context: Optional[Dict[str, Any]] = None
# Metadata for better reusability tracking
query_type: Optional[str] = None # "questionnaire", "toplines", "crosstabs"
time_period: Optional[Dict[str, Any]] = None # {"year": 2025, "month": "June"}
topics: Optional[List[str]] = None # Topics queried
class SurveyAnalysisState(TypedDict):
"""State dictionary for the survey analysis agent workflow"""
messages: Annotated[List, operator.add]
user_question: str
research_brief: Optional[ResearchBrief]
current_stage: int
stage_results: List[StageResult]
final_answer: Optional[str]
# ============================================================================
# SURVEY ANALYSIS AGENT
# ============================================================================
class SurveyAnalysisAgent:
"""Multi-agent system for analyzing survey data using LangGraph."""
def __init__(
self,
openai_api_key: str,
pinecone_api_key: str,
questionnaire_persist_dir: Optional[str] = None,
verbose: bool = True
):
self.openai_api_key = openai_api_key
self.pinecone_api_key = pinecone_api_key
self.verbose = verbose
# Initialize LLM
self.llm = ChatOpenAI(
model=os.getenv("OPENAI_MODEL", "gpt-4o"),
temperature=0
)
# Initialize RAG systems
if self.verbose:
print("Initializing RAG systems...")
# Default to parent directory if not specified
if questionnaire_persist_dir is None:
# Try current directory first, then parent
current_dir = Path("./questionnaire_vectorstores")
parent_dir = Path("../questionnaire_vectorstores")
if current_dir.exists() and (current_dir / "poll_catalog.json").exists():
questionnaire_persist_dir = str(current_dir)
elif parent_dir.exists() and (parent_dir / "poll_catalog.json").exists():
questionnaire_persist_dir = str(parent_dir)
else:
questionnaire_persist_dir = "./questionnaire_vectorstores"
self.questionnaire_rag = QuestionnaireRAG(
openai_api_key=openai_api_key,
pinecone_api_key=pinecone_api_key,
persist_directory=questionnaire_persist_dir,
verbose=verbose
)
self.toplines_rag = ToplinesRAG(verbose=verbose)
self.crosstab_rag = CrosstabsRAG(
questionnaire_rag=self.questionnaire_rag,
verbose=self.verbose
)
# Initialize relevance checker
self.relevance_checker = ConversationRelevanceChecker(
llm=self.llm,
verbose=self.verbose
)
# Build the graph
self.graph = self._build_graph()
if self.verbose:
print("✓ Survey analysis agent initialized")
def _build_graph(self) -> StateGraph:
"""Build the LangGraph workflow"""
workflow = StateGraph(SurveyAnalysisState)
# Add workflow nodes
workflow.add_node("generate_research_brief", self._generate_research_brief)
workflow.add_node("execute_stage", self._execute_stage)
workflow.add_node("extract_stage_context", self._extract_stage_context)
workflow.add_node("synthesize_response", self._synthesize_response)
# Start: Always begin with research brief generation
workflow.add_edge(START, "generate_research_brief")
# Route after research brief
workflow.add_conditional_edges(
"generate_research_brief",
self._route_after_brief,
{
"followup": "synthesize_response",
"answer": "synthesize_response",
"execute_stage": "execute_stage"
}
)
# After stage execution, extract context
workflow.add_edge("execute_stage", "extract_stage_context")
# Route after context extraction
workflow.add_conditional_edges(
"extract_stage_context",
self._route_after_stage,
{
"next_stage": "execute_stage",
"synthesize": "synthesize_response"
}
)
# End: Always end after synthesis
workflow.add_edge("synthesize_response", END)
# Compile with memory checkpointing
memory = MemorySaver()
return workflow.compile(checkpointer=memory)
def _get_full_question_context(self, state: SurveyAnalysisState) -> str:
"""Extract current question from conversation history"""
messages = state.get("messages", [])
human_messages = [msg.content for msg in messages if isinstance(msg, HumanMessage)]
if not human_messages:
return state["user_question"]
return human_messages[-1]
def _combine_query_with_context(self, state: SurveyAnalysisState) -> str:
"""Combine user query with follow-up context for semantic search"""
messages = state.get("messages", [])
human_messages = [msg.content for msg in messages if isinstance(msg, HumanMessage)]
if len(human_messages) <= 1:
return state["user_question"]
# Combine last 2 messages if the latest is a short answer
latest = human_messages[-1]
previous = human_messages[-2] if len(human_messages) > 1 else ""
# Check if latest is a short answer (month, year, etc.)
is_short_answer = len(latest.split()) <= 2 and any(
word.lower() in ['june', 'february', 'march', 'april', 'may', 'july',
'august', 'september', 'october', 'november', 'december', 'january']
or word.isdigit()
for word in latest.split()
)
if is_short_answer and previous:
# Combine: "Biden approval" + "June 2025" -> "Biden approval rating in June 2025"
return f"{previous} {latest}"
return latest
def _get_available_surveys_description(self) -> str:
"""Get formatted description of available surveys"""
survey_names = self.questionnaire_rag.get_available_survey_names()
if not survey_names:
return "No surveys currently loaded."
lines = ["Available survey names in the system:"]
for name in survey_names:
lines.append(f" - '{name}'")
return "\n".join(lines)
def _get_available_months_description(self) -> str:
"""Get formatted description of available months by year"""
month_order = [
"January", "February", "March", "April", "May", "June",
"July", "August", "September", "October", "November", "December"
]
catalog = self.questionnaire_rag.poll_catalog
years = {}
for poll_date, info in catalog.items():
year = info.get("year")
month = info.get("month")
survey = info.get("survey_name")
if year and month and survey == "Vanderbilt_Unity_Poll":
if year not in years:
years[year] = []
if month not in years[year]:
years[year].append(month)
lines = ["Available polls by year (Vanderbilt Unity Poll):"]
for year in sorted(years.keys()):
months_sorted = sorted(
years[year],
key=lambda m: month_order.index(m) if m in month_order else 999
)
months_str = ", ".join(months_sorted)
lines.append(f" {year}: {months_str}")
return "\n".join(lines)
def _generate_research_brief(self, state: SurveyAnalysisState) -> Dict[str, Any]:
"""Generate research brief"""
if self.verbose:
print("\n=== GENERATING RESEARCH BRIEF ===")
question = self._get_full_question_context(state)
messages = state.get("messages", [])
human_messages = [msg.content for msg in messages if isinstance(msg, HumanMessage)]
# Check for previously retrieved questions in conversation history
previous_stage_results = state.get("stage_results", [])
previously_retrieved_questions = []
if previous_stage_results:
for prev_result in previous_stage_results:
if prev_result.questionnaire_results:
q_info = prev_result.questionnaire_results.get("question_info", [])
source_questions = prev_result.questionnaire_results.get("source_questions", [])
if q_info or source_questions:
previously_retrieved_questions.append({
"question_info": q_info,
"source_questions": source_questions,
"num_questions": len(source_questions) if source_questions else len(q_info)
})
# Check relevance if there's conversation history
relevance_result = None
if len(human_messages) > 1 and previous_stage_results:
relevance_result = self.relevance_checker.check_relevance(
current_question=question,
conversation_history=messages[-6:], # Last 3 turns (user + assistant)
previous_stage_results=previous_stage_results
)
# Build conversation context
conversation_context = ""
if len(human_messages) > 1:
conversation_context = "\n\nCONVERSATION HISTORY (for context):\n"
recent_messages = human_messages[-4:-1] if len(human_messages) > 4 else human_messages[:-1]
for i, msg in enumerate(recent_messages, 1):
conversation_context += f" {i}. {msg}\n"
conversation_context += f"\nCurrent question: {question}\n"
# Add relevance analysis if available
if relevance_result:
conversation_context += f"\n{'='*80}\n"
conversation_context += f"🔍 RELEVANCE ANALYSIS:\n"
conversation_context += f"{'='*80}\n"
conversation_context += f"Related to previous: {relevance_result['is_related']}\n"
conversation_context += f"Relation type: {relevance_result['relation_type']}\n"
conversation_context += f"Time period changed: {relevance_result['time_period_changed']}\n"
conversation_context += f"Reasoning: {relevance_result['reasoning']}\n\n"
if relevance_result['is_related']:
reusable = relevance_result['reusable_data']
conversation_context += f"REUSABLE DATA:\n"
if reusable.get('questions') and previously_retrieved_questions:
num_questions = previously_retrieved_questions[0]['num_questions'] if previously_retrieved_questions else 0
conversation_context += f" ✅ Questions: {num_questions} question(s) available from previous turn\n"
conversation_context += f" → Use route_to_sources (single-stage) for TOPLINES or CROSSTABS\n"
conversation_context += f" → DO NOT query QUESTIONNAIRE again\n"
if relevance_result['relation_type'] == 'trend_analysis':
conversation_context += f"\n⚠️ CRITICAL: This is an ANALYSIS-ONLY query.\n"
conversation_context += f" - User wants analysis of ALREADY RETRIEVED data\n"
conversation_context += f" - Use action='answer' to synthesize from conversation history\n"
conversation_context += f" - DO NOT retrieve any new data\n"
if relevance_result['time_period_changed']:
conversation_context += f"\n⚠️ Time period changed - treat as NEW QUERY\n"
conversation_context += f" - Previous questions may not exist in new time period\n"
conversation_context += f" - Must query QUESTIONNAIRE for new time period\n"
conversation_context += f"\n"
# Add information about previously retrieved questions
if previously_retrieved_questions:
conversation_context += f"\n🚨 CRITICAL: Questions were ALREADY RETRIEVED in previous conversation turns:\n"
for i, prev_q in enumerate(previously_retrieved_questions, 1):
num_q = prev_q["num_questions"]
q_info = prev_q["question_info"]
if q_info:
# Show sample of question info
sample_vars = [q.get("variable_name", "unknown") for q in q_info[:3]]
sample_vars_str = ", ".join(sample_vars)
if len(q_info) > 3:
sample_vars_str += f" ... and {len(q_info) - 3} more"
conversation_context += f" Previous turn {i}: {num_q} question(s) retrieved (variables: {sample_vars_str})\n"
else:
conversation_context += f" Previous turn {i}: {num_q} question(s) retrieved\n"
conversation_context += f"\n⚠️ IMPORTANT: If the current question references 'these questions' or asks about responses to previously shown questions:\n"
conversation_context += f" - DO NOT query QUESTIONNAIRE pipeline again - questions are already available\n"
conversation_context += f" - Go DIRECTLY to TOPLINES or CROSSTABS using question_info from previous results\n"
conversation_context += f" - Use execute_stages with Stage 1 extracting question_info, Stage 2 querying TOPLINES/CROSSTABS\n"
# Check for short answers
latest = human_messages[-1]
is_short_answer = len(latest.split()) <= 2 and any(
word.lower() in ['june', 'february', 'march', 'april', 'may', 'july',
'august', 'september', 'october', 'november', 'december', 'january']
or word.isdigit()
for word in latest.split()
)
if is_short_answer and len(human_messages) > 1:
original_question = human_messages[-2]
conversation_context += f"\n🚨 IMPORTANT: The current question '{latest}' is a SHORT ANSWER.\n"
conversation_context += f"Original question was: '{original_question}'\n"
conversation_context += f"Combine '{latest}' with the original intent from '{original_question}'.\n"
# Load research brief prompt
system_prompt_template = _load_prompt_file("research_brief_prompt.txt")
system_prompt = system_prompt_template.format(
available_pipelines="Questionnaire: Survey questions\nToplines: Response frequencies\nCrosstabs: Demographic breakdowns",
available_surveys=self._get_available_surveys_description(),
available_months=self._get_available_months_description()
)
brief_generator = self.llm.with_structured_output(ResearchBrief)
user_prompt = f"User question: {question}\n\nGenerate a research brief."
if conversation_context:
user_prompt = conversation_context + "\n\n" + user_prompt
brief = brief_generator.invoke([
SystemMessage(content=system_prompt),
HumanMessage(content=user_prompt)
])
if self.verbose:
print(f"Action: {brief.action}")
print(f"Reasoning: {brief.reasoning}")
# Clear stage_results if this is a new topic, otherwise preserve for follow-ups
existing_stage_results = state.get("stage_results", [])
# If relevance checker determined this is a new/unrelated topic, clear previous results
if relevance_result and not relevance_result.get('is_related', True):
if self.verbose:
print(f"🔄 Clearing previous stage results (new topic detected)")
existing_stage_results = []
return {
"research_brief": brief,
"current_stage": 0,
"stage_results": existing_stage_results, # Clear for new topics, preserve for follow-ups
"messages": [AIMessage(content=f"[Research plan: {brief.action}]")]
}
def _route_after_brief(self, state: SurveyAnalysisState) -> str:
"""Route after research brief"""
brief = state["research_brief"]
if brief.action in ("followup", "answer"):
return brief.action
return "execute_stage"
def _execute_stage(self, state: SurveyAnalysisState) -> Dict[str, Any]:
"""Execute one stage of research"""
brief = state["research_brief"]
current_stage_idx = state.get("current_stage", 0)
previous_stage_results = state.get("stage_results", [])
# Determine execution mode
if brief.action == "route_to_sources":
stage_data_sources = brief.data_sources
stage_desc = "Single-stage retrieval"
# CRITICAL: Enrich with previous results to avoid redundant Questionnaire queries
if previous_stage_results:
# Try to extract question_info from previous results
question_info_available = False
for prev_result in previous_stage_results:
if prev_result.questionnaire_results:
q_info = prev_result.questionnaire_results.get("question_info", [])
if q_info:
question_info_available = True
if self.verbose:
print(f" ✅ Found {len(q_info)} questions from previous turn - will reuse")
break
# If question_info is available and we're querying toplines/crosstabs, enrich
if question_info_available:
for ds in stage_data_sources:
if ds.source_type in ("toplines", "crosstabs"):
if self.verbose:
print(f" 🔄 Enriching {ds.source_type} data source with previous question_info")
stage_data_sources = self._enrich_data_sources_with_context(
stage_data_sources, previous_stage_results,
"Extract question_info from previous conversation results"
)
break
else:
if self.verbose:
print(f" ⚠️ No question_info available from previous turns")
elif brief.action == "execute_stages":
stage = brief.stages[current_stage_idx]
stage_data_sources = stage.data_sources
stage_desc = stage.description
# Check if Stage 1 is just extraction (no data sources to execute)
if (current_stage_idx == 0 and
stage.use_previous_results_for and
"extract" in stage.use_previous_results_for.lower() and
not stage.data_sources and
previous_stage_results):
# This is an extraction-only stage - skip execution, just extract context
if self.verbose:
print(f"\n=== EXECUTING STAGE {stage.stage_number} (EXTRACTION ONLY) ===")
print(f"Description: {stage.description}")
print(f" Extracting question_info from previous conversation results")
# Extract and return immediately (context extraction happens in _extract_stage_context)
return {
"stage_results": previous_stage_results, # Keep previous results
"current_stage": current_stage_idx + 1
}
# Enrich with context from previous stages
if stage.use_previous_results_for and previous_stage_results:
stage_data_sources = self._enrich_data_sources_with_context(
stage_data_sources, previous_stage_results, stage.use_previous_results_for
)
else:
return {}
if self.verbose:
print(f"\n=== EXECUTING STAGE ===")
print(f"Description: {stage_desc}")
# Initialize stage result
stage_result = StageResult(
stage_number=current_stage_idx + 1,
status="success"
)
# Get combined query for semantic search fallback
combined_query = self._combine_query_with_context(state)
# Execute each data source
for ds in stage_data_sources:
filters_dict = {k: v for k, v in ds.filters.model_dump().items() if v is not None}
if ds.source_type == "questionnaire":
if self.verbose:
print(f"\n📊 [Questionnaire] {ds.query_description}")
result = self.questionnaire_rag.retrieve_raw_data(
question=ds.query_description,
filters=filters_dict if filters_dict else None
)
stage_result.questionnaire_results = result
# Populate metadata
stage_result.query_type = "questionnaire"
if filters_dict:
stage_result.time_period = {
"year": filters_dict.get("year"),
"month": filters_dict.get("month")
}
if filters_dict.get("topic"):
stage_result.topics = [filters_dict.get("topic")]
elif ds.source_type == "toplines":
if self.verbose:
print(f"\n📊 [Toplines] {ds.query_description}")
# Extract question_info and source_questions
# PRIORITY ORDER:
# 1. First check enriched data on data source (from current execution's Stage 1)
# 2. Then fall back to previous stage results (from conversation history)
question_info = None
source_questions = None
# Check enriched data first (most recent Stage 1 execution)
if hasattr(ds, '_question_info'):
question_info = ds._question_info # type: ignore
if self.verbose and question_info:
print(f" 📋 Using {len(question_info)} questions from enriched data source")
# Fall back to previous questionnaire results if no enriched data
if not question_info and previous_stage_results:
for prev_result in previous_stage_results:
if prev_result.questionnaire_results:
question_info = prev_result.questionnaire_results.get("question_info", [])
source_questions = prev_result.questionnaire_results.get("source_questions", [])
if question_info and self.verbose:
print(f" 📋 Using {len(question_info)} questions from conversation history")
break
result = self.toplines_rag.retrieve_raw_data(
query=combined_query if not question_info else ds.query_description,
question_info=question_info,
source_questions=source_questions,
filters=filters_dict if not question_info else None,
top_k=10
)
stage_result.toplines_results = result
# Populate metadata
stage_result.query_type = "toplines"
if filters_dict:
stage_result.time_period = {
"year": filters_dict.get("year"),
"month": filters_dict.get("month")
}
elif ds.source_type == "crosstabs":
if self.verbose:
print(f"\n📊 [Crosstabs] {ds.query_description}")
# Extract question_info and source_questions
# PRIORITY ORDER:
# 1. First check enriched data on data source (from current execution's Stage 1)
# 2. Then fall back to previous stage results (from conversation history)
question_info = None
source_questions = None
# Check enriched data first (most recent Stage 1 execution)
if hasattr(ds, '_question_info'):
question_info = ds._question_info # type: ignore
if self.verbose and question_info:
print(f" 📋 Using {len(question_info)} questions from enriched data source")
# Fall back to previous questionnaire results if no enriched data
if not question_info and previous_stage_results:
for prev_result in previous_stage_results:
if prev_result.questionnaire_results:
question_info = prev_result.questionnaire_results.get("question_info", [])
source_questions = prev_result.questionnaire_results.get("source_questions", [])
if question_info:
if self.verbose:
print(f" 📋 Using {len(question_info)} questions from conversation history")
break
if self.verbose:
if question_info:
print(f" ✅ Will use question_info (skipping Questionnaire)")
else:
print(f" ⚠️ No question_info available, will query Questionnaire")
result = self.crosstab_rag.retrieve_raw_data(
user_query=combined_query if not question_info else ds.query_description,
question_info=question_info,
source_questions=source_questions,
filters=filters_dict # Always pass filters for time period filtering
)
stage_result.crosstabs_results = result
# Populate metadata
stage_result.query_type = "crosstabs"
if filters_dict:
stage_result.time_period = {
"year": filters_dict.get("year"),
"month": filters_dict.get("month")
}
# Check if stage has data
has_data = (
(stage_result.questionnaire_results and stage_result.questionnaire_results.get("source_questions")) or
(stage_result.toplines_results and stage_result.toplines_results.get("retrieved_docs")) or
(stage_result.crosstabs_results and stage_result.crosstabs_results.get("matched_variables"))
)
updated_stage_results = previous_stage_results
if has_data:
updated_stage_results = previous_stage_results + [stage_result]
return {
"stage_results": updated_stage_results,
"current_stage": current_stage_idx + 1
}
def _enrich_data_sources_with_context(
self,
data_sources: List[DataSource],
previous_results: List[StageResult],
use_instruction: str
) -> List[DataSource]:
"""Enrich data sources with context from previous stages"""
if self.verbose:
print(f" Enriching with context: {use_instruction}")
# Extract question_info from previous results
# IMPORTANT: Prioritize most recent questionnaire results (from current execution)
# over older conversation results to avoid using stale data when time period changes
question_info_list = []
# First, check if the most recent stage has questionnaire results
# If so, use ONLY those (this handles time period changes correctly)
has_recent_questionnaire = False
if previous_results:
most_recent = previous_results[-1]
if most_recent.questionnaire_results:
q_info = most_recent.questionnaire_results.get("question_info", [])
if q_info:
question_info_list.extend(q_info)
has_recent_questionnaire = True
if self.verbose:
print(f" 📋 Using {len(q_info)} questions from most recent stage (Stage {most_recent.stage_number})")
# If no recent questionnaire results, fall back to collecting from all previous results
if not has_recent_questionnaire:
for prev_result in previous_results:
if prev_result.questionnaire_results:
q_info = prev_result.questionnaire_results.get("question_info", [])
question_info_list.extend(q_info)
if prev_result.toplines_results:
# Extract from toplines metadata
# Handle both Document objects and dicts (from checkpoint deserialization)
docs = prev_result.toplines_results.get("retrieved_docs", [])
for doc in docs:
if hasattr(doc, 'metadata'):
metadata = doc.metadata or {}
elif isinstance(doc, dict):
metadata = doc.get("metadata", {})
else:
metadata = {}
var_name = metadata.get("variable") or metadata.get("variable_name")
if var_name:
question_info_list.append({
"variable_name": var_name,
"year": metadata.get("year"),
"month": metadata.get("month", ""),
"poll_date": metadata.get("poll_date", ""),
"question_id": metadata.get("question_id")
})
if self.verbose and question_info_list:
print(f" 📋 Using {len(question_info_list)} questions from conversation history")
if not question_info_list:
return data_sources
# Store question_info in a way that can be accessed during execution
# We'll pass it directly to RAG methods, not through filters
enriched_sources = []
for ds in data_sources:
# Create a copy and store question_info as an attribute (not in filters)
enriched_ds = ds.model_copy()
# Store question_info separately - will be extracted in _execute_stage
enriched_ds._question_info = question_info_list # type: ignore
enriched_sources.append(enriched_ds)
return enriched_sources
def _extract_stage_context(self, state: SurveyAnalysisState) -> Dict[str, Any]:
"""Extract context from completed stage"""
stage_results = state.get("stage_results", [])
if not stage_results:
return {}
current_result = stage_results[-1]
extracted_context = {}
# Extract question_info
question_info_list = []
if current_result.questionnaire_results:
q_info = current_result.questionnaire_results.get("question_info", [])
question_info_list.extend(q_info)
if current_result.toplines_results:
docs = current_result.toplines_results.get("retrieved_docs", [])
for doc in docs:
# Handle both Document objects and dicts (from checkpoint deserialization)
if hasattr(doc, 'metadata'):
metadata = doc.metadata or {}
elif isinstance(doc, dict):
metadata = doc.get("metadata", {})
else:
metadata = {}
var_name = metadata.get("variable") or metadata.get("variable_name")
if var_name:
question_info_list.append({
"variable_name": var_name,
"year": metadata.get("year"),
"month": metadata.get("month", ""),
"poll_date": metadata.get("poll_date", ""),
"question_id": metadata.get("question_id")
})
if question_info_list:
extracted_context["question_info"] = question_info_list
current_result.extracted_context = extracted_context
return {}
def _route_after_stage(self, state: SurveyAnalysisState) -> str:
"""Route after stage execution"""
brief = state["research_brief"]
current_stage_idx = state.get("current_stage", 0)
if brief.action == "route_to_sources":
return "synthesize"
total_stages = len(brief.stages)
if current_stage_idx < total_stages:
return "next_stage"
else:
return "synthesize"
def _summarize_crosstab_data(
self,
crosstab_docs: List[Document],
user_query: str,
question_text: str,
var_name: str
) -> str:
"""
Summarize crosstab data by combining multi-part crosstabs and extracting relevant demographics.
This reduces token usage significantly compared to including all raw chunks.
"""
if not crosstab_docs:
return "No crosstab data available."
try:
# Initialize summarizer
summarizer = CrosstabSummarizer(
llm_model=os.getenv("OPENAI_MODEL", "gpt-4o"),
openai_api_key=self.openai_api_key
)
# Sort chunks by chunk_index to maintain correct order
# Handle both Document objects and dicts (from checkpoint deserialization)
def get_chunk_index(d):
if hasattr(d, 'metadata') and d.metadata:
return d.metadata.get("chunk_index", 999)
elif isinstance(d, dict):
return d.get("metadata", {}).get("chunk_index", 999)
return 999
sorted_docs = sorted(crosstab_docs, key=get_chunk_index)
# Use summarizer to combine parts and extract relevant demographic breakdown
result = summarizer.summarize(
user_query=user_query,
retrieved_docs=sorted_docs,
question_text=question_text,
top_n_sources=10
)
summary = result.get("answer", "")
# Truncate summary if too long
if len(summary) > 2000:
summary = summary[:1500] + "\n... (summary truncated for length)"
return summary
except Exception as e:
# Fallback: return truncated raw content if summarization fails
if self.verbose:
print(f"Error summarizing crosstab for {var_name}: {e}")
# Handle both Document objects and dicts
def get_chunk_index(d):
if hasattr(d, 'metadata') and d.metadata:
return d.metadata.get("chunk_index", 999)
elif isinstance(d, dict):
return d.get("metadata", {}).get("chunk_index", 999)
return 999
sorted_docs = sorted(crosstab_docs, key=get_chunk_index)
# Use only first 3 chunks to limit length
combined_parts = []
for doc in sorted_docs[:3]:
if hasattr(doc, 'page_content'):
combined_parts.append(doc.page_content)
elif isinstance(doc, dict):
combined_parts.append(doc.get('page_content', ''))
combined = "\n\n".join(combined_parts)
if len(combined) > 1500:
combined = combined[:1500] + "... (truncated)"
return f"Crosstab data for {var_name}:\n{combined}"
def _synthesize_response(self, state: SurveyAnalysisState) -> Dict[str, Any]:
"""Synthesize final answer"""
if self.verbose:
print("\n=== SYNTHESIZING RESPONSE ===")
brief = state["research_brief"]
full_question = self._get_full_question_context(state)
# Handle follow-up
if brief.action == "followup":
return {
"final_answer": brief.followup_question,
"messages": [AIMessage(content=brief.followup_question)]
}
# Handle direct answer
if brief.action == "answer":
answer = self.llm.invoke([
SystemMessage(content="Answer the user's question directly."),
HumanMessage(content=full_question)
]).content
return {
"final_answer": answer,
"messages": [AIMessage(content=answer)]
}
# Get stage results
stage_results = state.get("stage_results", [])
if not stage_results:
return {
"final_answer": "I was unable to retrieve any data to answer your question.",
"messages": [AIMessage(content="I was unable to retrieve any data to answer your question.")]
}
# Build context from stage results
context_parts = []
for i, stage_result in enumerate(stage_results, 1):
# Questionnaire data
if stage_result.questionnaire_results:
q_res = stage_result.questionnaire_results
source_questions = q_res.get("source_questions", [])
context_parts.append(f"\n=== STAGE {i} (QUESTIONNAIRE DATA) ===")
if source_questions:
context_parts.append(f"Retrieved {len(source_questions)} question(s):\n")
for j, q in enumerate(source_questions, 1):
context_parts.append(f"Question {j}: {q.get('question_text', 'N/A')}")
context_parts.append(f"Variable: {q.get('variable_name', 'N/A')}")
context_parts.append(f"Poll: {q.get('poll_date', 'N/A')}")
context_parts.append("")
# Toplines data
if stage_result.toplines_results:
t_res = stage_result.toplines_results
retrieved_docs = t_res.get("retrieved_docs", [])
context_parts.append(f"\n=== STAGE {i} (TOPLINES DATA) ===")
if retrieved_docs:
context_parts.append(f"Retrieved {len(retrieved_docs)} topline document(s):\n")
for j, doc in enumerate(retrieved_docs, 1):
# Handle both Document objects and dicts (from checkpoint deserialization)
if hasattr(doc, 'metadata'):
metadata = doc.metadata or {}
content = doc.page_content or ""
elif isinstance(doc, dict):
metadata = doc.get("metadata", {})
content = doc.get("page_content", "")
else:
metadata = {}
content = ""
context_parts.append(f"--- Topline Document {j} ---")
context_parts.append(f"Survey: {metadata.get('survey_name', 'Vanderbilt Unity Poll')} ({metadata.get('month', '')} {metadata.get('year', '')})")
context_parts.append(f"Variable: {metadata.get('variable_name', 'N/A')}")
context_parts.append(f"Response: {metadata.get('response_label', 'N/A')}")
context_parts.append(f"Percentage: {metadata.get('pct', 'N/A')}%")
if content:
context_parts.append(f"\nContent:\n{content}")
context_parts.append("")
# Crosstabs data
if stage_result.crosstabs_results:
c_res = stage_result.crosstabs_results
if "error" not in c_res:
crosstab_docs_by_var = c_res.get("crosstab_docs_by_variable", {})
context_parts.append(f"\n=== STAGE {i} (CROSSTABS DATA) ===")
for var_name, var_data in crosstab_docs_by_var.items():
crosstab_docs = var_data.get("crosstab_docs", [])
question_text = var_data.get("question_text", "")
matched_question = var_data.get("matched_question", {})
# Extract metadata from matched_question
year = matched_question.get("year", "Unknown")
month = matched_question.get("month", "Unknown")
poll_date = matched_question.get("poll_date", f"{year}-{month}")
# Extract sample size from first crosstab document if available
# Note: Sample size is typically embedded in the crosstab content
sample_size = "See crosstab data below"
if crosstab_docs:
first_doc = crosstab_docs[0]
if hasattr(first_doc, 'metadata'):
sample_size = first_doc.metadata.get("total_n", sample_size)
elif isinstance(first_doc, dict):
sample_size = first_doc.get("metadata", {}).get("total_n", sample_size)
context_parts.append(f"\n{'='*80}")
context_parts.append(f"Variable: {var_name}")
context_parts.append(f"Question: {question_text}")
context_parts.append(f"Poll Date: {poll_date}")
context_parts.append(f"Year: {year}, Month: {month}")
context_parts.append(f"Sample Size (N): {sample_size}")
context_parts.append(f"{'='*80}\n")
# Summarize crosstab data (reduces token usage significantly)
summary = self._summarize_crosstab_data(
crosstab_docs=crosstab_docs,
user_query=full_question,
question_text=question_text,
var_name=var_name
)
context_parts.append(summary)
context_parts.append("") # Extra blank line between variables
# Synthesize
synthesis_prompt_template = _load_prompt_file("synthesis_prompt_user.txt")
synthesis_prompt = synthesis_prompt_template.format(
stage_count='multiple stages' if len(stage_results) > 1 else 'the research',
full_question=full_question,
reasoning=brief.reasoning,
context_parts="\n".join(context_parts),
unavailable_note=""
)
synthesis_system_prompt = _load_prompt_file("synthesis_prompt_system.txt")
final_answer = self.llm.invoke([
SystemMessage(content=synthesis_system_prompt),
HumanMessage(content=synthesis_prompt)
]).content
if self.verbose:
print("Synthesis complete")
return {
"final_answer": final_answer,
"messages": [AIMessage(content=final_answer)]
}
def query(self, question: str, thread_id: str = "default") -> str:
"""Query the survey analysis system"""
# Try to load previous state from checkpoint to preserve stage_results across turns
previous_stage_results = []
try:
checkpoint = self.graph.get_state(config={"configurable": {"thread_id": thread_id}})
if checkpoint and checkpoint.values:
previous_stage_results = checkpoint.values.get("stage_results", [])
except:
previous_stage_results = []
initial_state = {
"messages": [HumanMessage(content=question)],
"user_question": question,
"research_brief": None,
"current_stage": 0,
"stage_results": previous_stage_results, # Preserve previous results for multi-turn
"final_answer": None
}
config = {
"configurable": {"thread_id": thread_id},
"recursion_limit": 50
}
if self.verbose:
print(f"\n🧵 Thread ID: {thread_id}")
final_state = self.graph.invoke(initial_state, config)
return final_state["final_answer"]
# ============================================================================
# CLI INTERFACE
# ============================================================================
def main():
"""Interactive CLI"""
import sys
openai_api_key = os.getenv("OPENAI_API_KEY")
pinecone_api_key = os.getenv("PINECONE_API_KEY")
if not openai_api_key or not pinecone_api_key:
print("Error: Missing API keys")
print("Set OPENAI_API_KEY and PINECONE_API_KEY environment variables")
sys.exit(1)
print("Initializing survey analysis agent...")
agent = SurveyAnalysisAgent(
openai_api_key=openai_api_key,
pinecone_api_key=pinecone_api_key,
verbose=True
)
print("\n" + "="*80)
print("SURVEY ANALYSIS AGENT")
print("="*80)
print("\nType 'quit' to exit\n")
thread_id = "cli_session"
while True:
try:
question = input("\nYour question: ").strip()
if not question or question.lower() in ['quit', 'exit', 'q']:
print("\nGoodbye!")
break
print("\n" + "-"*80)
answer = agent.query(question, thread_id=thread_id)
print("\n" + "="*80)
print("ANSWER:")
print("="*80)
print(answer)
print("="*80)
except KeyboardInterrupt:
print("\n\nGoodbye!")
break
except Exception as e:
print(f"\nError: {e}")
if os.getenv("DEBUG"):
raise
if __name__ == "__main__":
main()