CatoG commited on
Commit
caca053
·
verified ·
1 Parent(s): a971c3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -8
app.py CHANGED
@@ -3,7 +3,9 @@ from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings
3
  from langchain_text_splitters import RecursiveCharacterTextSplitter
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_community.document_loaders import PyPDFLoader
6
- from langchain.chains.retrieval_qa.base import RetrievalQA
 
 
7
 
8
  import gradio as gr
9
  import warnings
@@ -149,14 +151,30 @@ def retriever_qa(file, query, model_choice, max_tokens, temperature, embedding_m
149
  selected_model = model_choice or MODEL_OPTIONS[0]
150
  llm = get_llm(selected_model, int(max_tokens), float(temperature))
151
  retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
152
- qa = RetrievalQA.from_chain_type(
153
- llm=llm,
154
- chain_type="stuff",
155
- retriever=retriever_obj,
156
- return_source_documents=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  )
158
- response = qa.invoke({"query": query})
159
- return response['result']
 
160
  except Exception as e:
161
  return f"Error: {str(e)}"
162
 
 
3
  from langchain_text_splitters import RecursiveCharacterTextSplitter
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_community.document_loaders import PyPDFLoader
6
+ from langchain_core.prompts import PromptTemplate
7
+ from langchain_core.output_parsers import StrOutputParser
8
+ from langchain_core.runnables import RunnablePassthrough
9
 
10
  import gradio as gr
11
  import warnings
 
151
  selected_model = model_choice or MODEL_OPTIONS[0]
152
  llm = get_llm(selected_model, int(max_tokens), float(temperature))
153
  retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
154
+
155
+ # Create a simple RAG chain
156
+ template = """Answer the question based only on the following context:
157
+ {context}
158
+
159
+ Question: {question}
160
+
161
+ Answer:"""
162
+
163
+ prompt = PromptTemplate.from_template(template)
164
+
165
+ def format_docs(docs):
166
+ return "\n\n".join(doc.page_content for doc in docs)
167
+
168
+ # Build the chain
169
+ rag_chain = (
170
+ {"context": retriever_obj | format_docs, "question": RunnablePassthrough()}
171
+ | prompt
172
+ | llm
173
+ | StrOutputParser()
174
  )
175
+
176
+ response = rag_chain.invoke(query)
177
+ return response
178
  except Exception as e:
179
  return f"Error: {str(e)}"
180