Spaces:
Sleeping
Sleeping
| """Core RAG system implementation""" | |
| import os | |
| import glob | |
| from typing import List, Tuple, Optional | |
| import PyPDF2 | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import InferenceClient | |
| import spaces | |
| class RAGSystem: | |
| def __init__(self): | |
| self.chunks = [] | |
| self.embeddings = None | |
| self.index = None | |
| self.embedding_model = None | |
| self.embedding_model_name = None | |
| self.llm_client = None | |
| self.llm_model_name = None | |
| self.ready = False | |
| def is_ready(self) -> bool: | |
| """Check if the system is ready to process queries""" | |
| return self.ready and self.index is not None | |
| def load_default_corpus(self, chunk_size: int = 500, chunk_overlap: int = 50): | |
| """Load the default corpus from documents folder""" | |
| documents_dir = "documents" | |
| if not os.path.exists(documents_dir): | |
| return "Documents folder not found. Please upload a PDF.", "", "" | |
| # Get all PDFs in documents folder | |
| pdf_files = glob.glob(os.path.join(documents_dir, "*.pdf")) | |
| if not pdf_files: | |
| return "No PDF files found in documents folder. Please upload a PDF.", "", "" | |
| try: | |
| # Extract text from all PDFs | |
| all_text = "" | |
| corpus_summary = f"📚 **Loading {len(pdf_files)} documents:**\n\n" | |
| for pdf_path in pdf_files: | |
| filename = os.path.basename(pdf_path) | |
| corpus_summary += f"- {filename}\n" | |
| text = self.extract_text_from_pdf(pdf_path) | |
| all_text += f"\n\n=== {filename} ===\n\n{text}" | |
| corpus_summary += f"\n**Total text length:** {len(all_text)} characters\n" | |
| # Chunk the combined text | |
| self.chunks = self.chunk_text(all_text, chunk_size, chunk_overlap) | |
| if not self.chunks: | |
| return "Error: No valid chunks created from the documents.", "", "" | |
| # Create embeddings | |
| self.embeddings = self.create_embeddings(self.chunks) | |
| # Build index | |
| self.build_index(self.embeddings) | |
| self.ready = True | |
| # Format chunks for display | |
| chunks_display = "### Processed Chunks\n\n" | |
| for i, chunk in enumerate(self.chunks, 1): | |
| chunks_display += f"**Chunk {i}** ({len(chunk)} chars)\n```\n{chunk[:200]}{'...' if len(chunk) > 200 else ''}\n```\n\n" | |
| status = f"✅ Success! Processed {len(pdf_files)} documents into {len(self.chunks)} chunks." | |
| return status, chunks_display, corpus_summary | |
| except Exception as e: | |
| self.ready = False | |
| return f"Error loading default corpus: {str(e)}", "", "" | |
| def extract_text_from_pdf(self, pdf_path: str) -> str: | |
| """Extract text from PDF file""" | |
| text = "" | |
| with open(pdf_path, 'rb') as file: | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: | |
| """Split text into overlapping chunks""" | |
| chunks = [] | |
| start = 0 | |
| text_length = len(text) | |
| while start < text_length: | |
| end = start + chunk_size | |
| chunk = text[start:end] | |
| # Try to break at sentence boundary | |
| if end < text_length: | |
| # Look for sentence endings | |
| last_period = chunk.rfind('.') | |
| last_newline = chunk.rfind('\n') | |
| break_point = max(last_period, last_newline) | |
| if break_point > chunk_size * 0.5: # Only break if we're past halfway | |
| chunk = chunk[:break_point + 1] | |
| end = start + break_point + 1 | |
| chunks.append(chunk.strip()) | |
| start = end - overlap | |
| return [c for c in chunks if len(c) > 50] # Filter out very small chunks | |
| def create_embeddings(self, texts: List[str]) -> np.ndarray: | |
| """Create embeddings for text chunks""" | |
| if self.embedding_model is None: | |
| self.set_embedding_model("sentence-transformers/all-MiniLM-L6-v2") | |
| embeddings = self.embedding_model.encode( | |
| texts, | |
| show_progress_bar=True, | |
| convert_to_numpy=True | |
| ) | |
| return embeddings | |
| def build_index(self, embeddings: np.ndarray): | |
| """Build FAISS index from embeddings""" | |
| dimension = embeddings.shape[1] | |
| self.index = faiss.IndexFlatIP(dimension) # Inner product for cosine similarity | |
| # Normalize embeddings for cosine similarity | |
| faiss.normalize_L2(embeddings) | |
| self.index.add(embeddings) | |
| def process_document(self, pdf_path: str, chunk_size: int = 500, chunk_overlap: int = 50): | |
| """Process a PDF document and create searchable index""" | |
| try: | |
| # Extract text | |
| text = self.extract_text_from_pdf(pdf_path) | |
| if not text.strip(): | |
| return "Error: No text could be extracted from the PDF.", "", "" | |
| # Chunk text | |
| self.chunks = self.chunk_text(text, chunk_size, chunk_overlap) | |
| if not self.chunks: | |
| return "Error: No valid chunks created from the document.", "", "" | |
| # Create embeddings | |
| self.embeddings = self.create_embeddings(self.chunks) | |
| # Build index | |
| self.build_index(self.embeddings) | |
| self.ready = True | |
| # Format chunks for display | |
| chunks_display = "### Processed Chunks\n\n" | |
| for i, chunk in enumerate(self.chunks, 1): | |
| chunks_display += f"**Chunk {i}** ({len(chunk)} chars)\n```\n{chunk}\n```\n\n" | |
| status = f"✅ Success! Processed {len(self.chunks)} chunks from the document." | |
| return status, chunks_display, text[:5000] # Return first 5000 chars of original text | |
| except Exception as e: | |
| self.ready = False | |
| return f"Error processing document: {str(e)}", "", "" | |
| def set_embedding_model(self, model_name: str): | |
| """Set or change the embedding model""" | |
| if self.embedding_model_name != model_name: | |
| self.embedding_model_name = model_name | |
| self.embedding_model = SentenceTransformer(model_name) | |
| # If we have chunks, re-create embeddings and index | |
| if self.chunks: | |
| self.embeddings = self.create_embeddings(self.chunks) | |
| self.build_index(self.embeddings) | |
| def set_llm_model(self, model_name: str): | |
| """Set or change the LLM model""" | |
| if self.llm_model_name != model_name: | |
| self.llm_model_name = model_name | |
| # Use HF_TOKEN from environment if available | |
| hf_token = os.environ.get("HF_TOKEN", None) | |
| self.llm_client = InferenceClient(model_name, token=hf_token) | |
| def retrieve( | |
| self, | |
| query: str, | |
| top_k: int = 3, | |
| similarity_threshold: float = 0.0 | |
| ) -> List[Tuple[str, float]]: | |
| """Retrieve relevant chunks for a query""" | |
| if not self.is_ready(): | |
| return [] | |
| # Encode query | |
| query_embedding = self.embedding_model.encode( | |
| [query], | |
| convert_to_numpy=True | |
| ) | |
| # Normalize for cosine similarity | |
| faiss.normalize_L2(query_embedding) | |
| # Search | |
| scores, indices = self.index.search(query_embedding, top_k) | |
| # Filter by threshold and return results | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if score >= similarity_threshold: | |
| results.append((self.chunks[idx], float(score))) | |
| return results | |
| def generate( | |
| self, | |
| query: str, | |
| retrieved_chunks: List[Tuple[str, float]], | |
| temperature: float = 0.7, | |
| max_tokens: int = 300 | |
| ) -> Tuple[str, str]: | |
| """Generate answer using LLM""" | |
| if self.llm_client is None: | |
| self.set_llm_model("meta-llama/Llama-3.2-1B-Instruct") | |
| # Build context from retrieved chunks | |
| context = "\n\n".join([chunk for chunk, _ in retrieved_chunks]) | |
| # Create prompt | |
| prompt = f"""Use the following context to answer the question. If you cannot answer based on the context, say so. | |
| Context: | |
| {context} | |
| Question: {query} | |
| Answer:""" | |
| # Generate response using chat completion | |
| try: | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ] | |
| response = self.llm_client.chat_completion( | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| ) | |
| # Extract answer from response | |
| if hasattr(response, 'choices') and len(response.choices) > 0: | |
| answer = response.choices[0].message.content.strip() | |
| elif isinstance(response, dict) and 'choices' in response: | |
| answer = response['choices'][0]['message']['content'].strip() | |
| else: | |
| answer = str(response).strip() | |
| return answer, prompt | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}", prompt | |