survey-analytics / crosstab_rag.py
umangchaudhry's picture
Upload 20 files
cc2626e verified
"""
Crosstab RAG Module
------------------
Retrieves crosstab demographic breakdown data from Pinecone vectorstore.
Uses question_info for precise namespace matching and metadata filtering.
Returns raw data only - no synthesis.
"""
import os
from typing import List, Dict, Optional, Any
from pathlib import Path
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
from langchain.schema import Document
from langchain_pinecone import PineconeVectorStore
from pinecone import Pinecone
load_dotenv()
# Import QuestionnaireRAG to reuse question matching when needed
try:
from questionnaire_rag import QuestionnaireRAG
except ImportError:
# Handle case where running as module
from .questionnaire_rag import QuestionnaireRAG
PINECONE_RETRIEVE_K = 100
MAX_CROSSTAB_CHUNKS = 50
class CrosstabSummarizer:
"""Summarizes crosstab data to reduce token usage."""
def __init__(self, llm_model: str = None, openai_api_key: str = None):
from langchain_openai import ChatOpenAI
llm_model = llm_model or os.getenv("OPENAI_MODEL", "gpt-4o")
openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
self.llm = ChatOpenAI(model=llm_model, openai_api_key=openai_api_key, temperature=0.0)
def summarize(
self,
user_query: str,
retrieved_docs: List[Document],
question_text: Optional[str] = None,
top_n_sources: int = 6
) -> Dict:
"""Summarize crosstab data, extracting relevant demographic breakdowns."""
if not retrieved_docs:
return {"answer": "No relevant crosstab data found for that query.", "sources": []}
context_parts, sources = [], []
for i, d in enumerate(retrieved_docs):
# Handle both Document objects and dicts (from checkpoint deserialization)
if hasattr(d, 'metadata'):
md = d.metadata or {}
content = d.page_content or ""
elif isinstance(d, dict):
md = d.get("metadata", {})
content = d.get("page_content", "")
else:
md = {}
content = ""
id_hint = md.get("question_id") or md.get("variable_name") or f"part_{i+1}"
context_parts.append(f"--- Part {i+1} | {id_hint} ---\n{content}")
sources.append(id_hint)
context_text = "\n\n".join(context_parts)
# Load prompts
prompt_dir = Path(__file__).parent / "prompts"
system_prompt_path = prompt_dir / "crosstab_rag_prompt_system.txt"
user_prompt_path = prompt_dir / "crosstab_rag_prompt_user.txt"
system_prompt = system_prompt_path.read_text(encoding="utf-8") if system_prompt_path.exists() else ""
question_context = f"\n\nSURVEY QUESTION THAT WAS RETRIEVED: {question_text}" if question_text else ""
relevance_check = (
"\n\nโš ๏ธ RELEVANCE: The retrieved question IS relevant to the user's query. "
"Remember: ALL subtopics, specific examples, and related aspects ARE relevant:\n"
"- 'personal financial situation' IS about economy\n"
"- 'tariffs' IS about economy\n"
"- 'stock market' IS about economy\n"
"- 'gender-affirming healthcare' IS about healthcare\n"
"- 'Biden approval' IS about presidential approval\n"
"Only flag as irrelevant if about a COMPLETELY UNRELATED topic (e.g., user asked 'economy' but question is about 'sports teams'). "
"When in doubt, ANALYZE THE DATA - do not reject it."
) if question_text else ""
user_prompt_template = user_prompt_path.read_text(encoding="utf-8") if user_prompt_path.exists() else "{user_query}\n\n{context_text}"
user_prompt = user_prompt_template.format(
user_query=user_query,
question_context=question_context,
relevance_check=relevance_check,
context_text=context_text
)
from langchain.schema import HumanMessage, SystemMessage
messages = [SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
try:
result = self.llm.invoke(messages)
answer = result.content if hasattr(result, 'content') else str(result)
except Exception as e:
answer = f"Error generating summary: {e}"
return {"answer": answer.strip(), "sources": sources[:top_n_sources]}
class CrosstabRetriever:
"""Retrieves crosstab chunks from Pinecone using metadata filtering."""
def __init__(
self,
pinecone_api_key: str,
index_name: str,
embed_model: str,
openai_api_key: str,
verbose: bool = False
):
self.pc = Pinecone(api_key=pinecone_api_key)
self.index_name = index_name
self.embedder = OpenAIEmbeddings(model=embed_model, openai_api_key=openai_api_key)
self.verbose = verbose
def _build_namespace_from_question_info(self, question_info: Dict[str, Any]) -> Optional[str]:
"""Build namespace from question_info (year + month)"""
year = question_info.get("year")
month = question_info.get("month", "")
if year and month:
return f"Vanderbilt_Unity_Poll_{year}_{month}_cleaned_data_crosstabs".replace(" ", "_")
# Try to extract from poll_date
poll_date = question_info.get("poll_date", "")
if poll_date:
try:
from datetime import datetime
# Handle format like "2025-June"
if "-" in poll_date and len(poll_date.split("-")) == 2:
year_str, month_str = poll_date.split("-")
return f"Vanderbilt_Unity_Poll_{year_str}_{month_str}_cleaned_data_crosstabs".replace(" ", "_")
else:
date_obj = datetime.strptime(poll_date, "%Y-%m-%d")
year_str = str(date_obj.year)
month_str = date_obj.strftime("%B")
return f"Vanderbilt_Unity_Poll_{year_str}_{month_str}_cleaned_data_crosstabs".replace(" ", "_")
except Exception as e:
if self.verbose:
print(f" โš ๏ธ Failed to parse poll_date '{poll_date}': {e}")
return None
def retrieve_parts_for_question_info(
self,
question_info_list: List[Dict[str, Any]],
k: int = PINECONE_RETRIEVE_K,
filters: Optional[Dict[str, Any]] = None
) -> Dict[str, List[Document]]:
"""
Retrieve crosstab chunks for question_info list.
Groups by namespace (year/month) and filters by variable_name and question_id.
Args:
question_info_list: List of question info dicts with variable_name, year, month, question_id
k: Number of results to retrieve per variable
filters: Optional filters with year/month to constrain namespace search
Returns:
Dict mapping variable_name to list of Document objects
"""
try:
index = self.pc.Index(self.index_name)
stats = index.describe_index_stats()
available_namespaces = list(stats.get('namespaces', {}).keys())
if not available_namespaces:
if self.verbose:
print(" โš ๏ธ No namespaces found in index")
return {}
# Build target namespace from filters if provided
target_namespace = None
if filters:
year = filters.get("year")
month = filters.get("month", "")
if year and month:
target_namespace = f"Vanderbilt_Unity_Poll_{year}_{month}_cleaned_data_crosstabs".replace(" ", "_")
if target_namespace not in available_namespaces:
if self.verbose:
print(f" โš ๏ธ Target namespace {target_namespace} not found in available namespaces")
target_namespace = None
# Group questions by namespace
questions_by_namespace = {}
for q_info in question_info_list:
var_name = q_info.get("variable_name")
if not var_name:
continue
# Try to build namespace from question_info first
namespace = self._build_namespace_from_question_info(q_info)
if namespace and namespace in available_namespaces:
if namespace not in questions_by_namespace:
questions_by_namespace[namespace] = []
questions_by_namespace[namespace].append(var_name)
elif target_namespace:
# Use target namespace from filters
if target_namespace not in questions_by_namespace:
questions_by_namespace[target_namespace] = []
questions_by_namespace[target_namespace].append(var_name)
else:
# Only search all namespaces if NO question metadata is available
# This prevents broad searches when question_info is provided
if self.verbose:
print(f" โš ๏ธ Could not determine namespace for {var_name} (year={q_info.get('year')}, month={q_info.get('month')})")
# Skip this question rather than searching all namespaces
continue
# Get embedding dimension
embed_dim = 1536 # Default for text-embedding-3-small
try:
if hasattr(self.embedder, 'model') and 'small' in str(self.embedder.model).lower():
embed_dim = 1536
elif hasattr(self.embedder, 'model') and 'large' in str(self.embedder.model).lower():
embed_dim = 3072
except:
pass
dummy_vector = [0.0] * embed_dim
all_docs_by_variable = {}
# Build mapping from variable_name to question_id for filtering
var_to_question_id = {}
for q_info in question_info_list:
var_name = q_info.get("variable_name")
question_id = q_info.get("question_id")
if var_name and question_id:
var_to_question_id[var_name] = question_id
# Search each namespace
for namespace, var_names in questions_by_namespace.items():
if namespace not in available_namespaces:
continue
if self.verbose:
print(f" ๐Ÿ” Searching namespace: {namespace}")
print(f" Looking for variables: {', '.join(sorted(set(var_names)))}")
if var_to_question_id:
matched_vars = [v for v in var_names if v in var_to_question_id]
if matched_vars:
print(f" ๐Ÿ”‘ Using question_id filter for: {', '.join(sorted(set(matched_vars)))}")
# Build filter for variable names and question IDs
unique_vars = list(set(var_names))
# Build filter conditions - match on either variable_name OR question_id
filter_conditions = []
for var in unique_vars:
var_conditions = []
# Add variable_name conditions (with and without _crosstab suffix)
var_conditions.append({"variable_name": {"$eq": var}})
var_conditions.append({"variable_name": {"$eq": f"{var}_crosstab"}})
# Add question_id condition if available
# Note: question_id in Pinecone metadata might have _part suffix for chunked crosstabs
# but we match on base question_id and filter in post-processing
if var in var_to_question_id:
question_id = var_to_question_id[var]
var_conditions.append({"question_id": {"$eq": question_id}})
# Combine conditions for this variable with $or
if len(var_conditions) > 1:
filter_conditions.append({"$or": var_conditions})
else:
filter_conditions.append(var_conditions[0])
# Combine all variable filters with $or
if len(filter_conditions) == 1:
var_filter = filter_conditions[0]
else:
var_filter = {"$or": filter_conditions}
try:
result = index.query(
vector=dummy_vector,
top_k=k * len(unique_vars),
namespace=namespace,
filter=var_filter,
include_metadata=True
)
if self.verbose:
print(f" ๐Ÿ“Š Found {len(result.matches)} matches in {namespace}")
for match in result.matches:
metadata = match.metadata or {}
var_name = metadata.get("variable_name", "")
# Handle question_id format like "VAND10_part1"
question_id = metadata.get("question_id", "")
if question_id and "_part" in question_id:
base_var = question_id.split("_part")[0].replace("_crosstab", "")
if base_var in unique_vars:
var_name = base_var
# Check if variable_name has _crosstab suffix
if var_name and var_name.endswith("_crosstab"):
base_var = var_name.replace("_crosstab", "")
if base_var in unique_vars:
var_name = base_var
if not var_name or var_name not in unique_vars:
continue
content = metadata.pop('text', '') or metadata.pop('page_content', '') or ''
if not content:
continue
if var_name not in all_docs_by_variable:
all_docs_by_variable[var_name] = []
all_docs_by_variable[var_name].append(
Document(page_content=content, metadata=metadata)
)
except Exception as e:
if self.verbose:
print(f" โš ๏ธ Error querying namespace {namespace}: {e}")
continue
# Sort documents by chunk_index
for var_name in all_docs_by_variable:
all_docs_by_variable[var_name].sort(key=lambda d: d.metadata.get("chunk_index", 999))
all_docs_by_variable[var_name] = all_docs_by_variable[var_name][:MAX_CROSSTAB_CHUNKS]
if self.verbose:
total_docs = sum(len(docs) for docs in all_docs_by_variable.values())
print(f" โœ… Retrieved {total_docs} total document(s) for {len(all_docs_by_variable)} variable(s)")
return all_docs_by_variable
except Exception as e:
if self.verbose:
print(f" โŒ Error in retrieve_parts_for_question_info: {e}")
return {}
class CrosstabsRAG:
"""Crosstabs RAG with question_info-based retrieval."""
def __init__(
self,
questionnaire_rag: QuestionnaireRAG,
verbose: bool = False
):
self.questionnaire_rag = questionnaire_rag
self.verbose = verbose
pinecone_api_key = os.getenv("PINECONE_API_KEY")
openai_api_key = os.getenv("OPENAI_API_KEY")
index_name = os.getenv("PINECONE_INDEX_NAME_CROSSTABS", "crosstab-index")
embed_model = os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small")
self.retriever = CrosstabRetriever(
pinecone_api_key=pinecone_api_key,
index_name=index_name,
embed_model=embed_model,
openai_api_key=openai_api_key,
verbose=verbose
)
def retrieve_raw_data(
self,
user_query: str,
question_info: Optional[List[Dict[str, Any]]] = None,
source_questions: Optional[List[Dict[str, Any]]] = None,
filters: Optional[Dict[str, Any]] = None
) -> Dict:
"""
Retrieve raw crosstab data.
Uses question_info if provided (skips QuestionnaireRAG).
Otherwise uses QuestionnaireRAG to find questions, then retrieves crosstabs.
Falls back to semantic search if metadata filtering returns no results.
Args:
user_query: User's query (used for QuestionnaireRAG if question_info not provided)
question_info: List of question info dicts (preferred - skips QuestionnaireRAG)
source_questions: Optional list of full question dicts from previous stage (avoids lookup)
filters: Optional filters for QuestionnaireRAG
Returns:
Dict with crosstab_docs_by_variable, matched_questions, namespace_used, survey_info
"""
if self.verbose:
print(f"\n๐Ÿ“Š [Crosstabs] Query: {user_query}")
if question_info:
print(f"๐Ÿ” Question info: {len(question_info)} question(s) provided")
if filters:
print(f"๐Ÿ” Filters: {filters}")
# If question_info provided, skip QuestionnaireRAG
if question_info:
if self.verbose:
print(f"โœ… Using provided question_info, skipping QuestionnaireRAG")
# Retrieve crosstab data directly
crosstab_docs_by_variable = self.retriever.retrieve_parts_for_question_info(
question_info_list=question_info,
k=PINECONE_RETRIEVE_K,
filters=filters
)
if not crosstab_docs_by_variable:
return {"error": f"No crosstab data found for {len(question_info)} question(s)."}
# Get question metadata - use provided source_questions if available, otherwise lookup
if not source_questions:
source_questions = []
questions_by_id = self.questionnaire_rag.questions_by_id
for q_info in question_info:
question_id = q_info.get("question_id")
if question_id and question_id in questions_by_id:
source_questions.append(questions_by_id[question_id])
else:
# Fallback: try to find by variable_name and year/month
var_name = q_info.get("variable_name")
year = q_info.get("year")
month = q_info.get("month", "")
if var_name:
# Search through questions_by_id for matching variable
for qid, q_data in questions_by_id.items():
if (q_data.get("variable_name") == var_name and
q_data.get("year") == year and
q_data.get("month", "") == month):
source_questions.append(q_data)
break
# Format results
formatted_results = {}
matched_variables = []
all_namespaces = set()
for var_name, docs in crosstab_docs_by_variable.items():
question_metadata = next(
(q for q in source_questions if q.get("variable_name") == var_name),
{}
)
question_text = question_metadata.get("question_text", "")
if docs:
first_doc_meta = docs[0].metadata
survey_name = first_doc_meta.get("survey_name", "")
all_namespaces.add(survey_name)
formatted_results[var_name] = {
"crosstab_docs": docs,
"question_text": question_text or (docs[0].metadata.get("question_text", "") if docs else ""),
"matched_question": question_metadata
}
matched_variables.append(var_name)
return {
"crosstab_docs_by_variable": formatted_results,
"matched_questions": source_questions,
"matched_variables": matched_variables,
"namespace_used": list(all_namespaces),
"survey_info": {"poll": "Vanderbilt_Unity_Poll", "year": None, "month": None}
}
# Otherwise, use QuestionnaireRAG to find questions first
if self.verbose:
print(f"๐Ÿ” Using QuestionnaireRAG to find questions")
try:
q_result = self.questionnaire_rag.retrieve_raw_data(
question=user_query,
filters=filters or {},
k=10
)
except Exception as e:
return {"error": f"Error querying questionnaire: {e}"}
source_questions = q_result.get("source_questions", [])
question_info_from_questions = q_result.get("question_info", [])
if not source_questions:
return {"error": "No matching questions found in questionnaire for that query."}
if self.verbose:
print(f"โœ… Found {len(source_questions)} question(s) from QuestionnaireRAG")
# Retrieve crosstab data using question_info
crosstab_docs_by_variable = self.retriever.retrieve_parts_for_question_info(
question_info_list=question_info_from_questions,
k=PINECONE_RETRIEVE_K
)
if not crosstab_docs_by_variable:
return {"error": f"No crosstab data found for any of the {len(source_questions)} matched questions."}
# Format results
formatted_results = {}
matched_variables = []
all_namespaces = set()
for matched_question in source_questions:
variable_name = matched_question["variable_name"]
question_text = matched_question["question_text"]
if variable_name in crosstab_docs_by_variable:
formatted_results[variable_name] = {
"crosstab_docs": crosstab_docs_by_variable[variable_name],
"question_text": question_text,
"matched_question": matched_question
}
matched_variables.append(variable_name)
if crosstab_docs_by_variable[variable_name]:
first_doc = crosstab_docs_by_variable[variable_name][0]
survey_name = first_doc.metadata.get("survey_name", "")
all_namespaces.add(survey_name)
return {
"crosstab_docs_by_variable": formatted_results,
"matched_questions": source_questions,
"matched_variables": matched_variables,
"namespace_used": list(all_namespaces),
"survey_info": {"poll": "Vanderbilt_Unity_Poll", "year": None, "month": None}
}