CatoG commited on
Commit
ec573e5
·
verified ·
1 Parent(s): 17f8afb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -13
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from langchain_huggingface import HuggingFaceEmbeddings
3
  from langchain_text_splitters import RecursiveCharacterTextSplitter
4
  from langchain_community.vectorstores import Chroma
@@ -6,7 +7,6 @@ from langchain_community.document_loaders import PyPDFLoader
6
  from langchain_core.prompts import PromptTemplate
7
  from langchain_core.output_parsers import StrOutputParser
8
  from langchain_core.runnables import RunnablePassthrough
9
- from huggingface_hub import InferenceClient
10
 
11
  import gradio as gr
12
  import warnings
@@ -53,11 +53,14 @@ def get_huggingface_token():
53
  # LLM
54
  # ---------------------------
55
  def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
 
 
 
56
  token = get_huggingface_token()
 
 
57
 
58
- # Use InferenceClient directly for better reliability
59
- client = InferenceClient(model=model_id, token=token)
60
- return client, max_tokens, temperature
61
 
62
 
63
  # ---------------------------
@@ -140,7 +143,7 @@ def retriever_qa(file, query, model_choice, max_tokens, temperature, embedding_m
140
 
141
  try:
142
  selected_model = model_choice or MODEL_OPTIONS[0]
143
- client, max_tok, temp = get_llm(selected_model, int(max_tokens), float(temperature))
144
  retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
145
 
146
  # Get relevant documents
@@ -155,15 +158,28 @@ Question: {query}
155
 
156
  Answer:"""
157
 
158
- # Call the model directly
159
- response = client.text_generation(
160
- prompt,
161
- max_new_tokens=max_tok,
162
- temperature=temp,
163
- return_full_text=False
164
- )
 
 
 
 
 
 
 
165
 
166
- return response
 
 
 
 
 
 
167
  except Exception as e:
168
  import traceback
169
  error_details = traceback.format_exc()
 
1
  import os
2
+ import requests
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain_text_splitters import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import Chroma
 
7
  from langchain_core.prompts import PromptTemplate
8
  from langchain_core.output_parsers import StrOutputParser
9
  from langchain_core.runnables import RunnablePassthrough
 
10
 
11
  import gradio as gr
12
  import warnings
 
53
  # LLM
54
  # ---------------------------
55
  def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
56
+ """
57
+ Returns API URL, headers, and parameters for HuggingFace Inference API.
58
+ """
59
  token = get_huggingface_token()
60
+ api_url = f"https://api-inference.huggingface.co/models/{model_id}"
61
+ headers = {"Authorization": f"Bearer {token}"}
62
 
63
+ return api_url, headers, max_tokens, temperature
 
 
64
 
65
 
66
  # ---------------------------
 
143
 
144
  try:
145
  selected_model = model_choice or MODEL_OPTIONS[0]
146
+ api_url, headers, max_tok, temp = get_llm(selected_model, int(max_tokens), float(temperature))
147
  retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
148
 
149
  # Get relevant documents
 
158
 
159
  Answer:"""
160
 
161
+ # Call HuggingFace Inference API directly
162
+ payload = {
163
+ "inputs": prompt,
164
+ "parameters": {
165
+ "max_new_tokens": max_tok,
166
+ "temperature": temp,
167
+ "return_full_text": False
168
+ }
169
+ }
170
+
171
+ response = requests.post(api_url, headers=headers, json=payload)
172
+ response.raise_for_status()
173
+
174
+ result = response.json()
175
 
176
+ # Handle different response formats
177
+ if isinstance(result, list) and len(result) > 0:
178
+ return result[0].get("generated_text", str(result))
179
+ elif isinstance(result, dict):
180
+ return result.get("generated_text", str(result))
181
+ else:
182
+ return str(result)
183
  except Exception as e:
184
  import traceback
185
  error_details = traceback.format_exc()