""" Crosstab RAG Module ------------------ Retrieves crosstab demographic breakdown data from Pinecone vectorstore. Uses question_info for precise namespace matching and metadata filtering. Returns raw data only - no synthesis. """ import os from typing import List, Dict, Optional, Any from pathlib import Path from dotenv import load_dotenv from langchain_openai import OpenAIEmbeddings from langchain.schema import Document from langchain_pinecone import PineconeVectorStore from pinecone import Pinecone load_dotenv() # Import QuestionnaireRAG to reuse question matching when needed try: from questionnaire_rag import QuestionnaireRAG except ImportError: # Handle case where running as module from .questionnaire_rag import QuestionnaireRAG PINECONE_RETRIEVE_K = 100 MAX_CROSSTAB_CHUNKS = 50 class CrosstabSummarizer: """Summarizes crosstab data to reduce token usage.""" def __init__(self, llm_model: str = None, openai_api_key: str = None): from langchain_openai import ChatOpenAI llm_model = llm_model or os.getenv("OPENAI_MODEL", "gpt-4o") openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY") self.llm = ChatOpenAI(model=llm_model, openai_api_key=openai_api_key, temperature=0.0) def summarize( self, user_query: str, retrieved_docs: List[Document], question_text: Optional[str] = None, top_n_sources: int = 6 ) -> Dict: """Summarize crosstab data, extracting relevant demographic breakdowns.""" if not retrieved_docs: return {"answer": "No relevant crosstab data found for that query.", "sources": []} context_parts, sources = [], [] for i, d in enumerate(retrieved_docs): # Handle both Document objects and dicts (from checkpoint deserialization) if hasattr(d, 'metadata'): md = d.metadata or {} content = d.page_content or "" elif isinstance(d, dict): md = d.get("metadata", {}) content = d.get("page_content", "") else: md = {} content = "" id_hint = md.get("question_id") or md.get("variable_name") or f"part_{i+1}" context_parts.append(f"--- Part {i+1} | {id_hint} ---\n{content}") sources.append(id_hint) context_text = "\n\n".join(context_parts) # Load prompts prompt_dir = Path(__file__).parent / "prompts" system_prompt_path = prompt_dir / "crosstab_rag_prompt_system.txt" user_prompt_path = prompt_dir / "crosstab_rag_prompt_user.txt" system_prompt = system_prompt_path.read_text(encoding="utf-8") if system_prompt_path.exists() else "" question_context = f"\n\nSURVEY QUESTION THAT WAS RETRIEVED: {question_text}" if question_text else "" relevance_check = ( "\n\n⚠️ RELEVANCE: The retrieved question IS relevant to the user's query. " "Remember: ALL subtopics, specific examples, and related aspects ARE relevant:\n" "- 'personal financial situation' IS about economy\n" "- 'tariffs' IS about economy\n" "- 'stock market' IS about economy\n" "- 'gender-affirming healthcare' IS about healthcare\n" "- 'Biden approval' IS about presidential approval\n" "Only flag as irrelevant if about a COMPLETELY UNRELATED topic (e.g., user asked 'economy' but question is about 'sports teams'). " "When in doubt, ANALYZE THE DATA - do not reject it." ) if question_text else "" user_prompt_template = user_prompt_path.read_text(encoding="utf-8") if user_prompt_path.exists() else "{user_query}\n\n{context_text}" user_prompt = user_prompt_template.format( user_query=user_query, question_context=question_context, relevance_check=relevance_check, context_text=context_text ) from langchain.schema import HumanMessage, SystemMessage messages = [SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)] try: result = self.llm.invoke(messages) answer = result.content if hasattr(result, 'content') else str(result) except Exception as e: answer = f"Error generating summary: {e}" return {"answer": answer.strip(), "sources": sources[:top_n_sources]} class CrosstabRetriever: """Retrieves crosstab chunks from Pinecone using metadata filtering.""" def __init__( self, pinecone_api_key: str, index_name: str, embed_model: str, openai_api_key: str, verbose: bool = False ): self.pc = Pinecone(api_key=pinecone_api_key) self.index_name = index_name self.embedder = OpenAIEmbeddings(model=embed_model, openai_api_key=openai_api_key) self.verbose = verbose def _build_namespace_from_question_info(self, question_info: Dict[str, Any]) -> Optional[str]: """Build namespace from question_info (year + month)""" year = question_info.get("year") month = question_info.get("month", "") if year and month: return f"Vanderbilt_Unity_Poll_{year}_{month}_cleaned_data_crosstabs".replace(" ", "_") # Try to extract from poll_date poll_date = question_info.get("poll_date", "") if poll_date: try: from datetime import datetime # Handle format like "2025-June" if "-" in poll_date and len(poll_date.split("-")) == 2: year_str, month_str = poll_date.split("-") return f"Vanderbilt_Unity_Poll_{year_str}_{month_str}_cleaned_data_crosstabs".replace(" ", "_") else: date_obj = datetime.strptime(poll_date, "%Y-%m-%d") year_str = str(date_obj.year) month_str = date_obj.strftime("%B") return f"Vanderbilt_Unity_Poll_{year_str}_{month_str}_cleaned_data_crosstabs".replace(" ", "_") except Exception as e: if self.verbose: print(f" ⚠️ Failed to parse poll_date '{poll_date}': {e}") return None def retrieve_parts_for_question_info( self, question_info_list: List[Dict[str, Any]], k: int = PINECONE_RETRIEVE_K, filters: Optional[Dict[str, Any]] = None ) -> Dict[str, List[Document]]: """ Retrieve crosstab chunks for question_info list. Groups by namespace (year/month) and filters by variable_name and question_id. Args: question_info_list: List of question info dicts with variable_name, year, month, question_id k: Number of results to retrieve per variable filters: Optional filters with year/month to constrain namespace search Returns: Dict mapping variable_name to list of Document objects """ try: index = self.pc.Index(self.index_name) stats = index.describe_index_stats() available_namespaces = list(stats.get('namespaces', {}).keys()) if not available_namespaces: if self.verbose: print(" ⚠️ No namespaces found in index") return {} # Build target namespace from filters if provided target_namespace = None if filters: year = filters.get("year") month = filters.get("month", "") if year and month: target_namespace = f"Vanderbilt_Unity_Poll_{year}_{month}_cleaned_data_crosstabs".replace(" ", "_") if target_namespace not in available_namespaces: if self.verbose: print(f" ⚠️ Target namespace {target_namespace} not found in available namespaces") target_namespace = None # Group questions by namespace questions_by_namespace = {} for q_info in question_info_list: var_name = q_info.get("variable_name") if not var_name: continue # Try to build namespace from question_info first namespace = self._build_namespace_from_question_info(q_info) if namespace and namespace in available_namespaces: if namespace not in questions_by_namespace: questions_by_namespace[namespace] = [] questions_by_namespace[namespace].append(var_name) elif target_namespace: # Use target namespace from filters if target_namespace not in questions_by_namespace: questions_by_namespace[target_namespace] = [] questions_by_namespace[target_namespace].append(var_name) else: # Only search all namespaces if NO question metadata is available # This prevents broad searches when question_info is provided if self.verbose: print(f" ⚠️ Could not determine namespace for {var_name} (year={q_info.get('year')}, month={q_info.get('month')})") # Skip this question rather than searching all namespaces continue # Get embedding dimension embed_dim = 1536 # Default for text-embedding-3-small try: if hasattr(self.embedder, 'model') and 'small' in str(self.embedder.model).lower(): embed_dim = 1536 elif hasattr(self.embedder, 'model') and 'large' in str(self.embedder.model).lower(): embed_dim = 3072 except: pass dummy_vector = [0.0] * embed_dim all_docs_by_variable = {} # Build mapping from variable_name to question_id for filtering var_to_question_id = {} for q_info in question_info_list: var_name = q_info.get("variable_name") question_id = q_info.get("question_id") if var_name and question_id: var_to_question_id[var_name] = question_id # Search each namespace for namespace, var_names in questions_by_namespace.items(): if namespace not in available_namespaces: continue if self.verbose: print(f" 🔍 Searching namespace: {namespace}") print(f" Looking for variables: {', '.join(sorted(set(var_names)))}") if var_to_question_id: matched_vars = [v for v in var_names if v in var_to_question_id] if matched_vars: print(f" 🔑 Using question_id filter for: {', '.join(sorted(set(matched_vars)))}") # Build filter for variable names and question IDs unique_vars = list(set(var_names)) # Build filter conditions - match on either variable_name OR question_id filter_conditions = [] for var in unique_vars: var_conditions = [] # Add variable_name conditions (with and without _crosstab suffix) var_conditions.append({"variable_name": {"$eq": var}}) var_conditions.append({"variable_name": {"$eq": f"{var}_crosstab"}}) # Add question_id condition if available # Note: question_id in Pinecone metadata might have _part suffix for chunked crosstabs # but we match on base question_id and filter in post-processing if var in var_to_question_id: question_id = var_to_question_id[var] var_conditions.append({"question_id": {"$eq": question_id}}) # Combine conditions for this variable with $or if len(var_conditions) > 1: filter_conditions.append({"$or": var_conditions}) else: filter_conditions.append(var_conditions[0]) # Combine all variable filters with $or if len(filter_conditions) == 1: var_filter = filter_conditions[0] else: var_filter = {"$or": filter_conditions} try: result = index.query( vector=dummy_vector, top_k=k * len(unique_vars), namespace=namespace, filter=var_filter, include_metadata=True ) if self.verbose: print(f" 📊 Found {len(result.matches)} matches in {namespace}") for match in result.matches: metadata = match.metadata or {} var_name = metadata.get("variable_name", "") # Handle question_id format like "VAND10_part1" question_id = metadata.get("question_id", "") if question_id and "_part" in question_id: base_var = question_id.split("_part")[0].replace("_crosstab", "") if base_var in unique_vars: var_name = base_var # Check if variable_name has _crosstab suffix if var_name and var_name.endswith("_crosstab"): base_var = var_name.replace("_crosstab", "") if base_var in unique_vars: var_name = base_var if not var_name or var_name not in unique_vars: continue content = metadata.pop('text', '') or metadata.pop('page_content', '') or '' if not content: continue if var_name not in all_docs_by_variable: all_docs_by_variable[var_name] = [] all_docs_by_variable[var_name].append( Document(page_content=content, metadata=metadata) ) except Exception as e: if self.verbose: print(f" ⚠️ Error querying namespace {namespace}: {e}") continue # Sort documents by chunk_index for var_name in all_docs_by_variable: all_docs_by_variable[var_name].sort(key=lambda d: d.metadata.get("chunk_index", 999)) all_docs_by_variable[var_name] = all_docs_by_variable[var_name][:MAX_CROSSTAB_CHUNKS] if self.verbose: total_docs = sum(len(docs) for docs in all_docs_by_variable.values()) print(f" ✅ Retrieved {total_docs} total document(s) for {len(all_docs_by_variable)} variable(s)") return all_docs_by_variable except Exception as e: if self.verbose: print(f" ❌ Error in retrieve_parts_for_question_info: {e}") return {} class CrosstabsRAG: """Crosstabs RAG with question_info-based retrieval.""" def __init__( self, questionnaire_rag: QuestionnaireRAG, verbose: bool = False ): self.questionnaire_rag = questionnaire_rag self.verbose = verbose pinecone_api_key = os.getenv("PINECONE_API_KEY") openai_api_key = os.getenv("OPENAI_API_KEY") index_name = os.getenv("PINECONE_INDEX_NAME_CROSSTABS", "crosstab-index") embed_model = os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small") self.retriever = CrosstabRetriever( pinecone_api_key=pinecone_api_key, index_name=index_name, embed_model=embed_model, openai_api_key=openai_api_key, verbose=verbose ) def retrieve_raw_data( self, user_query: str, question_info: Optional[List[Dict[str, Any]]] = None, source_questions: Optional[List[Dict[str, Any]]] = None, filters: Optional[Dict[str, Any]] = None ) -> Dict: """ Retrieve raw crosstab data. Uses question_info if provided (skips QuestionnaireRAG). Otherwise uses QuestionnaireRAG to find questions, then retrieves crosstabs. Falls back to semantic search if metadata filtering returns no results. Args: user_query: User's query (used for QuestionnaireRAG if question_info not provided) question_info: List of question info dicts (preferred - skips QuestionnaireRAG) source_questions: Optional list of full question dicts from previous stage (avoids lookup) filters: Optional filters for QuestionnaireRAG Returns: Dict with crosstab_docs_by_variable, matched_questions, namespace_used, survey_info """ if self.verbose: print(f"\n📊 [Crosstabs] Query: {user_query}") if question_info: print(f"🔍 Question info: {len(question_info)} question(s) provided") if filters: print(f"🔍 Filters: {filters}") # If question_info provided, skip QuestionnaireRAG if question_info: if self.verbose: print(f"✅ Using provided question_info, skipping QuestionnaireRAG") # Retrieve crosstab data directly crosstab_docs_by_variable = self.retriever.retrieve_parts_for_question_info( question_info_list=question_info, k=PINECONE_RETRIEVE_K, filters=filters ) if not crosstab_docs_by_variable: return {"error": f"No crosstab data found for {len(question_info)} question(s)."} # Get question metadata - use provided source_questions if available, otherwise lookup if not source_questions: source_questions = [] questions_by_id = self.questionnaire_rag.questions_by_id for q_info in question_info: question_id = q_info.get("question_id") if question_id and question_id in questions_by_id: source_questions.append(questions_by_id[question_id]) else: # Fallback: try to find by variable_name and year/month var_name = q_info.get("variable_name") year = q_info.get("year") month = q_info.get("month", "") if var_name: # Search through questions_by_id for matching variable for qid, q_data in questions_by_id.items(): if (q_data.get("variable_name") == var_name and q_data.get("year") == year and q_data.get("month", "") == month): source_questions.append(q_data) break # Format results formatted_results = {} matched_variables = [] all_namespaces = set() for var_name, docs in crosstab_docs_by_variable.items(): question_metadata = next( (q for q in source_questions if q.get("variable_name") == var_name), {} ) question_text = question_metadata.get("question_text", "") if docs: first_doc_meta = docs[0].metadata survey_name = first_doc_meta.get("survey_name", "") all_namespaces.add(survey_name) formatted_results[var_name] = { "crosstab_docs": docs, "question_text": question_text or (docs[0].metadata.get("question_text", "") if docs else ""), "matched_question": question_metadata } matched_variables.append(var_name) return { "crosstab_docs_by_variable": formatted_results, "matched_questions": source_questions, "matched_variables": matched_variables, "namespace_used": list(all_namespaces), "survey_info": {"poll": "Vanderbilt_Unity_Poll", "year": None, "month": None} } # Otherwise, use QuestionnaireRAG to find questions first if self.verbose: print(f"🔍 Using QuestionnaireRAG to find questions") try: q_result = self.questionnaire_rag.retrieve_raw_data( question=user_query, filters=filters or {}, k=10 ) except Exception as e: return {"error": f"Error querying questionnaire: {e}"} source_questions = q_result.get("source_questions", []) question_info_from_questions = q_result.get("question_info", []) if not source_questions: return {"error": "No matching questions found in questionnaire for that query."} if self.verbose: print(f"✅ Found {len(source_questions)} question(s) from QuestionnaireRAG") # Retrieve crosstab data using question_info crosstab_docs_by_variable = self.retriever.retrieve_parts_for_question_info( question_info_list=question_info_from_questions, k=PINECONE_RETRIEVE_K ) if not crosstab_docs_by_variable: return {"error": f"No crosstab data found for any of the {len(source_questions)} matched questions."} # Format results formatted_results = {} matched_variables = [] all_namespaces = set() for matched_question in source_questions: variable_name = matched_question["variable_name"] question_text = matched_question["question_text"] if variable_name in crosstab_docs_by_variable: formatted_results[variable_name] = { "crosstab_docs": crosstab_docs_by_variable[variable_name], "question_text": question_text, "matched_question": matched_question } matched_variables.append(variable_name) if crosstab_docs_by_variable[variable_name]: first_doc = crosstab_docs_by_variable[variable_name][0] survey_name = first_doc.metadata.get("survey_name", "") all_namespaces.add(survey_name) return { "crosstab_docs_by_variable": formatted_results, "matched_questions": source_questions, "matched_variables": matched_variables, "namespace_used": list(all_namespaces), "survey_info": {"poll": "Vanderbilt_Unity_Poll", "year": None, "month": None} }