rag-mcp-agent / rag_core /vector_store.py
Dave1233's picture
Change to local memory
bda9d9e
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional, Protocol, Tuple
from typing_extensions import TypedDict
from .config import RagConfig
try:
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
except ImportError:
QdrantClient = None
rest = None
try:
import chromadb
from chromadb.config import Settings as ChromaSettings
except ImportError:
chromadb = None
ChromaSettings = None
logger = logging.getLogger(__name__)
class VectorHit(TypedDict, total=False):
id: str
score: float
payload: Dict[str, Any]
class VectorStore(Protocol):
def ensure_collection(self, name: str, dim: int) -> None:
...
def upsert(self, collection: str, points: List[Dict[str, Any]]) -> None:
...
def search(self, collection: str, vector: List[float], limit: int) -> List[VectorHit]:
...
def retrieve(self, collection: str, ids: List[str]) -> List[VectorHit]:
...
def list_doc_ids(self, collection: str, *, limit: int = 10_000) -> List[str]:
...
def delete_doc(self, collection: str, doc_id: str) -> Tuple[bool, Optional[str]]:
...
class QdrantVectorStore:
def __init__(
self,
client: Any = None,
*,
collection_prefix: str = "",
):
if client is None:
if QdrantClient is None or rest is None:
raise RuntimeError("Install qdrant-client to use QdrantVectorStore")
# Use in-process embedded Qdrant for Hugging Face Spaces compatibility
client = QdrantClient(":memory:")
self._client = client
self._collection_prefix = collection_prefix
@classmethod
def from_config(cls, config: RagConfig) -> "QdrantVectorStore":
# Use embedded Qdrant (in-memory) instead of connecting to external server
# This works on Hugging Face Spaces where we can't run Docker containers
client = QdrantClient(":memory:")
return cls(
client=client,
collection_prefix=config.collection_prefix,
)
def _full(self, name: str) -> str:
return f"{self._collection_prefix}{name}"
def ensure_collection(self, name: str, dim: int) -> None:
collection_name = self._full(name)
try:
self._client.get_collection(collection_name=collection_name)
return
except Exception:
pass
if rest is None:
raise RuntimeError("qdrant-client models not available")
self._client.create_collection(
collection_name=collection_name,
vectors_config=rest.VectorParams(size=dim, distance=rest.Distance.COSINE),
)
def upsert(self, collection: str, points: List[Dict[str, Any]]) -> None:
collection_name = self._full(collection)
payloads: List[rest.PointStruct] = []
for point in points:
payloads.append(
rest.PointStruct(
id=point["id"],
vector=point["vector"],
payload=point["payload"],
)
)
self._client.upsert(collection_name=collection_name, points=payloads)
def search(self, collection: str, vector: List[float], limit: int) -> List[VectorHit]:
collection_name = self._full(collection)
results = self._client.query_points(
collection_name=collection_name,
query=vector,
limit=limit,
with_payload=True,
)
hits: List[VectorHit] = []
for hit in results.points:
hits.append(
{
"id": str(hit.id),
"score": float(hit.score),
"payload": hit.payload or {},
}
)
return hits
def retrieve(self, collection: str, ids: List[str]) -> List[VectorHit]:
if rest is None:
raise RuntimeError("qdrant-client models not available")
collection_name = self._full(collection)
hits: List[VectorHit] = []
for chunk_id in ids:
points, _ = self._client.scroll(
collection_name=collection_name,
scroll_filter=rest.Filter(
must=[
rest.FieldCondition(
key="chunk_id",
match=rest.MatchValue(value=chunk_id),
)
]
),
limit=1,
with_payload=True,
)
for point in points:
hits.append(
{
"id": str(point.id),
"score": 1.0,
"payload": point.payload or {},
}
)
return hits
def list_doc_ids(self, collection: str, *, limit: int = 10_000) -> List[str]:
collection_name = self._full(collection)
ids: List[str] = []
next_offset = None
while True:
points, next_offset = self._client.scroll(
collection_name=collection_name,
limit=256,
with_payload=True,
offset=next_offset,
)
for point in points:
payload = point.payload or {}
doc_id = payload.get("doc_id")
if doc_id and doc_id not in ids:
ids.append(doc_id)
if len(ids) >= limit:
return ids
if next_offset is None:
break
return ids
def delete_doc(self, collection: str, doc_id: str) -> Tuple[bool, Optional[str]]:
if rest is None:
raise RuntimeError("qdrant-client models not available")
collection_name = self._full(collection)
try:
filter_payload = rest.Filter(
must=[
rest.FieldCondition(
key="doc_id",
match=rest.MatchValue(value=doc_id),
)
]
)
self._client.delete(
collection_name=collection_name,
points_selector=rest.FilterSelector(filter=filter_payload),
)
return True, None
except Exception as exc:
logger.exception("Failed to delete doc %s from %s", doc_id, collection_name)
return False, str(exc)
class ChromaVectorStore:
def __init__(
self,
client: Any = None,
*,
persist_directory: Optional[str] = None,
collection_prefix: str = "",
):
if client is None:
if chromadb is None:
raise RuntimeError("Install chromadb to use ChromaVectorStore: pip install chromadb")
if persist_directory:
client = chromadb.PersistentClient(path=persist_directory)
else:
client = chromadb.Client()
self._client = client
self._collection_prefix = collection_prefix
self._collections: Dict[str, Any] = {}
@classmethod
def from_config(cls, config: RagConfig) -> "ChromaVectorStore":
persist_dir = getattr(config, "chroma_persist_directory", "./chroma_data")
return cls(
persist_directory=persist_dir,
collection_prefix=config.collection_prefix,
)
def _full(self, name: str) -> str:
return f"{self._collection_prefix}{name}"
def ensure_collection(self, name: str, dim: int) -> None:
collection_name = self._full(name)
if collection_name not in self._collections:
self._collections[collection_name] = self._client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"},
)
def upsert(self, collection: str, points: List[Dict[str, Any]]) -> None:
collection_name = self._full(collection)
coll = self._collections.get(collection_name)
if coll is None:
coll = self._client.get_collection(collection_name)
self._collections[collection_name] = coll
ids = [str(point["id"]) for point in points]
embeddings = [point["vector"] for point in points]
metadatas = [point["payload"] for point in points]
coll.upsert(ids=ids, embeddings=embeddings, metadatas=metadatas)
def search(self, collection: str, vector: List[float], limit: int) -> List[VectorHit]:
collection_name = self._full(collection)
coll = self._collections.get(collection_name)
if coll is None:
coll = self._client.get_collection(collection_name)
self._collections[collection_name] = coll
results = coll.query(query_embeddings=[vector], n_results=limit)
hits: List[VectorHit] = []
if results["ids"] and len(results["ids"]) > 0:
for i, doc_id in enumerate(results["ids"][0]):
distance = results["distances"][0][i] if results["distances"] else 0.0
score = 1.0 - distance
metadata = results["metadatas"][0][i] if results["metadatas"] else {}
hits.append(
{
"id": str(doc_id),
"score": float(score),
"payload": metadata or {},
}
)
return hits
def retrieve(self, collection: str, ids: List[str]) -> List[VectorHit]:
collection_name = self._full(collection)
coll = self._collections.get(collection_name)
if coll is None:
coll = self._client.get_collection(collection_name)
self._collections[collection_name] = coll
results = coll.get(ids=ids, include=["metadatas"])
hits: List[VectorHit] = []
if results["ids"]:
for i, doc_id in enumerate(results["ids"]):
metadata = results["metadatas"][i] if results["metadatas"] else {}
hits.append(
{
"id": str(doc_id),
"score": 1.0,
"payload": metadata or {},
}
)
return hits
def list_doc_ids(self, collection: str, *, limit: int = 10_000) -> List[str]:
collection_name = self._full(collection)
coll = self._collections.get(collection_name)
if coll is None:
coll = self._client.get_collection(collection_name)
self._collections[collection_name] = coll
results = coll.get(limit=limit, include=["metadatas"])
doc_ids: List[str] = []
seen = set()
if results["metadatas"]:
for metadata in results["metadatas"]:
doc_id = metadata.get("doc_id")
if doc_id and doc_id not in seen:
doc_ids.append(doc_id)
seen.add(doc_id)
return doc_ids
def delete_doc(self, collection: str, doc_id: str) -> Tuple[bool, Optional[str]]:
collection_name = self._full(collection)
coll = self._collections.get(collection_name)
if coll is None:
coll = self._client.get_collection(collection_name)
self._collections[collection_name] = coll
try:
results = coll.get(where={"doc_id": doc_id}, include=[])
if results["ids"]:
coll.delete(ids=results["ids"])
return True, None
except Exception as exc:
logger.exception("Failed to delete doc %s from %s", doc_id, collection_name)
return False, str(exc)