Update app.py
Browse files
app.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
import os
|
| 2 |
-
from langchain_huggingface import
|
| 3 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 4 |
from langchain_community.vectorstores import Chroma
|
| 5 |
from langchain_community.document_loaders import PyPDFLoader
|
| 6 |
from langchain_core.prompts import PromptTemplate
|
| 7 |
from langchain_core.output_parsers import StrOutputParser
|
| 8 |
from langchain_core.runnables import RunnablePassthrough
|
|
|
|
| 9 |
|
| 10 |
import gradio as gr
|
| 11 |
import warnings
|
|
@@ -54,16 +55,9 @@ def get_huggingface_token():
|
|
| 54 |
def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
|
| 55 |
token = get_huggingface_token()
|
| 56 |
|
| 57 |
-
# Use
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
repo_id=model_id,
|
| 61 |
-
max_new_tokens=max_tokens,
|
| 62 |
-
temperature=temperature,
|
| 63 |
-
huggingfacehub_api_token=token,
|
| 64 |
-
timeout=120, # Increase timeout for large models
|
| 65 |
-
)
|
| 66 |
-
return llm
|
| 67 |
|
| 68 |
|
| 69 |
# ---------------------------
|
|
@@ -146,31 +140,29 @@ def retriever_qa(file, query, model_choice, max_tokens, temperature, embedding_m
|
|
| 146 |
|
| 147 |
try:
|
| 148 |
selected_model = model_choice or MODEL_OPTIONS[0]
|
| 149 |
-
|
| 150 |
retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
|
| 151 |
|
| 152 |
-
#
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
{context}
|
| 155 |
|
| 156 |
-
Question: {
|
| 157 |
|
| 158 |
Answer:"""
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
rag_chain = (
|
| 167 |
-
{"context": retriever_obj | format_docs, "question": RunnablePassthrough()}
|
| 168 |
-
| prompt
|
| 169 |
-
| llm
|
| 170 |
-
| StrOutputParser()
|
| 171 |
)
|
| 172 |
|
| 173 |
-
response = rag_chain.invoke(query)
|
| 174 |
return response
|
| 175 |
except Exception as e:
|
| 176 |
import traceback
|
|
|
|
| 1 |
import os
|
| 2 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 3 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 4 |
from langchain_community.vectorstores import Chroma
|
| 5 |
from langchain_community.document_loaders import PyPDFLoader
|
| 6 |
from langchain_core.prompts import PromptTemplate
|
| 7 |
from langchain_core.output_parsers import StrOutputParser
|
| 8 |
from langchain_core.runnables import RunnablePassthrough
|
| 9 |
+
from huggingface_hub import InferenceClient
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
import warnings
|
|
|
|
| 55 |
def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
|
| 56 |
token = get_huggingface_token()
|
| 57 |
|
| 58 |
+
# Use InferenceClient directly for better reliability
|
| 59 |
+
client = InferenceClient(model=model_id, token=token)
|
| 60 |
+
return client, max_tokens, temperature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
# ---------------------------
|
|
|
|
| 140 |
|
| 141 |
try:
|
| 142 |
selected_model = model_choice or MODEL_OPTIONS[0]
|
| 143 |
+
client, max_tok, temp = get_llm(selected_model, int(max_tokens), float(temperature))
|
| 144 |
retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
|
| 145 |
|
| 146 |
+
# Get relevant documents
|
| 147 |
+
docs = retriever_obj.get_relevant_documents(query)
|
| 148 |
+
context = "\n\n".join(doc.page_content for doc in docs)
|
| 149 |
+
|
| 150 |
+
# Create prompt
|
| 151 |
+
prompt = f"""Answer the question based only on the following context:
|
| 152 |
{context}
|
| 153 |
|
| 154 |
+
Question: {query}
|
| 155 |
|
| 156 |
Answer:"""
|
| 157 |
|
| 158 |
+
# Call the model directly
|
| 159 |
+
response = client.text_generation(
|
| 160 |
+
prompt,
|
| 161 |
+
max_new_tokens=max_tok,
|
| 162 |
+
temperature=temp,
|
| 163 |
+
return_full_text=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
)
|
| 165 |
|
|
|
|
| 166 |
return response
|
| 167 |
except Exception as e:
|
| 168 |
import traceback
|