Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Depends, HTTPException, UploadFile, File | |
| import pandas as pd | |
| import lancedb | |
| from functools import cached_property, lru_cache | |
| from pydantic import Field, BaseModel | |
| from typing import Optional, Dict, List, Annotated, Any | |
| from fastapi import APIRouter | |
| import uuid | |
| import io | |
| from io import BytesIO | |
| import csv | |
| import sqlite3 | |
| # LlamaIndex imports | |
| from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex | |
| from llama_index.vector_stores.lancedb import LanceDBVectorStore | |
| from llama_index.embeddings.fastembed import FastEmbedEmbedding | |
| from llama_index.core.schema import TextNode | |
| from llama_index.core import StorageContext, load_index_from_storage | |
| import json | |
| import os | |
| import shutil | |
| router = APIRouter( | |
| prefix="/rag", | |
| tags=["rag"] | |
| ) | |
| # Configure global LlamaIndex settings | |
| Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5") | |
| # Database connection dependency | |
| def get_db_connection(db_path: str = "./lancedb/dev"): | |
| return lancedb.connect(db_path) | |
| def get_db(): | |
| conn = sqlite3.connect('./data/tablesv2.db') | |
| conn.row_factory = sqlite3.Row | |
| return conn | |
| def init_db(): | |
| db = get_db() | |
| db.execute(''' | |
| CREATE TABLE IF NOT EXISTS tables ( | |
| id INTEGER PRIMARY KEY, | |
| user_id TEXT NOT NULL, | |
| table_id TEXT NOT NULL, | |
| table_name TEXT NOT NULL, | |
| created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| ''') | |
| db.execute(''' | |
| CREATE TABLE IF NOT EXISTS table_files ( | |
| id INTEGER PRIMARY KEY, | |
| table_id TEXT NOT NULL, | |
| filename TEXT NOT NULL, | |
| file_path TEXT NOT NULL, | |
| FOREIGN KEY (table_id) REFERENCES tables (table_id), | |
| UNIQUE(table_id, filename) | |
| ) | |
| ''') | |
| db.commit() | |
| # Pydantic models | |
| class CreateTableResponse(BaseModel): | |
| table_id: str | |
| message: str | |
| status: str | |
| table_name: str | |
| class QueryTableResponse(BaseModel): | |
| results: Dict[str, Any] | |
| total_results: int | |
| async def create_embedding_table( | |
| user_id: str, | |
| files: List[UploadFile] = File(...), | |
| table_id: Optional[str] = None, | |
| table_name: Optional[str] = None | |
| ) -> CreateTableResponse: | |
| try: | |
| db = get_db() | |
| table_id = table_id or str(uuid.uuid4()) | |
| table_name = table_name or f"knowledge-base-{str(uuid.uuid4())[:4]}" | |
| # Check if table exists | |
| existing = db.execute( | |
| 'SELECT id FROM tables WHERE user_id = ? AND table_id = ?', | |
| (user_id, table_id) | |
| ).fetchone() | |
| directory_path = f"./data/{table_id}" | |
| os.makedirs(directory_path, exist_ok=True) | |
| for file in files: | |
| if not file.filename: | |
| raise HTTPException(status_code=400, detail="Invalid filename") | |
| if os.path.splitext(file.filename)[1].lower() not in {".pdf", ".docx", ".csv", ".txt", ".md"}: | |
| raise HTTPException(status_code=400, detail="Unsupported file type") | |
| file_path = os.path.join(directory_path, file.filename) | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| vector_store = LanceDBVectorStore( | |
| uri="./lancedb/dev", | |
| table_name=table_id, | |
| mode="overwrite", | |
| query_type="hybrid" | |
| ) | |
| documents = SimpleDirectoryReader(directory_path).load_data() | |
| index = VectorStoreIndex.from_documents(documents, vector_store=vector_store) | |
| index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}") | |
| if not existing: | |
| db.execute( | |
| 'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)', | |
| (user_id, table_id, table_name) | |
| ) | |
| for file in files: | |
| db.execute( | |
| 'INSERT OR REPLACE INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)', | |
| (table_id, file.filename, f"./data/{table_id}/{file.filename}") | |
| ) | |
| db.commit() | |
| return CreateTableResponse( | |
| table_id=table_id, | |
| message="Success", | |
| status="success", | |
| table_name=table_name | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def query_table( | |
| table_id: str, | |
| query: str, | |
| user_id: str, | |
| #db: Annotated[Any, Depends(get_db_connection)], | |
| limit: Optional[int] = 10 | |
| ) -> QueryTableResponse: | |
| """Query the database table using LlamaIndex.""" | |
| try: | |
| table_name = table_id #f"{user_id}__table__{table_id}" | |
| # load index and retriever | |
| storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}") | |
| index = load_index_from_storage(storage_context) | |
| retriever = index.as_retriever(similarity_top_k=limit) | |
| # Get response | |
| response = retriever.retrieve(query) | |
| # Format results | |
| results = [{ | |
| 'text': node.text, | |
| 'score': node.score | |
| } for node in response] | |
| return QueryTableResponse( | |
| results={'data': results}, | |
| total_results=len(results) | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}") | |
| async def get_tables(user_id: str): | |
| db = get_db() | |
| tables = db.execute(''' | |
| SELECT | |
| t.table_id, | |
| t.table_name, | |
| t.created_time as created_at, | |
| GROUP_CONCAT(tf.filename) as filenames | |
| FROM tables t | |
| LEFT JOIN table_files tf ON t.table_id = tf.table_id | |
| WHERE t.user_id = ? | |
| GROUP BY t.table_id | |
| ''', (user_id,)).fetchall() | |
| result = [] | |
| for table in tables: | |
| table_dict = dict(table) | |
| result.append({ | |
| 'table_id': table_dict['table_id'], | |
| 'table_name': table_dict['table_name'], | |
| 'created_at': table_dict['created_at'], | |
| 'documents': [filename for filename in table_dict['filenames'].split(',') if filename] if table_dict['filenames'] else [] | |
| }) | |
| return result | |
| async def delete_table(table_id: str, user_id: str): | |
| try: | |
| db = get_db() | |
| # Verify user owns the table | |
| table = db.execute( | |
| 'SELECT * FROM tables WHERE table_id = ? AND user_id = ?', | |
| (table_id, user_id) | |
| ).fetchone() | |
| if not table: | |
| raise HTTPException(status_code=404, detail="Table not found or unauthorized") | |
| # Delete files from filesystem | |
| table_path = f"./data/{table_id}" | |
| index_path = f"./lancedb/index/{table_id}" | |
| if os.path.exists(table_path): | |
| shutil.rmtree(table_path) | |
| if os.path.exists(index_path): | |
| shutil.rmtree(index_path) | |
| # Delete from database | |
| db.execute('DELETE FROM table_files WHERE table_id = ?', (table_id,)) | |
| db.execute('DELETE FROM tables WHERE table_id = ?', (table_id,)) | |
| db.commit() | |
| return {"message": "Table deleted successfully"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| async def startup(): | |
| init_db() | |
| print("RAG Router started") | |
| table_name = "digiyatra" | |
| user_id = "digiyatra" | |
| db = get_db() | |
| # Check if table already exists | |
| existing = db.execute('SELECT id FROM tables WHERE table_id = ?', (table_name,)).fetchone() | |
| if not existing: | |
| vector_store = LanceDBVectorStore( | |
| uri="./lancedb/dev", | |
| table_name=table_name, | |
| mode="overwrite", | |
| query_type="hybrid" | |
| ) | |
| with open('combined_digi_yatra.csv', newline='') as f: | |
| nodes = [TextNode(text=str(row), id_=str(uuid.uuid4())) | |
| for row in list(csv.reader(f))[1:]] | |
| index = VectorStoreIndex(nodes, vector_store=vector_store) | |
| index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}") | |
| db.execute( | |
| 'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)', | |
| (user_id, table_name, table_name) | |
| ) | |
| db.execute( | |
| 'INSERT INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)', | |
| (table_name, 'combined_digi_yatra.csv', 'combined_digi_yatra.csv') | |
| ) | |
| db.commit() | |
| async def shutdown(): | |
| print("RAG Router shutdown") |