krishnasimha commited on
Commit
fc0c80d
·
verified ·
1 Parent(s): f096e98

Upload app3.py

Browse files
Files changed (1) hide show
  1. app3.py +267 -0
app3.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import time
4
+ from sentence_transformers import SentenceTransformer
5
+ import os
6
+ import datetime
7
+ import feedparser
8
+ from huggingface_hub import hf_hub_download
9
+ import faiss, pickle
10
+ import aiohttp
11
+ import asyncio
12
+
13
+ # -------------------
14
+ # Load prebuilt index
15
+ # -------------------
16
+ import sqlite3
17
+
18
+ def init_cache_db():
19
+ conn = sqlite3.connect("query_cache.db")
20
+ c = conn.cursor()
21
+ c.execute("""
22
+ CREATE TABLE IF NOT EXISTS cache (
23
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
24
+ query TEXT UNIQUE,
25
+ answer TEXT,
26
+ embedding BLOB,
27
+ frequency INTEGER DEFAULT 1
28
+ )
29
+ """)
30
+ conn.commit()
31
+ return conn
32
+
33
+ cache_conn = init_cache_db()
34
+
35
+ def store_in_cache(query, answer, embedding):
36
+ c = cache_conn.cursor()
37
+ c.execute("""
38
+ INSERT OR REPLACE INTO cache (query, answer, embedding, frequency)
39
+ VALUES (?, ?, ?, COALESCE(
40
+ (SELECT frequency FROM cache WHERE query=?), 0
41
+ ) + 1)
42
+ """,
43
+ (query, answer, embedding.tobytes(), query)
44
+ )
45
+ cache_conn.commit()
46
+
47
+
48
+ def search_cache(query, embed_model, threshold=0.85):
49
+ q_emb = embed_model.encode([query], convert_to_numpy=True)[0]
50
+
51
+ c = cache_conn.cursor()
52
+ c.execute("SELECT query, answer, embedding, frequency FROM cache")
53
+ rows = c.fetchall()
54
+
55
+ best_sim = -1
56
+ best_row = None
57
+
58
+ for qry, ans, emb_blob, freq in rows:
59
+ emb = np.frombuffer(emb_blob, dtype=np.float32)
60
+ emb = emb.reshape(-1)
61
+
62
+ sim = np.dot(q_emb, emb) / (np.linalg.norm(q_emb) * np.linalg.norm(emb))
63
+
64
+ if sim > threshold and sim > best_sim:
65
+ best_sim = sim
66
+ best_row = (qry, ans, freq)
67
+
68
+ if best_row:
69
+ return best_row[1] # return only answer
70
+
71
+ return None
72
+
73
+
74
+ @st.cache_resource
75
+ def load_index():
76
+ faiss_path = hf_hub_download(
77
+ repo_id="krishnasimha/health-chatbot-data",
78
+ filename="health_index.faiss",
79
+ repo_type="dataset"
80
+ )
81
+ pkl_path = hf_hub_download(
82
+ repo_id="krishnasimha/health-chatbot-data",
83
+ filename="health_metadata.pkl",
84
+ repo_type="dataset"
85
+ )
86
+
87
+ index = faiss.read_index(faiss_path)
88
+ with open(pkl_path, "rb") as f:
89
+ metadata = pickle.load(f)
90
+
91
+ embed_model = SentenceTransformer("all-MiniLM-L6-v2")
92
+ return index, metadata, embed_model
93
+
94
+ index, metadata, embed_model = load_index()
95
+
96
+ # -------------------
97
+ # FAISS Benchmark
98
+ # -------------------
99
+ def benchmark_faiss(n_queries=100, k=3):
100
+ queries = ["What is diabetes?", "How to prevent malaria?", "Symptoms of dengue?"]
101
+ query_embs = embed_model.encode(queries, convert_to_numpy=True)
102
+
103
+ times = []
104
+ for _ in range(n_queries):
105
+ q = query_embs[np.random.randint(0, len(query_embs))].reshape(1, -1)
106
+ start = time.time()
107
+ D, I = index.search(q, k)
108
+ times.append(time.time() - start)
109
+
110
+ avg_time = np.mean(times) * 1000
111
+ st.sidebar.write(f"⚡ FAISS Benchmark: {avg_time:.2f} ms/query over {n_queries} queries")
112
+
113
+ # -------------------
114
+ # Chat session management
115
+ # -------------------
116
+ if "chats" not in st.session_state:
117
+ st.session_state.chats = {}
118
+ if "current_chat" not in st.session_state:
119
+ st.session_state.current_chat = "New Chat 1"
120
+ st.session_state.chats["New Chat 1"] = [
121
+ {"role": "system", "content": "You are a helpful public health awareness chatbot."}
122
+ ]
123
+
124
+ st.sidebar.header("Chat Manager")
125
+
126
+ if st.sidebar.button("➕ New Chat"):
127
+ chat_count = len(st.session_state.chats) + 1
128
+ new_chat_name = f"New Chat {chat_count}"
129
+ st.session_state.chats[new_chat_name] = [
130
+ {"role": "system", "content": "You are a helpful public health awareness chatbot."}
131
+ ]
132
+ st.session_state.current_chat = new_chat_name
133
+
134
+ benchmark_faiss()
135
+
136
+ chat_list = list(st.session_state.chats.keys())
137
+ selected_chat = st.sidebar.selectbox("Your chats:", chat_list, index=chat_list.index(st.session_state.current_chat))
138
+ st.session_state.current_chat = selected_chat
139
+
140
+ new_name = st.sidebar.text_input("Rename Chat:", st.session_state.current_chat)
141
+ if new_name and new_name != st.session_state.current_chat:
142
+ if new_name not in st.session_state.chats:
143
+ st.session_state.chats[new_name] = st.session_state.chats.pop(st.session_state.current_chat)
144
+ st.session_state.current_chat = new_name
145
+
146
+ # -------------------
147
+ # RSS News Fetcher (async)
148
+ # -------------------
149
+ RSS_URL = "https://news.google.com/rss/search?q=health+disease+awareness&hl=en-IN&gl=IN&ceid=IN:en"
150
+
151
+ async def fetch_rss_url(url):
152
+ async with aiohttp.ClientSession() as session:
153
+ async with session.get(url) as resp:
154
+ return await resp.text()
155
+
156
+ def fetch_news():
157
+ raw_xml = asyncio.run(fetch_rss_url(RSS_URL))
158
+ feed = feedparser.parse(raw_xml)
159
+ articles = []
160
+ for entry in feed.entries[:5]:
161
+ articles.append({
162
+ "title": entry.title,
163
+ "link": entry.link,
164
+ "published": entry.published
165
+ })
166
+ return articles
167
+
168
+ def update_news_hourly():
169
+ now = datetime.datetime.now()
170
+ if "last_news_update" not in st.session_state or (now - st.session_state.last_news_update).seconds > 3600:
171
+ st.session_state.last_news_update = now
172
+ st.session_state.news_articles = fetch_news()
173
+
174
+ # -------------------
175
+ # Async Together API
176
+ # -------------------
177
+ async def async_together_chat(messages):
178
+ url = "https://api.together.xyz/v1/chat/completions"
179
+ headers = {
180
+ "Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}",
181
+ "Content-Type": "application/json",
182
+ }
183
+ payload = {
184
+ "model": "deepseek-ai/DeepSeek-V3",
185
+ "messages": messages,
186
+ }
187
+
188
+ async with aiohttp.ClientSession() as session:
189
+ async with session.post(url, headers=headers, json=payload) as resp:
190
+ result = await resp.json()
191
+ return result["choices"][0]["message"]["content"]
192
+
193
+ # -------------------
194
+ # Query function (async call inside)
195
+ # -------------------
196
+ def retrieve_answer(query, k=3):
197
+
198
+ # 1️⃣ Try fetch from cache
199
+ cached_answer = search_cache(query, embed_model)
200
+ if cached_answer:
201
+ st.sidebar.success("⚡ Retrieved from cache")
202
+ return cached_answer, [] # no FAISS sources
203
+
204
+ # 2️⃣ If no cache → normal FAISS pipeline
205
+ query_emb = embed_model.encode([query], convert_to_numpy=True)
206
+ D, I = index.search(query_emb, k)
207
+ retrieved = [metadata["texts"][i] for i in I[0]]
208
+ sources = [metadata["sources"][i] for i in I[0]]
209
+ context = "\n".join(retrieved)
210
+
211
+ user_message = {
212
+ "role": "user",
213
+ "content": f"Answer based on the context below:\n\n{context}\n\nQuestion: {query}"
214
+ }
215
+ st.session_state.chats[st.session_state.current_chat].append(user_message)
216
+
217
+ answer = asyncio.run(async_together_chat(st.session_state.chats[st.session_state.current_chat]))
218
+
219
+ # 3️⃣ Save the new query + embedding + answer into cache
220
+ store_in_cache(query, answer, query_emb[0])
221
+
222
+ st.session_state.chats[st.session_state.current_chat].append({"role": "assistant", "content": answer})
223
+ return answer, sources
224
+
225
+
226
+ # -------------------
227
+ # Background task for news refresh
228
+ # -------------------
229
+ async def background_news_updater():
230
+ while True:
231
+ st.session_state.news_articles = fetch_news()
232
+ await asyncio.sleep(3600) # refresh every hour
233
+
234
+ if "news_task" not in st.session_state:
235
+ loop = asyncio.new_event_loop()
236
+ asyncio.set_event_loop(loop)
237
+ st.session_state.news_task = loop.create_task(background_news_updater())
238
+
239
+ # -------------------
240
+ # Streamlit UI
241
+ # -------------------
242
+ st.title(st.session_state.current_chat)
243
+
244
+ update_news_hourly()
245
+ st.subheader("📰 Latest Health Updates")
246
+ if "news_articles" in st.session_state:
247
+ for art in st.session_state.news_articles:
248
+ st.markdown(f"**{art['title']}** \n[Read more]({art['link']}) \n*Published: {art['published']}*")
249
+ st.write("---")
250
+
251
+ user_query = st.text_input("Ask me about health, prevention, or awareness:")
252
+
253
+ if user_query:
254
+ with st.spinner("Searching knowledge base..."):
255
+ answer, sources = retrieve_answer(user_query)
256
+ st.write("### 💡 Answer")
257
+ st.write(answer)
258
+
259
+ st.write("### 📖 Sources")
260
+ for src in sources:
261
+ st.write(f"- {src}")
262
+
263
+ for msg in st.session_state.chats[st.session_state.current_chat]:
264
+ if msg["role"] == "user":
265
+ st.write(f"🧑 **You:** {msg['content']}")
266
+ elif msg["role"] == "assistant":
267
+ st.write(f"🤖 **Bot:** {msg['content']}")