File size: 5,681 Bytes
40f2bff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from __future__ import annotations

import logging
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
from uuid import uuid4

try:
    from langchain_text_splitters import RecursiveCharacterTextSplitter
except ImportError:
    try:
        from langchain.text_splitter import RecursiveCharacterTextSplitter
    except ImportError as exc:
        raise RuntimeError("langchain and langchain-text-splitters are required for chunking") from exc

from .config import RagConfig, get_config
from .embedding import Embedding
from .model import Model
from .vector_store import VectorStore

logger = logging.getLogger(__name__)


class RagService:

    def __init__(self, model: Model, vector_store: VectorStore, config: Optional[RagConfig] = None):
        self._model = model
        self._embedder = Embedding(model)
        self._store = vector_store
        self._config = config or get_config()

    def _split_texts(
        self, texts: Iterable[str], *, chunk_size: int, overlap: int
    ) -> List[str]:
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=overlap,
        )
        chunks: List[str] = []
        for text in texts:
            chunks.extend(splitter.split_text(text))
        return chunks

    def ingest(
        self,
        collection: str,
        documents: Sequence[Dict[str, Any]],
        *,
        chunk_size: Optional[int] = None,
        overlap: Optional[int] = None,
    ) -> Dict[str, Any]:
        self._model._inspect()
        if self._model.dim is None:
            raise RuntimeError("Embedding dimension is unknown; failed to inspect model")
        self._store.ensure_collection(collection, dim=self._model.dim)

        chunk_size = chunk_size or self._config.default_chunk_size
        overlap = overlap or self._config.default_chunk_overlap

        doc_ids: List[str] = []
        chunks_indexed = 0
        errors: Dict[str, str] = {}
        points: List[Dict[str, Any]] = []

        for doc in documents:
            raw_text = doc.get("text") or ""
            if not raw_text.strip():
                logger.warning("Skipping empty document payload")
                continue

            doc_id = doc.get("id") or str(uuid4())
            metadata = doc.get("metadata") or {}

            try:
                chunk_texts = self._split_texts([raw_text], chunk_size=chunk_size, overlap=overlap)
                for idx, chunk_text in enumerate(chunk_texts):
                    point_id = str(uuid4())
                    chunk_id = f"{doc_id}:{idx}"
                    vector = self._embedder.embed_text(chunk_text)
                    points.append(
                        {
                            "id": point_id,
                            "vector": vector,
                            "payload": {
                                "chunk_id": chunk_id,
                                "doc_id": doc_id,
                                "text": chunk_text,
                                "metadata": metadata,
                                "chunk_index": idx,
                            },
                        }
                    )
                doc_ids.append(doc_id)
                chunks_indexed += len(chunk_texts)
            except Exception as exc:
                logger.exception("Failed to ingest doc %s", doc_id)
                errors[doc_id] = str(exc)

        if points:
            self._store.upsert(collection=collection, points=points)

        return {
            "doc_ids": doc_ids,
            "chunks_indexed": chunks_indexed,
            "errors": errors or None,
        }

    def search(
        self,
        collection: str,
        query: str,
        *,
        top_k: int = 5,
        score_threshold: float = 0.0,
    ) -> List[Dict[str, Any]]:
        self._model._inspect()
        if self._model.dim is None:
            raise RuntimeError("Embedding dimension is unknown; failed to inspect model")
        self._store.ensure_collection(collection, dim=self._model.dim)

        vector = self._embedder.embed_query(query)
        results = self._store.search(collection=collection, vector=vector, limit=top_k)

        hits: List[Dict[str, Any]] = []
        for hit in results:
            if score_threshold and hit.get("score", 0) < score_threshold:
                continue
            payload = hit.get("payload", {}) or {}
            hits.append(
                {
                    "chunk_id": payload.get("chunk_id") or str(hit.get("id")),
                    "text": payload.get("text", ""),
                    "score": float(hit.get("score", 0.0)),
                    "doc_id": payload.get("doc_id", ""),
                    "metadata": payload.get("metadata"),
                }
            )
        return hits

    def get_chunk(self, collection: str, chunk_id: str) -> Dict[str, Any]:
        records = self._store.retrieve(collection=collection, ids=[chunk_id])
        if not records:
            raise KeyError(f"Chunk {chunk_id} not found in {collection}")

        payload = records[0].get("payload", {}) or {}
        return {
            "chunk_id": payload.get("chunk_id") or chunk_id,
            "text": payload.get("text", ""),
            "doc_id": payload.get("doc_id", ""),
            "metadata": payload.get("metadata"),
        }

    def list_doc_ids(self, collection: str, *, limit: int = 10_000) -> List[str]:
        return self._store.list_doc_ids(collection=collection, limit=limit)

    def delete_doc(self, collection: str, doc_id: str) -> Tuple[bool, Optional[str]]:
        return self._store.delete_doc(collection=collection, doc_id=doc_id)