|
|
|
|
|
import time |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
from src.retrieval.retriever_chain import get_base_retriever, load_hf_llm, create_qa_chain, create_qa_chain_eval |
|
|
|
|
|
|
|
|
HF_MODEL = "huggingfaceh4/zephyr-7b-beta" |
|
|
EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5" |
|
|
|
|
|
|
|
|
|
|
|
def get_qa_chain(): |
|
|
""" |
|
|
Instantiates QA Chain. |
|
|
|
|
|
Returns: |
|
|
Runnable: Returns an instance of QA Chain. |
|
|
""" |
|
|
|
|
|
|
|
|
retriever = get_base_retriever(embedding_model=EMBEDDING_MODEL, k=4, search_type="mmr") |
|
|
|
|
|
|
|
|
llm = load_hf_llm(repo_id=HF_MODEL, max_new_tokens=512, temperature=0.4) |
|
|
|
|
|
|
|
|
qa_chain = create_qa_chain(retriever, llm) |
|
|
|
|
|
return qa_chain |
|
|
|
|
|
|
|
|
def set_global_qa_chain(local_qa_chain): |
|
|
""" |
|
|
Sets the Global QA Chain. |
|
|
|
|
|
Args: |
|
|
local_qa_chain: Local QA Chain. |
|
|
""" |
|
|
global global_qa_chain |
|
|
global_qa_chain = local_qa_chain |
|
|
|
|
|
|
|
|
def invoke_chain(query: str): |
|
|
""" |
|
|
Invokes the chain to generate the response. |
|
|
|
|
|
Args: |
|
|
query (str): Question asked by the user. |
|
|
|
|
|
Returns: |
|
|
str: Returns the generated response. |
|
|
""" |
|
|
max_attempts = 3 |
|
|
|
|
|
for attempt in range(max_attempts): |
|
|
try: |
|
|
response = global_qa_chain.invoke(query) |
|
|
return response |
|
|
except Exception as e: |
|
|
print(f"Attempt {attempt + 1} failed with error:", e) |
|
|
else: |
|
|
return "All attempts failed. Unable to get response." |
|
|
|
|
|
|
|
|
|
|
|
def generate_response(query: str): |
|
|
""" |
|
|
Generates response based on the question being asked. |
|
|
|
|
|
Args: |
|
|
query (str): Question asked by the user. |
|
|
history (dict): Chat history. NOT USED FOR NOW. |
|
|
|
|
|
Returns: |
|
|
str: Returns the generated response. |
|
|
""" |
|
|
|
|
|
|
|
|
print("*" * 100) |
|
|
print("Question:", query) |
|
|
start_time = time.time() |
|
|
try: |
|
|
response = global_qa_chain.invoke(query) |
|
|
except Exception as e: |
|
|
print("Error:", e) |
|
|
response = global_qa_chain.invoke(query) |
|
|
print("Answer:", response) |
|
|
end_time = time.time() |
|
|
print("Response Time:", "{:.2f}".format(round(end_time - start_time, 2))) |
|
|
print("*" * 100) |
|
|
return response |
|
|
|
|
|
|
|
|
def generate_response_streamlit(message: str, history: Optional[dict]): |
|
|
""" |
|
|
Generates response based on the question being asked. |
|
|
|
|
|
Args: |
|
|
message (str): Question asked by the user. |
|
|
history (dict): Chat history. NOT USED FOR NOW. |
|
|
|
|
|
Returns: |
|
|
str: Returns the generated response. |
|
|
""" |
|
|
|
|
|
response = generate_response(message) |
|
|
response = response.replace("$", "\$").replace(":", "\:").replace("\n", " \n") |
|
|
for word in response.split(" "): |
|
|
yield word + " " |
|
|
time.sleep(0.05) |
|
|
|
|
|
|
|
|
def generate_response_gradio(message: str, history: Optional[dict]): |
|
|
""" |
|
|
Generates response based on the question being asked. |
|
|
|
|
|
Args: |
|
|
message (str): Question asked by the user. |
|
|
history (dict): Chat history. NOT USED FOR NOW. |
|
|
|
|
|
Returns: |
|
|
str: Returns the generated response. |
|
|
""" |
|
|
|
|
|
response = generate_response(message) |
|
|
for i in range(len(response)): |
|
|
time.sleep(0.01) |
|
|
yield response[: i+1] |
|
|
|
|
|
|
|
|
def has_global_variable(): |
|
|
""" |
|
|
Checks if global_qa_chain has been set. |
|
|
|
|
|
Returns: |
|
|
bool: Returns True if set. Otherwise False. |
|
|
""" |
|
|
if 'global_qa_chain' in globals(): |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
def get_qa_chain_eval(): |
|
|
""" |
|
|
Instantiates QA Chain for evaluation. |
|
|
|
|
|
Returns: |
|
|
Runnable: Returns an instance of QA Chain. |
|
|
""" |
|
|
|
|
|
|
|
|
retriever = get_base_retriever(embedding_model=EMBEDDING_MODEL, k=4, search_type="mmr") |
|
|
|
|
|
|
|
|
llm = load_hf_llm(repo_id=HF_MODEL, max_new_tokens=512, temperature=0.4) |
|
|
|
|
|
|
|
|
qa_chain = create_qa_chain_eval(retriever, llm) |
|
|
|
|
|
return qa_chain |