CatoG commited on
Commit
0c0b9d5
·
verified ·
1 Parent(s): 09754a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -27
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
- from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings
3
  from langchain_text_splitters import RecursiveCharacterTextSplitter
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain_core.prompts import PromptTemplate
7
  from langchain_core.output_parsers import StrOutputParser
8
  from langchain_core.runnables import RunnablePassthrough
 
9
 
10
  import gradio as gr
11
  import warnings
@@ -54,16 +55,9 @@ def get_huggingface_token():
54
  def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
55
  token = get_huggingface_token()
56
 
57
- # Use simpler initialization without specifying task
58
- # Let HuggingFace auto-detect the best configuration
59
- llm = HuggingFaceEndpoint(
60
- repo_id=model_id,
61
- max_new_tokens=max_tokens,
62
- temperature=temperature,
63
- huggingfacehub_api_token=token,
64
- timeout=120, # Increase timeout for large models
65
- )
66
- return llm
67
 
68
 
69
  # ---------------------------
@@ -146,31 +140,29 @@ def retriever_qa(file, query, model_choice, max_tokens, temperature, embedding_m
146
 
147
  try:
148
  selected_model = model_choice or MODEL_OPTIONS[0]
149
- llm = get_llm(selected_model, int(max_tokens), float(temperature))
150
  retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
151
 
152
- # Create a simple RAG chain
153
- template = """Answer the question based only on the following context:
 
 
 
 
154
  {context}
155
 
156
- Question: {question}
157
 
158
  Answer:"""
159
 
160
- prompt = PromptTemplate.from_template(template)
161
-
162
- def format_docs(docs):
163
- return "\n\n".join(doc.page_content for doc in docs)
164
-
165
- # Build the chain
166
- rag_chain = (
167
- {"context": retriever_obj | format_docs, "question": RunnablePassthrough()}
168
- | prompt
169
- | llm
170
- | StrOutputParser()
171
  )
172
 
173
- response = rag_chain.invoke(query)
174
  return response
175
  except Exception as e:
176
  import traceback
 
1
  import os
2
+ from langchain_huggingface import HuggingFaceEmbeddings
3
  from langchain_text_splitters import RecursiveCharacterTextSplitter
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain_core.prompts import PromptTemplate
7
  from langchain_core.output_parsers import StrOutputParser
8
  from langchain_core.runnables import RunnablePassthrough
9
+ from huggingface_hub import InferenceClient
10
 
11
  import gradio as gr
12
  import warnings
 
55
  def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
56
  token = get_huggingface_token()
57
 
58
+ # Use InferenceClient directly for better reliability
59
+ client = InferenceClient(model=model_id, token=token)
60
+ return client, max_tokens, temperature
 
 
 
 
 
 
 
61
 
62
 
63
  # ---------------------------
 
140
 
141
  try:
142
  selected_model = model_choice or MODEL_OPTIONS[0]
143
+ client, max_tok, temp = get_llm(selected_model, int(max_tokens), float(temperature))
144
  retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
145
 
146
+ # Get relevant documents
147
+ docs = retriever_obj.get_relevant_documents(query)
148
+ context = "\n\n".join(doc.page_content for doc in docs)
149
+
150
+ # Create prompt
151
+ prompt = f"""Answer the question based only on the following context:
152
  {context}
153
 
154
+ Question: {query}
155
 
156
  Answer:"""
157
 
158
+ # Call the model directly
159
+ response = client.text_generation(
160
+ prompt,
161
+ max_new_tokens=max_tok,
162
+ temperature=temp,
163
+ return_full_text=False
 
 
 
 
 
164
  )
165
 
 
166
  return response
167
  except Exception as e:
168
  import traceback