File size: 4,335 Bytes
0a6f6d8
f19541b
27b774b
5e8a58c
 
 
0a6f6d8
 
789a1ff
27b774b
0a6f6d8
 
 
 
 
 
 
 
 
 
 
 
27b774b
0a6f6d8
 
 
 
 
 
 
 
 
05abef9
0a6f6d8
27b774b
 
 
 
 
 
0a6f6d8
 
05abef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27b774b
 
 
0a6f6d8
 
 
 
27b774b
0a6f6d8
 
 
 
 
 
 
27b774b
 
f19541b
05abef9
 
 
 
 
f19541b
 
 
27b774b
0a6f6d8
27b774b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1bcc1e
05abef9
27b774b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05abef9
 
 
 
 
 
 
 
 
 
 
 
5e8a58c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# import libraries
import time
from typing import Optional

# import functions
from src.retrieval.retriever_chain import get_base_retriever, load_hf_llm, create_qa_chain, create_qa_chain_eval

# constants
HF_MODEL        = "huggingfaceh4/zephyr-7b-beta"  # "mistralai/Mistral-7B-Instruct-v0.2" # "google/gemma-7b"
EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5"


# get the qa chain
def get_qa_chain():
    """
    Instantiates QA Chain.

    Returns:
        Runnable: Returns an instance of QA Chain.
    """

    # get retriever
    retriever = get_base_retriever(embedding_model=EMBEDDING_MODEL, k=4, search_type="mmr")

    # instantiate llm
    llm = load_hf_llm(repo_id=HF_MODEL, max_new_tokens=512, temperature=0.4)

    # instantiate qa chain
    qa_chain = create_qa_chain(retriever, llm)

    return qa_chain

# function to get the global qa chain
def set_global_qa_chain(local_qa_chain):
    """
    Sets the Global QA Chain.

    Args:
        local_qa_chain: Local QA Chain.
    """
    global global_qa_chain
    global_qa_chain = local_qa_chain

# function to invoke the rag chain
def invoke_chain(query: str):
    """
    Invokes the chain to generate the response.

    Args:
        query (str): Question asked by the user.

    Returns:
        str: Returns the generated response.
    """
    max_attempts = 3  # Maximum number of retry attempts

    for attempt in range(max_attempts):
        try:
            response = global_qa_chain.invoke(query)
            return response  # If successful, return the response
        except Exception as e:
            print(f"Attempt {attempt + 1} failed with error:", e)
    else:
        return "All attempts failed. Unable to get response."

    
# function to generate streamlit response
def generate_response(query: str):
    """
    Generates response based on the question being asked.

    Args:
        query (str): Question asked by the user.
        history (dict): Chat history. NOT USED FOR NOW.

    Returns:
        str: Returns the generated response.
    """

    # invoke chain
    print("*" * 100)
    print("Question:", query)
    start_time = time.time()
    try:
        response = global_qa_chain.invoke(query)
    except Exception as e:
        print("Error:", e)
        response = global_qa_chain.invoke(query)
    print("Answer:", response)
    end_time = time.time()
    print("Response Time:", "{:.2f}".format(round(end_time - start_time, 2)))
    print("*" * 100)
    return response

# function to generate streamlit response
def generate_response_streamlit(message: str, history: Optional[dict]):
    """
    Generates response based on the question being asked.

    Args:
        message (str): Question asked by the user.
        history (dict): Chat history. NOT USED FOR NOW.

    Returns:
        str: Returns the generated response.
    """

    response = generate_response(message)
    response = response.replace("$", "\$").replace(":", "\:").replace("\n", "  \n")
    for word in response.split(" "):
        yield word + " "
        time.sleep(0.05)
        
# function to generate response
def generate_response_gradio(message: str, history: Optional[dict]):
    """
    Generates response based on the question being asked.

    Args:
        message (str): Question asked by the user.
        history (dict): Chat history. NOT USED FOR NOW.

    Returns:
        str: Returns the generated response.
    """

    response = generate_response(message)
    for i in range(len(response)):
        time.sleep(0.01)
        yield response[: i+1]

# function to check if the global variable has been set
def has_global_variable():
    """
    Checks if global_qa_chain has been set.

    Returns:
        bool: Returns True if set. Otherwise False.
    """
    if 'global_qa_chain' in globals():
        return True
    
    return False

# get the qa chain for evaluation
def get_qa_chain_eval():
    """
    Instantiates QA Chain for evaluation.

    Returns:
        Runnable: Returns an instance of QA Chain.
    """

    # get retriever
    retriever = get_base_retriever(embedding_model=EMBEDDING_MODEL, k=4, search_type="mmr")

    # instantiate llm
    llm = load_hf_llm(repo_id=HF_MODEL, max_new_tokens=512, temperature=0.4)

    # instantiate qa chain
    qa_chain = create_qa_chain_eval(retriever, llm)

    return qa_chain