Spaces:
Running
Running
| """ | |
| Crosstab RAG Module | |
| ------------------ | |
| Retrieves crosstab demographic breakdown data from Pinecone vectorstore. | |
| Uses question_info for precise namespace matching and metadata filtering. | |
| Returns raw data only - no synthesis. | |
| """ | |
| import os | |
| from typing import List, Dict, Optional, Any | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain.schema import Document | |
| from langchain_pinecone import PineconeVectorStore | |
| from pinecone import Pinecone | |
| load_dotenv() | |
| # Import QuestionnaireRAG to reuse question matching when needed | |
| try: | |
| from questionnaire_rag import QuestionnaireRAG | |
| except ImportError: | |
| # Handle case where running as module | |
| from .questionnaire_rag import QuestionnaireRAG | |
| PINECONE_RETRIEVE_K = 100 | |
| MAX_CROSSTAB_CHUNKS = 50 | |
| class CrosstabSummarizer: | |
| """Summarizes crosstab data to reduce token usage.""" | |
| def __init__(self, llm_model: str = None, openai_api_key: str = None): | |
| from langchain_openai import ChatOpenAI | |
| llm_model = llm_model or os.getenv("OPENAI_MODEL", "gpt-4o") | |
| openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY") | |
| self.llm = ChatOpenAI(model=llm_model, openai_api_key=openai_api_key, temperature=0.0) | |
| def summarize( | |
| self, | |
| user_query: str, | |
| retrieved_docs: List[Document], | |
| question_text: Optional[str] = None, | |
| top_n_sources: int = 6 | |
| ) -> Dict: | |
| """Summarize crosstab data, extracting relevant demographic breakdowns.""" | |
| if not retrieved_docs: | |
| return {"answer": "No relevant crosstab data found for that query.", "sources": []} | |
| context_parts, sources = [], [] | |
| for i, d in enumerate(retrieved_docs): | |
| # Handle both Document objects and dicts (from checkpoint deserialization) | |
| if hasattr(d, 'metadata'): | |
| md = d.metadata or {} | |
| content = d.page_content or "" | |
| elif isinstance(d, dict): | |
| md = d.get("metadata", {}) | |
| content = d.get("page_content", "") | |
| else: | |
| md = {} | |
| content = "" | |
| id_hint = md.get("question_id") or md.get("variable_name") or f"part_{i+1}" | |
| context_parts.append(f"--- Part {i+1} | {id_hint} ---\n{content}") | |
| sources.append(id_hint) | |
| context_text = "\n\n".join(context_parts) | |
| # Load prompts | |
| prompt_dir = Path(__file__).parent / "prompts" | |
| system_prompt_path = prompt_dir / "crosstab_rag_prompt_system.txt" | |
| user_prompt_path = prompt_dir / "crosstab_rag_prompt_user.txt" | |
| system_prompt = system_prompt_path.read_text(encoding="utf-8") if system_prompt_path.exists() else "" | |
| question_context = f"\n\nSURVEY QUESTION THAT WAS RETRIEVED: {question_text}" if question_text else "" | |
| relevance_check = ( | |
| "\n\nโ ๏ธ RELEVANCE: The retrieved question IS relevant to the user's query. " | |
| "Remember: ALL subtopics, specific examples, and related aspects ARE relevant:\n" | |
| "- 'personal financial situation' IS about economy\n" | |
| "- 'tariffs' IS about economy\n" | |
| "- 'stock market' IS about economy\n" | |
| "- 'gender-affirming healthcare' IS about healthcare\n" | |
| "- 'Biden approval' IS about presidential approval\n" | |
| "Only flag as irrelevant if about a COMPLETELY UNRELATED topic (e.g., user asked 'economy' but question is about 'sports teams'). " | |
| "When in doubt, ANALYZE THE DATA - do not reject it." | |
| ) if question_text else "" | |
| user_prompt_template = user_prompt_path.read_text(encoding="utf-8") if user_prompt_path.exists() else "{user_query}\n\n{context_text}" | |
| user_prompt = user_prompt_template.format( | |
| user_query=user_query, | |
| question_context=question_context, | |
| relevance_check=relevance_check, | |
| context_text=context_text | |
| ) | |
| from langchain.schema import HumanMessage, SystemMessage | |
| messages = [SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)] | |
| try: | |
| result = self.llm.invoke(messages) | |
| answer = result.content if hasattr(result, 'content') else str(result) | |
| except Exception as e: | |
| answer = f"Error generating summary: {e}" | |
| return {"answer": answer.strip(), "sources": sources[:top_n_sources]} | |
| class CrosstabRetriever: | |
| """Retrieves crosstab chunks from Pinecone using metadata filtering.""" | |
| def __init__( | |
| self, | |
| pinecone_api_key: str, | |
| index_name: str, | |
| embed_model: str, | |
| openai_api_key: str, | |
| verbose: bool = False | |
| ): | |
| self.pc = Pinecone(api_key=pinecone_api_key) | |
| self.index_name = index_name | |
| self.embedder = OpenAIEmbeddings(model=embed_model, openai_api_key=openai_api_key) | |
| self.verbose = verbose | |
| def _build_namespace_from_question_info(self, question_info: Dict[str, Any]) -> Optional[str]: | |
| """Build namespace from question_info (year + month)""" | |
| year = question_info.get("year") | |
| month = question_info.get("month", "") | |
| if year and month: | |
| return f"Vanderbilt_Unity_Poll_{year}_{month}_cleaned_data_crosstabs".replace(" ", "_") | |
| # Try to extract from poll_date | |
| poll_date = question_info.get("poll_date", "") | |
| if poll_date: | |
| try: | |
| from datetime import datetime | |
| # Handle format like "2025-June" | |
| if "-" in poll_date and len(poll_date.split("-")) == 2: | |
| year_str, month_str = poll_date.split("-") | |
| return f"Vanderbilt_Unity_Poll_{year_str}_{month_str}_cleaned_data_crosstabs".replace(" ", "_") | |
| else: | |
| date_obj = datetime.strptime(poll_date, "%Y-%m-%d") | |
| year_str = str(date_obj.year) | |
| month_str = date_obj.strftime("%B") | |
| return f"Vanderbilt_Unity_Poll_{year_str}_{month_str}_cleaned_data_crosstabs".replace(" ", "_") | |
| except Exception as e: | |
| if self.verbose: | |
| print(f" โ ๏ธ Failed to parse poll_date '{poll_date}': {e}") | |
| return None | |
| def retrieve_parts_for_question_info( | |
| self, | |
| question_info_list: List[Dict[str, Any]], | |
| k: int = PINECONE_RETRIEVE_K, | |
| filters: Optional[Dict[str, Any]] = None | |
| ) -> Dict[str, List[Document]]: | |
| """ | |
| Retrieve crosstab chunks for question_info list. | |
| Groups by namespace (year/month) and filters by variable_name and question_id. | |
| Args: | |
| question_info_list: List of question info dicts with variable_name, year, month, question_id | |
| k: Number of results to retrieve per variable | |
| filters: Optional filters with year/month to constrain namespace search | |
| Returns: | |
| Dict mapping variable_name to list of Document objects | |
| """ | |
| try: | |
| index = self.pc.Index(self.index_name) | |
| stats = index.describe_index_stats() | |
| available_namespaces = list(stats.get('namespaces', {}).keys()) | |
| if not available_namespaces: | |
| if self.verbose: | |
| print(" โ ๏ธ No namespaces found in index") | |
| return {} | |
| # Build target namespace from filters if provided | |
| target_namespace = None | |
| if filters: | |
| year = filters.get("year") | |
| month = filters.get("month", "") | |
| if year and month: | |
| target_namespace = f"Vanderbilt_Unity_Poll_{year}_{month}_cleaned_data_crosstabs".replace(" ", "_") | |
| if target_namespace not in available_namespaces: | |
| if self.verbose: | |
| print(f" โ ๏ธ Target namespace {target_namespace} not found in available namespaces") | |
| target_namespace = None | |
| # Group questions by namespace | |
| questions_by_namespace = {} | |
| for q_info in question_info_list: | |
| var_name = q_info.get("variable_name") | |
| if not var_name: | |
| continue | |
| # Try to build namespace from question_info first | |
| namespace = self._build_namespace_from_question_info(q_info) | |
| if namespace and namespace in available_namespaces: | |
| if namespace not in questions_by_namespace: | |
| questions_by_namespace[namespace] = [] | |
| questions_by_namespace[namespace].append(var_name) | |
| elif target_namespace: | |
| # Use target namespace from filters | |
| if target_namespace not in questions_by_namespace: | |
| questions_by_namespace[target_namespace] = [] | |
| questions_by_namespace[target_namespace].append(var_name) | |
| else: | |
| # Only search all namespaces if NO question metadata is available | |
| # This prevents broad searches when question_info is provided | |
| if self.verbose: | |
| print(f" โ ๏ธ Could not determine namespace for {var_name} (year={q_info.get('year')}, month={q_info.get('month')})") | |
| # Skip this question rather than searching all namespaces | |
| continue | |
| # Get embedding dimension | |
| embed_dim = 1536 # Default for text-embedding-3-small | |
| try: | |
| if hasattr(self.embedder, 'model') and 'small' in str(self.embedder.model).lower(): | |
| embed_dim = 1536 | |
| elif hasattr(self.embedder, 'model') and 'large' in str(self.embedder.model).lower(): | |
| embed_dim = 3072 | |
| except: | |
| pass | |
| dummy_vector = [0.0] * embed_dim | |
| all_docs_by_variable = {} | |
| # Build mapping from variable_name to question_id for filtering | |
| var_to_question_id = {} | |
| for q_info in question_info_list: | |
| var_name = q_info.get("variable_name") | |
| question_id = q_info.get("question_id") | |
| if var_name and question_id: | |
| var_to_question_id[var_name] = question_id | |
| # Search each namespace | |
| for namespace, var_names in questions_by_namespace.items(): | |
| if namespace not in available_namespaces: | |
| continue | |
| if self.verbose: | |
| print(f" ๐ Searching namespace: {namespace}") | |
| print(f" Looking for variables: {', '.join(sorted(set(var_names)))}") | |
| if var_to_question_id: | |
| matched_vars = [v for v in var_names if v in var_to_question_id] | |
| if matched_vars: | |
| print(f" ๐ Using question_id filter for: {', '.join(sorted(set(matched_vars)))}") | |
| # Build filter for variable names and question IDs | |
| unique_vars = list(set(var_names)) | |
| # Build filter conditions - match on either variable_name OR question_id | |
| filter_conditions = [] | |
| for var in unique_vars: | |
| var_conditions = [] | |
| # Add variable_name conditions (with and without _crosstab suffix) | |
| var_conditions.append({"variable_name": {"$eq": var}}) | |
| var_conditions.append({"variable_name": {"$eq": f"{var}_crosstab"}}) | |
| # Add question_id condition if available | |
| # Note: question_id in Pinecone metadata might have _part suffix for chunked crosstabs | |
| # but we match on base question_id and filter in post-processing | |
| if var in var_to_question_id: | |
| question_id = var_to_question_id[var] | |
| var_conditions.append({"question_id": {"$eq": question_id}}) | |
| # Combine conditions for this variable with $or | |
| if len(var_conditions) > 1: | |
| filter_conditions.append({"$or": var_conditions}) | |
| else: | |
| filter_conditions.append(var_conditions[0]) | |
| # Combine all variable filters with $or | |
| if len(filter_conditions) == 1: | |
| var_filter = filter_conditions[0] | |
| else: | |
| var_filter = {"$or": filter_conditions} | |
| try: | |
| result = index.query( | |
| vector=dummy_vector, | |
| top_k=k * len(unique_vars), | |
| namespace=namespace, | |
| filter=var_filter, | |
| include_metadata=True | |
| ) | |
| if self.verbose: | |
| print(f" ๐ Found {len(result.matches)} matches in {namespace}") | |
| for match in result.matches: | |
| metadata = match.metadata or {} | |
| var_name = metadata.get("variable_name", "") | |
| # Handle question_id format like "VAND10_part1" | |
| question_id = metadata.get("question_id", "") | |
| if question_id and "_part" in question_id: | |
| base_var = question_id.split("_part")[0].replace("_crosstab", "") | |
| if base_var in unique_vars: | |
| var_name = base_var | |
| # Check if variable_name has _crosstab suffix | |
| if var_name and var_name.endswith("_crosstab"): | |
| base_var = var_name.replace("_crosstab", "") | |
| if base_var in unique_vars: | |
| var_name = base_var | |
| if not var_name or var_name not in unique_vars: | |
| continue | |
| content = metadata.pop('text', '') or metadata.pop('page_content', '') or '' | |
| if not content: | |
| continue | |
| if var_name not in all_docs_by_variable: | |
| all_docs_by_variable[var_name] = [] | |
| all_docs_by_variable[var_name].append( | |
| Document(page_content=content, metadata=metadata) | |
| ) | |
| except Exception as e: | |
| if self.verbose: | |
| print(f" โ ๏ธ Error querying namespace {namespace}: {e}") | |
| continue | |
| # Sort documents by chunk_index | |
| for var_name in all_docs_by_variable: | |
| all_docs_by_variable[var_name].sort(key=lambda d: d.metadata.get("chunk_index", 999)) | |
| all_docs_by_variable[var_name] = all_docs_by_variable[var_name][:MAX_CROSSTAB_CHUNKS] | |
| if self.verbose: | |
| total_docs = sum(len(docs) for docs in all_docs_by_variable.values()) | |
| print(f" โ Retrieved {total_docs} total document(s) for {len(all_docs_by_variable)} variable(s)") | |
| return all_docs_by_variable | |
| except Exception as e: | |
| if self.verbose: | |
| print(f" โ Error in retrieve_parts_for_question_info: {e}") | |
| return {} | |
| class CrosstabsRAG: | |
| """Crosstabs RAG with question_info-based retrieval.""" | |
| def __init__( | |
| self, | |
| questionnaire_rag: QuestionnaireRAG, | |
| verbose: bool = False | |
| ): | |
| self.questionnaire_rag = questionnaire_rag | |
| self.verbose = verbose | |
| pinecone_api_key = os.getenv("PINECONE_API_KEY") | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| index_name = os.getenv("PINECONE_INDEX_NAME_CROSSTABS", "crosstab-index") | |
| embed_model = os.getenv("OPENAI_EMBED_MODEL", "text-embedding-3-small") | |
| self.retriever = CrosstabRetriever( | |
| pinecone_api_key=pinecone_api_key, | |
| index_name=index_name, | |
| embed_model=embed_model, | |
| openai_api_key=openai_api_key, | |
| verbose=verbose | |
| ) | |
| def retrieve_raw_data( | |
| self, | |
| user_query: str, | |
| question_info: Optional[List[Dict[str, Any]]] = None, | |
| source_questions: Optional[List[Dict[str, Any]]] = None, | |
| filters: Optional[Dict[str, Any]] = None | |
| ) -> Dict: | |
| """ | |
| Retrieve raw crosstab data. | |
| Uses question_info if provided (skips QuestionnaireRAG). | |
| Otherwise uses QuestionnaireRAG to find questions, then retrieves crosstabs. | |
| Falls back to semantic search if metadata filtering returns no results. | |
| Args: | |
| user_query: User's query (used for QuestionnaireRAG if question_info not provided) | |
| question_info: List of question info dicts (preferred - skips QuestionnaireRAG) | |
| source_questions: Optional list of full question dicts from previous stage (avoids lookup) | |
| filters: Optional filters for QuestionnaireRAG | |
| Returns: | |
| Dict with crosstab_docs_by_variable, matched_questions, namespace_used, survey_info | |
| """ | |
| if self.verbose: | |
| print(f"\n๐ [Crosstabs] Query: {user_query}") | |
| if question_info: | |
| print(f"๐ Question info: {len(question_info)} question(s) provided") | |
| if filters: | |
| print(f"๐ Filters: {filters}") | |
| # If question_info provided, skip QuestionnaireRAG | |
| if question_info: | |
| if self.verbose: | |
| print(f"โ Using provided question_info, skipping QuestionnaireRAG") | |
| # Retrieve crosstab data directly | |
| crosstab_docs_by_variable = self.retriever.retrieve_parts_for_question_info( | |
| question_info_list=question_info, | |
| k=PINECONE_RETRIEVE_K, | |
| filters=filters | |
| ) | |
| if not crosstab_docs_by_variable: | |
| return {"error": f"No crosstab data found for {len(question_info)} question(s)."} | |
| # Get question metadata - use provided source_questions if available, otherwise lookup | |
| if not source_questions: | |
| source_questions = [] | |
| questions_by_id = self.questionnaire_rag.questions_by_id | |
| for q_info in question_info: | |
| question_id = q_info.get("question_id") | |
| if question_id and question_id in questions_by_id: | |
| source_questions.append(questions_by_id[question_id]) | |
| else: | |
| # Fallback: try to find by variable_name and year/month | |
| var_name = q_info.get("variable_name") | |
| year = q_info.get("year") | |
| month = q_info.get("month", "") | |
| if var_name: | |
| # Search through questions_by_id for matching variable | |
| for qid, q_data in questions_by_id.items(): | |
| if (q_data.get("variable_name") == var_name and | |
| q_data.get("year") == year and | |
| q_data.get("month", "") == month): | |
| source_questions.append(q_data) | |
| break | |
| # Format results | |
| formatted_results = {} | |
| matched_variables = [] | |
| all_namespaces = set() | |
| for var_name, docs in crosstab_docs_by_variable.items(): | |
| question_metadata = next( | |
| (q for q in source_questions if q.get("variable_name") == var_name), | |
| {} | |
| ) | |
| question_text = question_metadata.get("question_text", "") | |
| if docs: | |
| first_doc_meta = docs[0].metadata | |
| survey_name = first_doc_meta.get("survey_name", "") | |
| all_namespaces.add(survey_name) | |
| formatted_results[var_name] = { | |
| "crosstab_docs": docs, | |
| "question_text": question_text or (docs[0].metadata.get("question_text", "") if docs else ""), | |
| "matched_question": question_metadata | |
| } | |
| matched_variables.append(var_name) | |
| return { | |
| "crosstab_docs_by_variable": formatted_results, | |
| "matched_questions": source_questions, | |
| "matched_variables": matched_variables, | |
| "namespace_used": list(all_namespaces), | |
| "survey_info": {"poll": "Vanderbilt_Unity_Poll", "year": None, "month": None} | |
| } | |
| # Otherwise, use QuestionnaireRAG to find questions first | |
| if self.verbose: | |
| print(f"๐ Using QuestionnaireRAG to find questions") | |
| try: | |
| q_result = self.questionnaire_rag.retrieve_raw_data( | |
| question=user_query, | |
| filters=filters or {}, | |
| k=10 | |
| ) | |
| except Exception as e: | |
| return {"error": f"Error querying questionnaire: {e}"} | |
| source_questions = q_result.get("source_questions", []) | |
| question_info_from_questions = q_result.get("question_info", []) | |
| if not source_questions: | |
| return {"error": "No matching questions found in questionnaire for that query."} | |
| if self.verbose: | |
| print(f"โ Found {len(source_questions)} question(s) from QuestionnaireRAG") | |
| # Retrieve crosstab data using question_info | |
| crosstab_docs_by_variable = self.retriever.retrieve_parts_for_question_info( | |
| question_info_list=question_info_from_questions, | |
| k=PINECONE_RETRIEVE_K | |
| ) | |
| if not crosstab_docs_by_variable: | |
| return {"error": f"No crosstab data found for any of the {len(source_questions)} matched questions."} | |
| # Format results | |
| formatted_results = {} | |
| matched_variables = [] | |
| all_namespaces = set() | |
| for matched_question in source_questions: | |
| variable_name = matched_question["variable_name"] | |
| question_text = matched_question["question_text"] | |
| if variable_name in crosstab_docs_by_variable: | |
| formatted_results[variable_name] = { | |
| "crosstab_docs": crosstab_docs_by_variable[variable_name], | |
| "question_text": question_text, | |
| "matched_question": matched_question | |
| } | |
| matched_variables.append(variable_name) | |
| if crosstab_docs_by_variable[variable_name]: | |
| first_doc = crosstab_docs_by_variable[variable_name][0] | |
| survey_name = first_doc.metadata.get("survey_name", "") | |
| all_namespaces.add(survey_name) | |
| return { | |
| "crosstab_docs_by_variable": formatted_results, | |
| "matched_questions": source_questions, | |
| "matched_variables": matched_variables, | |
| "namespace_used": list(all_namespaces), | |
| "survey_info": {"poll": "Vanderbilt_Unity_Poll", "year": None, "month": None} | |
| } | |