File size: 4,335 Bytes
0a6f6d8 f19541b 27b774b 5e8a58c 0a6f6d8 789a1ff 27b774b 0a6f6d8 27b774b 0a6f6d8 05abef9 0a6f6d8 27b774b 0a6f6d8 05abef9 27b774b 0a6f6d8 27b774b 0a6f6d8 27b774b f19541b 05abef9 f19541b 27b774b 0a6f6d8 27b774b b1bcc1e 05abef9 27b774b 05abef9 5e8a58c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# import libraries
import time
from typing import Optional
# import functions
from src.retrieval.retriever_chain import get_base_retriever, load_hf_llm, create_qa_chain, create_qa_chain_eval
# constants
HF_MODEL = "huggingfaceh4/zephyr-7b-beta" # "mistralai/Mistral-7B-Instruct-v0.2" # "google/gemma-7b"
EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5"
# get the qa chain
def get_qa_chain():
"""
Instantiates QA Chain.
Returns:
Runnable: Returns an instance of QA Chain.
"""
# get retriever
retriever = get_base_retriever(embedding_model=EMBEDDING_MODEL, k=4, search_type="mmr")
# instantiate llm
llm = load_hf_llm(repo_id=HF_MODEL, max_new_tokens=512, temperature=0.4)
# instantiate qa chain
qa_chain = create_qa_chain(retriever, llm)
return qa_chain
# function to get the global qa chain
def set_global_qa_chain(local_qa_chain):
"""
Sets the Global QA Chain.
Args:
local_qa_chain: Local QA Chain.
"""
global global_qa_chain
global_qa_chain = local_qa_chain
# function to invoke the rag chain
def invoke_chain(query: str):
"""
Invokes the chain to generate the response.
Args:
query (str): Question asked by the user.
Returns:
str: Returns the generated response.
"""
max_attempts = 3 # Maximum number of retry attempts
for attempt in range(max_attempts):
try:
response = global_qa_chain.invoke(query)
return response # If successful, return the response
except Exception as e:
print(f"Attempt {attempt + 1} failed with error:", e)
else:
return "All attempts failed. Unable to get response."
# function to generate streamlit response
def generate_response(query: str):
"""
Generates response based on the question being asked.
Args:
query (str): Question asked by the user.
history (dict): Chat history. NOT USED FOR NOW.
Returns:
str: Returns the generated response.
"""
# invoke chain
print("*" * 100)
print("Question:", query)
start_time = time.time()
try:
response = global_qa_chain.invoke(query)
except Exception as e:
print("Error:", e)
response = global_qa_chain.invoke(query)
print("Answer:", response)
end_time = time.time()
print("Response Time:", "{:.2f}".format(round(end_time - start_time, 2)))
print("*" * 100)
return response
# function to generate streamlit response
def generate_response_streamlit(message: str, history: Optional[dict]):
"""
Generates response based on the question being asked.
Args:
message (str): Question asked by the user.
history (dict): Chat history. NOT USED FOR NOW.
Returns:
str: Returns the generated response.
"""
response = generate_response(message)
response = response.replace("$", "\$").replace(":", "\:").replace("\n", " \n")
for word in response.split(" "):
yield word + " "
time.sleep(0.05)
# function to generate response
def generate_response_gradio(message: str, history: Optional[dict]):
"""
Generates response based on the question being asked.
Args:
message (str): Question asked by the user.
history (dict): Chat history. NOT USED FOR NOW.
Returns:
str: Returns the generated response.
"""
response = generate_response(message)
for i in range(len(response)):
time.sleep(0.01)
yield response[: i+1]
# function to check if the global variable has been set
def has_global_variable():
"""
Checks if global_qa_chain has been set.
Returns:
bool: Returns True if set. Otherwise False.
"""
if 'global_qa_chain' in globals():
return True
return False
# get the qa chain for evaluation
def get_qa_chain_eval():
"""
Instantiates QA Chain for evaluation.
Returns:
Runnable: Returns an instance of QA Chain.
"""
# get retriever
retriever = get_base_retriever(embedding_model=EMBEDDING_MODEL, k=4, search_type="mmr")
# instantiate llm
llm = load_hf_llm(repo_id=HF_MODEL, max_new_tokens=512, temperature=0.4)
# instantiate qa chain
qa_chain = create_qa_chain_eval(retriever, llm)
return qa_chain |