Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 3 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 4 |
from langchain_community.vectorstores import Chroma
|
|
@@ -6,7 +7,6 @@ from langchain_community.document_loaders import PyPDFLoader
|
|
| 6 |
from langchain_core.prompts import PromptTemplate
|
| 7 |
from langchain_core.output_parsers import StrOutputParser
|
| 8 |
from langchain_core.runnables import RunnablePassthrough
|
| 9 |
-
from huggingface_hub import InferenceClient
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
import warnings
|
|
@@ -53,11 +53,14 @@ def get_huggingface_token():
|
|
| 53 |
# LLM
|
| 54 |
# ---------------------------
|
| 55 |
def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
|
|
|
|
|
|
|
|
|
|
| 56 |
token = get_huggingface_token()
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
client = InferenceClient(model=model_id, token=token)
|
| 60 |
-
return client, max_tokens, temperature
|
| 61 |
|
| 62 |
|
| 63 |
# ---------------------------
|
|
@@ -140,7 +143,7 @@ def retriever_qa(file, query, model_choice, max_tokens, temperature, embedding_m
|
|
| 140 |
|
| 141 |
try:
|
| 142 |
selected_model = model_choice or MODEL_OPTIONS[0]
|
| 143 |
-
|
| 144 |
retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
|
| 145 |
|
| 146 |
# Get relevant documents
|
|
@@ -155,15 +158,28 @@ Question: {query}
|
|
| 155 |
|
| 156 |
Answer:"""
|
| 157 |
|
| 158 |
-
# Call
|
| 159 |
-
|
| 160 |
-
prompt,
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
except Exception as e:
|
| 168 |
import traceback
|
| 169 |
error_details = traceback.format_exc()
|
|
|
|
| 1 |
import os
|
| 2 |
+
import requests
|
| 3 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 4 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 5 |
from langchain_community.vectorstores import Chroma
|
|
|
|
| 7 |
from langchain_core.prompts import PromptTemplate
|
| 8 |
from langchain_core.output_parsers import StrOutputParser
|
| 9 |
from langchain_core.runnables import RunnablePassthrough
|
|
|
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
import warnings
|
|
|
|
| 53 |
# LLM
|
| 54 |
# ---------------------------
|
| 55 |
def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
|
| 56 |
+
"""
|
| 57 |
+
Returns API URL, headers, and parameters for HuggingFace Inference API.
|
| 58 |
+
"""
|
| 59 |
token = get_huggingface_token()
|
| 60 |
+
api_url = f"https://api-inference.huggingface.co/models/{model_id}"
|
| 61 |
+
headers = {"Authorization": f"Bearer {token}"}
|
| 62 |
|
| 63 |
+
return api_url, headers, max_tokens, temperature
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
# ---------------------------
|
|
|
|
| 143 |
|
| 144 |
try:
|
| 145 |
selected_model = model_choice or MODEL_OPTIONS[0]
|
| 146 |
+
api_url, headers, max_tok, temp = get_llm(selected_model, int(max_tokens), float(temperature))
|
| 147 |
retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
|
| 148 |
|
| 149 |
# Get relevant documents
|
|
|
|
| 158 |
|
| 159 |
Answer:"""
|
| 160 |
|
| 161 |
+
# Call HuggingFace Inference API directly
|
| 162 |
+
payload = {
|
| 163 |
+
"inputs": prompt,
|
| 164 |
+
"parameters": {
|
| 165 |
+
"max_new_tokens": max_tok,
|
| 166 |
+
"temperature": temp,
|
| 167 |
+
"return_full_text": False
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
response = requests.post(api_url, headers=headers, json=payload)
|
| 172 |
+
response.raise_for_status()
|
| 173 |
+
|
| 174 |
+
result = response.json()
|
| 175 |
|
| 176 |
+
# Handle different response formats
|
| 177 |
+
if isinstance(result, list) and len(result) > 0:
|
| 178 |
+
return result[0].get("generated_text", str(result))
|
| 179 |
+
elif isinstance(result, dict):
|
| 180 |
+
return result.get("generated_text", str(result))
|
| 181 |
+
else:
|
| 182 |
+
return str(result)
|
| 183 |
except Exception as e:
|
| 184 |
import traceback
|
| 185 |
error_details = traceback.format_exc()
|