Spaces:
Running
Running
| from __future__ import annotations | |
| import logging | |
| from typing import Any, Dict, List, Optional, Protocol, Tuple | |
| from typing_extensions import TypedDict | |
| from .config import RagConfig | |
| try: | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http import models as rest | |
| except ImportError: | |
| QdrantClient = None | |
| rest = None | |
| try: | |
| import chromadb | |
| from chromadb.config import Settings as ChromaSettings | |
| except ImportError: | |
| chromadb = None | |
| ChromaSettings = None | |
| logger = logging.getLogger(__name__) | |
| class VectorHit(TypedDict, total=False): | |
| id: str | |
| score: float | |
| payload: Dict[str, Any] | |
| class VectorStore(Protocol): | |
| def ensure_collection(self, name: str, dim: int) -> None: | |
| ... | |
| def upsert(self, collection: str, points: List[Dict[str, Any]]) -> None: | |
| ... | |
| def search(self, collection: str, vector: List[float], limit: int) -> List[VectorHit]: | |
| ... | |
| def retrieve(self, collection: str, ids: List[str]) -> List[VectorHit]: | |
| ... | |
| def list_doc_ids(self, collection: str, *, limit: int = 10_000) -> List[str]: | |
| ... | |
| def delete_doc(self, collection: str, doc_id: str) -> Tuple[bool, Optional[str]]: | |
| ... | |
| class QdrantVectorStore: | |
| def __init__( | |
| self, | |
| client: Any = None, | |
| *, | |
| collection_prefix: str = "", | |
| ): | |
| if client is None: | |
| if QdrantClient is None or rest is None: | |
| raise RuntimeError("Install qdrant-client to use QdrantVectorStore") | |
| # Use in-process embedded Qdrant for Hugging Face Spaces compatibility | |
| client = QdrantClient(":memory:") | |
| self._client = client | |
| self._collection_prefix = collection_prefix | |
| def from_config(cls, config: RagConfig) -> "QdrantVectorStore": | |
| # Use embedded Qdrant (in-memory) instead of connecting to external server | |
| # This works on Hugging Face Spaces where we can't run Docker containers | |
| client = QdrantClient(":memory:") | |
| return cls( | |
| client=client, | |
| collection_prefix=config.collection_prefix, | |
| ) | |
| def _full(self, name: str) -> str: | |
| return f"{self._collection_prefix}{name}" | |
| def ensure_collection(self, name: str, dim: int) -> None: | |
| collection_name = self._full(name) | |
| try: | |
| self._client.get_collection(collection_name=collection_name) | |
| return | |
| except Exception: | |
| pass | |
| if rest is None: | |
| raise RuntimeError("qdrant-client models not available") | |
| self._client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=rest.VectorParams(size=dim, distance=rest.Distance.COSINE), | |
| ) | |
| def upsert(self, collection: str, points: List[Dict[str, Any]]) -> None: | |
| collection_name = self._full(collection) | |
| payloads: List[rest.PointStruct] = [] | |
| for point in points: | |
| payloads.append( | |
| rest.PointStruct( | |
| id=point["id"], | |
| vector=point["vector"], | |
| payload=point["payload"], | |
| ) | |
| ) | |
| self._client.upsert(collection_name=collection_name, points=payloads) | |
| def search(self, collection: str, vector: List[float], limit: int) -> List[VectorHit]: | |
| collection_name = self._full(collection) | |
| results = self._client.query_points( | |
| collection_name=collection_name, | |
| query=vector, | |
| limit=limit, | |
| with_payload=True, | |
| ) | |
| hits: List[VectorHit] = [] | |
| for hit in results.points: | |
| hits.append( | |
| { | |
| "id": str(hit.id), | |
| "score": float(hit.score), | |
| "payload": hit.payload or {}, | |
| } | |
| ) | |
| return hits | |
| def retrieve(self, collection: str, ids: List[str]) -> List[VectorHit]: | |
| if rest is None: | |
| raise RuntimeError("qdrant-client models not available") | |
| collection_name = self._full(collection) | |
| hits: List[VectorHit] = [] | |
| for chunk_id in ids: | |
| points, _ = self._client.scroll( | |
| collection_name=collection_name, | |
| scroll_filter=rest.Filter( | |
| must=[ | |
| rest.FieldCondition( | |
| key="chunk_id", | |
| match=rest.MatchValue(value=chunk_id), | |
| ) | |
| ] | |
| ), | |
| limit=1, | |
| with_payload=True, | |
| ) | |
| for point in points: | |
| hits.append( | |
| { | |
| "id": str(point.id), | |
| "score": 1.0, | |
| "payload": point.payload or {}, | |
| } | |
| ) | |
| return hits | |
| def list_doc_ids(self, collection: str, *, limit: int = 10_000) -> List[str]: | |
| collection_name = self._full(collection) | |
| ids: List[str] = [] | |
| next_offset = None | |
| while True: | |
| points, next_offset = self._client.scroll( | |
| collection_name=collection_name, | |
| limit=256, | |
| with_payload=True, | |
| offset=next_offset, | |
| ) | |
| for point in points: | |
| payload = point.payload or {} | |
| doc_id = payload.get("doc_id") | |
| if doc_id and doc_id not in ids: | |
| ids.append(doc_id) | |
| if len(ids) >= limit: | |
| return ids | |
| if next_offset is None: | |
| break | |
| return ids | |
| def delete_doc(self, collection: str, doc_id: str) -> Tuple[bool, Optional[str]]: | |
| if rest is None: | |
| raise RuntimeError("qdrant-client models not available") | |
| collection_name = self._full(collection) | |
| try: | |
| filter_payload = rest.Filter( | |
| must=[ | |
| rest.FieldCondition( | |
| key="doc_id", | |
| match=rest.MatchValue(value=doc_id), | |
| ) | |
| ] | |
| ) | |
| self._client.delete( | |
| collection_name=collection_name, | |
| points_selector=rest.FilterSelector(filter=filter_payload), | |
| ) | |
| return True, None | |
| except Exception as exc: | |
| logger.exception("Failed to delete doc %s from %s", doc_id, collection_name) | |
| return False, str(exc) | |
| class ChromaVectorStore: | |
| def __init__( | |
| self, | |
| client: Any = None, | |
| *, | |
| persist_directory: Optional[str] = None, | |
| collection_prefix: str = "", | |
| ): | |
| if client is None: | |
| if chromadb is None: | |
| raise RuntimeError("Install chromadb to use ChromaVectorStore: pip install chromadb") | |
| if persist_directory: | |
| client = chromadb.PersistentClient(path=persist_directory) | |
| else: | |
| client = chromadb.Client() | |
| self._client = client | |
| self._collection_prefix = collection_prefix | |
| self._collections: Dict[str, Any] = {} | |
| def from_config(cls, config: RagConfig) -> "ChromaVectorStore": | |
| persist_dir = getattr(config, "chroma_persist_directory", "./chroma_data") | |
| return cls( | |
| persist_directory=persist_dir, | |
| collection_prefix=config.collection_prefix, | |
| ) | |
| def _full(self, name: str) -> str: | |
| return f"{self._collection_prefix}{name}" | |
| def ensure_collection(self, name: str, dim: int) -> None: | |
| collection_name = self._full(name) | |
| if collection_name not in self._collections: | |
| self._collections[collection_name] = self._client.get_or_create_collection( | |
| name=collection_name, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| def upsert(self, collection: str, points: List[Dict[str, Any]]) -> None: | |
| collection_name = self._full(collection) | |
| coll = self._collections.get(collection_name) | |
| if coll is None: | |
| coll = self._client.get_collection(collection_name) | |
| self._collections[collection_name] = coll | |
| ids = [str(point["id"]) for point in points] | |
| embeddings = [point["vector"] for point in points] | |
| metadatas = [point["payload"] for point in points] | |
| coll.upsert(ids=ids, embeddings=embeddings, metadatas=metadatas) | |
| def search(self, collection: str, vector: List[float], limit: int) -> List[VectorHit]: | |
| collection_name = self._full(collection) | |
| coll = self._collections.get(collection_name) | |
| if coll is None: | |
| coll = self._client.get_collection(collection_name) | |
| self._collections[collection_name] = coll | |
| results = coll.query(query_embeddings=[vector], n_results=limit) | |
| hits: List[VectorHit] = [] | |
| if results["ids"] and len(results["ids"]) > 0: | |
| for i, doc_id in enumerate(results["ids"][0]): | |
| distance = results["distances"][0][i] if results["distances"] else 0.0 | |
| score = 1.0 - distance | |
| metadata = results["metadatas"][0][i] if results["metadatas"] else {} | |
| hits.append( | |
| { | |
| "id": str(doc_id), | |
| "score": float(score), | |
| "payload": metadata or {}, | |
| } | |
| ) | |
| return hits | |
| def retrieve(self, collection: str, ids: List[str]) -> List[VectorHit]: | |
| collection_name = self._full(collection) | |
| coll = self._collections.get(collection_name) | |
| if coll is None: | |
| coll = self._client.get_collection(collection_name) | |
| self._collections[collection_name] = coll | |
| results = coll.get(ids=ids, include=["metadatas"]) | |
| hits: List[VectorHit] = [] | |
| if results["ids"]: | |
| for i, doc_id in enumerate(results["ids"]): | |
| metadata = results["metadatas"][i] if results["metadatas"] else {} | |
| hits.append( | |
| { | |
| "id": str(doc_id), | |
| "score": 1.0, | |
| "payload": metadata or {}, | |
| } | |
| ) | |
| return hits | |
| def list_doc_ids(self, collection: str, *, limit: int = 10_000) -> List[str]: | |
| collection_name = self._full(collection) | |
| coll = self._collections.get(collection_name) | |
| if coll is None: | |
| coll = self._client.get_collection(collection_name) | |
| self._collections[collection_name] = coll | |
| results = coll.get(limit=limit, include=["metadatas"]) | |
| doc_ids: List[str] = [] | |
| seen = set() | |
| if results["metadatas"]: | |
| for metadata in results["metadatas"]: | |
| doc_id = metadata.get("doc_id") | |
| if doc_id and doc_id not in seen: | |
| doc_ids.append(doc_id) | |
| seen.add(doc_id) | |
| return doc_ids | |
| def delete_doc(self, collection: str, doc_id: str) -> Tuple[bool, Optional[str]]: | |
| collection_name = self._full(collection) | |
| coll = self._collections.get(collection_name) | |
| if coll is None: | |
| coll = self._client.get_collection(collection_name) | |
| self._collections[collection_name] = coll | |
| try: | |
| results = coll.get(where={"doc_id": doc_id}, include=[]) | |
| if results["ids"]: | |
| coll.delete(ids=results["ids"]) | |
| return True, None | |
| except Exception as exc: | |
| logger.exception("Failed to delete doc %s from %s", doc_id, collection_name) | |
| return False, str(exc) | |