Spaces:
Running
Running
File size: 5,681 Bytes
40f2bff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from __future__ import annotations
import logging
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
from uuid import uuid4
try:
from langchain_text_splitters import RecursiveCharacterTextSplitter
except ImportError:
try:
from langchain.text_splitter import RecursiveCharacterTextSplitter
except ImportError as exc:
raise RuntimeError("langchain and langchain-text-splitters are required for chunking") from exc
from .config import RagConfig, get_config
from .embedding import Embedding
from .model import Model
from .vector_store import VectorStore
logger = logging.getLogger(__name__)
class RagService:
def __init__(self, model: Model, vector_store: VectorStore, config: Optional[RagConfig] = None):
self._model = model
self._embedder = Embedding(model)
self._store = vector_store
self._config = config or get_config()
def _split_texts(
self, texts: Iterable[str], *, chunk_size: int, overlap: int
) -> List[str]:
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=overlap,
)
chunks: List[str] = []
for text in texts:
chunks.extend(splitter.split_text(text))
return chunks
def ingest(
self,
collection: str,
documents: Sequence[Dict[str, Any]],
*,
chunk_size: Optional[int] = None,
overlap: Optional[int] = None,
) -> Dict[str, Any]:
self._model._inspect()
if self._model.dim is None:
raise RuntimeError("Embedding dimension is unknown; failed to inspect model")
self._store.ensure_collection(collection, dim=self._model.dim)
chunk_size = chunk_size or self._config.default_chunk_size
overlap = overlap or self._config.default_chunk_overlap
doc_ids: List[str] = []
chunks_indexed = 0
errors: Dict[str, str] = {}
points: List[Dict[str, Any]] = []
for doc in documents:
raw_text = doc.get("text") or ""
if not raw_text.strip():
logger.warning("Skipping empty document payload")
continue
doc_id = doc.get("id") or str(uuid4())
metadata = doc.get("metadata") or {}
try:
chunk_texts = self._split_texts([raw_text], chunk_size=chunk_size, overlap=overlap)
for idx, chunk_text in enumerate(chunk_texts):
point_id = str(uuid4())
chunk_id = f"{doc_id}:{idx}"
vector = self._embedder.embed_text(chunk_text)
points.append(
{
"id": point_id,
"vector": vector,
"payload": {
"chunk_id": chunk_id,
"doc_id": doc_id,
"text": chunk_text,
"metadata": metadata,
"chunk_index": idx,
},
}
)
doc_ids.append(doc_id)
chunks_indexed += len(chunk_texts)
except Exception as exc:
logger.exception("Failed to ingest doc %s", doc_id)
errors[doc_id] = str(exc)
if points:
self._store.upsert(collection=collection, points=points)
return {
"doc_ids": doc_ids,
"chunks_indexed": chunks_indexed,
"errors": errors or None,
}
def search(
self,
collection: str,
query: str,
*,
top_k: int = 5,
score_threshold: float = 0.0,
) -> List[Dict[str, Any]]:
self._model._inspect()
if self._model.dim is None:
raise RuntimeError("Embedding dimension is unknown; failed to inspect model")
self._store.ensure_collection(collection, dim=self._model.dim)
vector = self._embedder.embed_query(query)
results = self._store.search(collection=collection, vector=vector, limit=top_k)
hits: List[Dict[str, Any]] = []
for hit in results:
if score_threshold and hit.get("score", 0) < score_threshold:
continue
payload = hit.get("payload", {}) or {}
hits.append(
{
"chunk_id": payload.get("chunk_id") or str(hit.get("id")),
"text": payload.get("text", ""),
"score": float(hit.get("score", 0.0)),
"doc_id": payload.get("doc_id", ""),
"metadata": payload.get("metadata"),
}
)
return hits
def get_chunk(self, collection: str, chunk_id: str) -> Dict[str, Any]:
records = self._store.retrieve(collection=collection, ids=[chunk_id])
if not records:
raise KeyError(f"Chunk {chunk_id} not found in {collection}")
payload = records[0].get("payload", {}) or {}
return {
"chunk_id": payload.get("chunk_id") or chunk_id,
"text": payload.get("text", ""),
"doc_id": payload.get("doc_id", ""),
"metadata": payload.get("metadata"),
}
def list_doc_ids(self, collection: str, *, limit: int = 10_000) -> List[str]:
return self._store.list_doc_ids(collection=collection, limit=limit)
def delete_doc(self, collection: str, doc_id: str) -> Tuple[bool, Optional[str]]:
return self._store.delete_doc(collection=collection, doc_id=doc_id)
|