krishnasimha commited on
Commit
92ca83b
Β·
verified Β·
1 Parent(s): 8eb4002

Upload app5.py

Browse files
Files changed (1) hide show
  1. app5.py +438 -0
app5.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import time
4
+ import datetime
5
+ import asyncio
6
+ import sqlite3
7
+ import pickle
8
+
9
+ import streamlit as st
10
+ import numpy as np
11
+ import pandas as pd
12
+ import feedparser
13
+ import aiohttp
14
+
15
+ from sentence_transformers import SentenceTransformer
16
+ from sentence_transformers import CrossEncoder
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ # Optional S3 support
20
+ try:
21
+ import boto3
22
+ BOTO3_AVAILABLE = True
23
+ except Exception:
24
+ BOTO3_AVAILABLE = False
25
+
26
+ import faiss
27
+
28
+ # -------------------------
29
+ # Initialize DB & helpers
30
+ # -------------------------
31
+ DB_PATH = "query_cache.db"
32
+
33
+ def init_cache_db():
34
+ conn = sqlite3.connect(DB_PATH)
35
+ c = conn.cursor()
36
+ c.execute("""
37
+ CREATE TABLE IF NOT EXISTS cache (
38
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
39
+ query TEXT UNIQUE,
40
+ answer TEXT,
41
+ embedding BLOB,
42
+ frequency INTEGER DEFAULT 1
43
+ )
44
+ """)
45
+ conn.commit()
46
+ conn.close()
47
+
48
+ def init_export_logs():
49
+ conn = sqlite3.connect(DB_PATH)
50
+ c = conn.cursor()
51
+ c.execute("""
52
+ CREATE TABLE IF NOT EXISTS export_logs (
53
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
54
+ exported_on TEXT,
55
+ file_name TEXT
56
+ )
57
+ """)
58
+ conn.commit()
59
+ conn.close()
60
+
61
+ init_cache_db()
62
+ init_export_logs()
63
+
64
+ def get_db_connection():
65
+ return sqlite3.connect(DB_PATH)
66
+
67
+ # -------------------------
68
+ # Cache store/search
69
+ # -------------------------
70
+ def store_in_cache(query, answer, embedding):
71
+ conn = get_db_connection()
72
+ c = conn.cursor()
73
+ c.execute("""
74
+ INSERT OR REPLACE INTO cache (query, answer, embedding, frequency)
75
+ VALUES (?, ?, ?, COALESCE(
76
+ (SELECT frequency FROM cache WHERE query=?), 0
77
+ ) + 1)
78
+ """, (query, answer, embedding.tobytes(), query))
79
+ conn.commit()
80
+ conn.close()
81
+
82
+ def search_cache(query, embed_model, threshold=0.85):
83
+ q_emb = embed_model.encode([query], convert_to_numpy=True)[0]
84
+
85
+ conn = get_db_connection()
86
+ c = conn.cursor()
87
+ c.execute("SELECT query, answer, embedding, frequency FROM cache")
88
+ rows = c.fetchall()
89
+ conn.close()
90
+
91
+ best_sim = -1
92
+ best_row = None
93
+
94
+ for qry, ans, emb_blob, freq in rows:
95
+ try:
96
+ emb = np.frombuffer(emb_blob, dtype=np.float32)
97
+ except Exception:
98
+ continue
99
+ emb = emb.reshape(-1)
100
+ sim = np.dot(q_emb, emb) / (np.linalg.norm(q_emb) * np.linalg.norm(emb) + 1e-12)
101
+ if sim > threshold and sim > best_sim:
102
+ best_sim = sim
103
+ best_row = (qry, ans, freq)
104
+
105
+ if best_row:
106
+ return best_row[1]
107
+ return None
108
+
109
+ # -------------------------
110
+ # Exports
111
+ # -------------------------
112
+ def export_cache_to_excel():
113
+ conn = get_db_connection()
114
+ df = pd.read_sql_query("SELECT * FROM cache", conn)
115
+ conn.close()
116
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
117
+ file_name = f"cache_export_{timestamp}.xlsx"
118
+ df.to_excel(file_name, index=False)
119
+ # log export
120
+ conn = get_db_connection()
121
+ c = conn.cursor()
122
+ c.execute("INSERT INTO export_logs (exported_on, file_name) VALUES (?, ?)",
123
+ (datetime.datetime.now().isoformat(), file_name))
124
+ conn.commit()
125
+ conn.close()
126
+ return file_name
127
+
128
+ def export_cache_to_sql():
129
+ conn = get_db_connection()
130
+ dump_path = f"cache_dump_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.sql"
131
+ with open(dump_path, "w", encoding="utf-8") as f:
132
+ for line in conn.iterdump():
133
+ f.write("%s\n" % line)
134
+ conn.close()
135
+ # log export
136
+ conn = get_db_connection()
137
+ c = conn.cursor()
138
+ c.execute("INSERT INTO export_logs (exported_on, file_name) VALUES (?, ?)",
139
+ (datetime.datetime.now().isoformat(), dump_path))
140
+ conn.commit()
141
+ conn.close()
142
+ return dump_path
143
+
144
+ # -------------------------
145
+ # Optional: S3 upload
146
+ # -------------------------
147
+ def upload_file_to_s3(local_path, bucket_name, object_name=None):
148
+ if not BOTO3_AVAILABLE:
149
+ return False, "boto3 not installed"
150
+ if object_name is None:
151
+ object_name = os.path.basename(local_path)
152
+ try:
153
+ s3 = boto3.client(
154
+ "s3",
155
+ aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
156
+ aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
157
+ region_name=os.environ.get("AWS_DEFAULT_REGION")
158
+ )
159
+ s3.upload_file(local_path, bucket_name, object_name)
160
+ return True, f"s3://{bucket_name}/{object_name}"
161
+ except Exception as e:
162
+ return False, str(e)
163
+
164
+ # -------------------------
165
+ # Load FAISS index
166
+ # -------------------------
167
+ @st.cache_resource
168
+ def load_index():
169
+ faiss_path = hf_hub_download("krishnasimha/health-chatbot-data", "health_index.faiss", repo_type="dataset")
170
+ pkl_path = hf_hub_download("krishnasimha/health-chatbot-data", "health_metadata.pkl", repo_type="dataset")
171
+ index = faiss.read_index(faiss_path)
172
+ with open(pkl_path, "rb") as f:
173
+ metadata = pickle.load(f)
174
+ embed_model = SentenceTransformer("all-MiniLM-L6-v2")
175
+ return index, metadata, embed_model
176
+
177
+ index, metadata, embed_model = load_index()
178
+
179
+ # -------------------------
180
+ # Load Reranker (Cross-Encoder)
181
+ # -------------------------
182
+ @st.cache_resource
183
+ def load_reranker():
184
+ # Cross-encoder β€” good speed/quality tradeoff
185
+ return CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
186
+
187
+ reranker = load_reranker()
188
+
189
+ # -------------------------
190
+ # FAISS benchmark
191
+ # -------------------------
192
+ def benchmark_faiss(n_queries=100, k=3):
193
+ queries = ["What is diabetes?", "How to prevent malaria?", "Symptoms of dengue?"]
194
+ query_embs = embed_model.encode(queries, convert_to_numpy=True)
195
+ times = []
196
+ for _ in range(n_queries):
197
+ q = query_embs[np.random.randint(0, len(query_embs))].reshape(1, -1)
198
+ start = time.time()
199
+ D, I = index.search(q, k)
200
+ times.append(time.time() - start)
201
+ avg_time = np.mean(times) * 1000
202
+ st.sidebar.write(f"⚑ FAISS Benchmark: {avg_time:.2f} ms/query over {n_queries} queries")
203
+
204
+ # -------------------------
205
+ # RSS / News
206
+ # -------------------------
207
+ RSS_URL = "https://news.google.com/rss/search?q=health+disease+awareness&hl=en-IN&gl=IN&ceid=IN:en"
208
+
209
+ async def fetch_rss_url(url):
210
+ async with aiohttp.ClientSession() as session:
211
+ async with session.get(url) as resp:
212
+ return await resp.text()
213
+
214
+ def fetch_news():
215
+ try:
216
+ raw_xml = asyncio.run(fetch_rss_url(RSS_URL))
217
+ feed = feedparser.parse(raw_xml)
218
+ articles = [{"title": e.get("title",""), "link": e.get("link",""), "published": e.get("published","")} for e in feed.entries[:5]]
219
+ return articles
220
+ except Exception:
221
+ return []
222
+
223
+ def update_news_hourly():
224
+ now = datetime.datetime.now()
225
+ if "last_news_update" not in st.session_state or (now - st.session_state.last_news_update).seconds > 3600:
226
+ st.session_state.last_news_update = now
227
+ st.session_state.news_articles = fetch_news()
228
+
229
+ # -------------------------
230
+ # Together API
231
+ # -------------------------
232
+ async def async_together_chat(messages):
233
+ url = "https://api.together.xyz/v1/chat/completions"
234
+ headers = {
235
+ "Authorization": f"Bearer {os.environ.get('TOGETHER_API_KEY','')}",
236
+ "Content-Type": "application/json",
237
+ }
238
+ payload = {"model": "deepseek-ai/DeepSeek-V3", "messages": messages}
239
+ async with aiohttp.ClientSession() as session:
240
+ async with session.post(url, headers=headers, json=payload) as resp:
241
+ result = await resp.json()
242
+ return result["choices"][0]["message"]["content"]
243
+
244
+ # -------------------------
245
+ # Retrieve answer (with reranker)
246
+ # -------------------------
247
+ def retrieve_answer(query, k=3):
248
+ cached_answer = search_cache(query, embed_model)
249
+ if cached_answer:
250
+ st.sidebar.success("⚑ Retrieved from cache")
251
+ return cached_answer, []
252
+
253
+ # Encode query
254
+ query_emb = embed_model.encode([query], convert_to_numpy=True)
255
+
256
+ # 1) FAISS retrieves more candidates (we fetch 10 for reranking)
257
+ fetch_k = max(k, 10)
258
+ D, I = index.search(query_emb, fetch_k)
259
+
260
+ retrieved = [metadata["texts"][i] for i in I[0]]
261
+ sources = [metadata["sources"][i] for i in I[0]]
262
+
263
+ # -----------------------------
264
+ # Cross-Encoder Reranking step
265
+ # -----------------------------
266
+ try:
267
+ pairs = [[query, chunk] for chunk in retrieved]
268
+ scores = reranker.predict(pairs)
269
+ reranked = sorted(zip(scores, retrieved, sources), key=lambda x: x[0], reverse=True)
270
+ top_reranked = reranked[:k]
271
+ top_chunks = [c for _, c, _ in top_reranked]
272
+ top_sources = [s for _, _, s in top_reranked]
273
+ context = "\n".join(top_chunks)
274
+ sources = top_sources
275
+ except Exception:
276
+ # fallback: if reranker fails for any reason, use the original retrieved top-k
277
+ context = "\n".join(retrieved[:k])
278
+ sources = sources[:k]
279
+
280
+ user_message = {"role":"user", "content": f"Answer based on context:\n{context}\n\nQuestion: {query}"}
281
+ st.session_state.chats[st.session_state.current_chat].append(user_message)
282
+
283
+ try:
284
+ answer = asyncio.run(async_together_chat(st.session_state.chats[st.session_state.current_chat]))
285
+ except Exception as e:
286
+ answer = f"Error: {e}"
287
+
288
+ try:
289
+ store_in_cache(query, answer, query_emb[0])
290
+ except Exception:
291
+ pass
292
+
293
+ st.session_state.chats[st.session_state.current_chat].append({"role": "assistant", "content": answer})
294
+ return answer, sources
295
+
296
+ # -------------------------
297
+ # Background news
298
+ # -------------------------
299
+ async def background_news_updater():
300
+ while True:
301
+ st.session_state.news_articles = fetch_news()
302
+ await asyncio.sleep(3600)
303
+
304
+ if "news_task" not in st.session_state:
305
+ loop = asyncio.new_event_loop()
306
+ asyncio.set_event_loop(loop)
307
+ st.session_state.news_task = loop.create_task(background_news_updater())
308
+
309
+ # -------------------------
310
+ # Streamlit UI + Chat manager
311
+ # -------------------------
312
+ if "chats" not in st.session_state:
313
+ st.session_state.chats = {}
314
+ if "current_chat" not in st.session_state:
315
+ st.session_state.current_chat = "New Chat 1"
316
+ st.session_state.chats["New Chat 1"] = [{"role": "system", "content": "You are a helpful public health chatbot."}]
317
+
318
+ st.sidebar.header("Chat Manager")
319
+ if st.sidebar.button("βž• New Chat"):
320
+ chat_count = len(st.session_state.chats) + 1
321
+ new_chat_name = f"New Chat {chat_count}"
322
+ st.session_state.chats[new_chat_name] = [{"role": "system", "content": "You are a helpful public health chatbot."}]
323
+ st.session_state.current_chat = new_chat_name
324
+
325
+ benchmark_faiss()
326
+
327
+ chat_list = list(st.session_state.chats.keys())
328
+ selected_chat = st.sidebar.selectbox("Your chats:", chat_list, index=chat_list.index(st.session_state.current_chat))
329
+ st.session_state.current_chat = selected_chat
330
+
331
+ new_name = st.sidebar.text_input("Rename Chat:", st.session_state.current_chat)
332
+ if new_name and new_name != st.session_state.current_chat:
333
+ if new_name not in st.session_state.chats:
334
+ st.session_state.chats[new_name] = st.session_state.chats.pop(st.session_state.current_chat)
335
+ st.session_state.current_chat = new_name
336
+
337
+ # -------------------------
338
+ # Admin Panel
339
+ # -------------------------
340
+ query_params = st.query_params
341
+ is_admin_mode = (query_params.get("admin") == "1")
342
+
343
+ def rerun_app():
344
+ st.session_state['__rerun'] = not st.session_state.get('__rerun', False)
345
+
346
+ if is_admin_mode or st.session_state.get("is_admin", False):
347
+ st.sidebar.markdown("---")
348
+ st.sidebar.subheader("πŸ” Admin Panel (dev only)")
349
+ ADMIN_PASSWORD = os.environ.get("ADMIN_PASSWORD", "")
350
+ if "is_admin" not in st.session_state:
351
+ st.session_state.is_admin = False
352
+
353
+ if st.session_state.is_admin:
354
+ st.sidebar.success("Admin authenticated")
355
+ if st.sidebar.button("πŸšͺ Logout Admin"):
356
+ st.session_state.is_admin = False
357
+ rerun_app()
358
+ else:
359
+ admin_input = st.sidebar.text_input("Enter admin password:", type="password")
360
+ if st.sidebar.button("Login"):
361
+ if admin_input == ADMIN_PASSWORD and ADMIN_PASSWORD != "":
362
+ st.session_state.is_admin = True
363
+ st.sidebar.success("Admin authenticated")
364
+ rerun_app()
365
+ else:
366
+ st.sidebar.error("Wrong password or ADMIN_PASSWORD not set")
367
+
368
+ if st.session_state.is_admin:
369
+ if st.sidebar.button("⬇️ Export Query Cache to Excel"):
370
+ file_path = export_cache_to_excel()
371
+ st.sidebar.success(f"Exported: {file_path}")
372
+ with open(file_path, "rb") as f:
373
+ st.sidebar.download_button("Download Excel", f, file_name=file_path, mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
374
+
375
+ if st.sidebar.button("πŸ’Ύ Export SQL Dump"):
376
+ dump_path = export_cache_to_sql()
377
+ st.sidebar.success(f"SQL dump created: {dump_path}")
378
+ with open(dump_path, "rb") as f:
379
+ st.sidebar.download_button("Download SQL Dump", f, file_name=dump_path, mime="application/sql")
380
+
381
+ s3_bucket = os.environ.get("S3_BUCKET_NAME", "")
382
+ if s3_bucket and BOTO3_AVAILABLE:
383
+ if st.sidebar.button("πŸ“€ Upload last Excel to S3"):
384
+ excel_files = sorted([f for f in os.listdir(".") if f.startswith("cache_export_") and f.endswith(".xlsx")])
385
+ if excel_files:
386
+ last_file = excel_files[-1]
387
+ ok, msg = upload_file_to_s3(last_file, s3_bucket)
388
+ if ok:
389
+ st.sidebar.success(f"Uploaded: {msg}")
390
+ else:
391
+ st.sidebar.error(f"S3 upload failed: {msg}")
392
+ else:
393
+ st.sidebar.warning("No Excel export file found")
394
+ elif not BOTO3_AVAILABLE:
395
+ st.sidebar.info("S3 upload: boto3 not installed")
396
+ elif not s3_bucket:
397
+ st.sidebar.info("S3 upload disabled: set S3_BUCKET_NAME env var")
398
+
399
+ if st.sidebar.checkbox("πŸ“œ Show Export History"):
400
+ conn = get_db_connection()
401
+ logs = pd.read_sql_query("SELECT * FROM export_logs ORDER BY id DESC", conn)
402
+ conn.close()
403
+ st.sidebar.write(logs)
404
+
405
+ # -------------------------
406
+ # Main UI: News + Chat
407
+ # -------------------------
408
+ st.title(st.session_state.current_chat)
409
+
410
+ update_news_hourly()
411
+ st.subheader("πŸ“° Latest Health Updates")
412
+ if "news_articles" in st.session_state:
413
+ for art in st.session_state.news_articles:
414
+ st.markdown(f"**{art['title']}** \n[Read more]({art['link']}) \n*Published: {art['published']}*")
415
+ st.write("---")
416
+
417
+ user_query = st.text_input("Ask me about health, prevention, or awareness:")
418
+
419
+ if user_query:
420
+ with st.spinner("Searching knowledge base..."):
421
+ answer, sources = retrieve_answer(user_query)
422
+ st.write("### πŸ’‘ Answer")
423
+ st.write(answer)
424
+
425
+ st.write("### πŸ“– Sources")
426
+ for src in sources:
427
+ st.write(f"- {src}")
428
+
429
+ # render chat history
430
+ for msg in st.session_state.chats[st.session_state.current_chat]:
431
+ if msg["role"] == "user":
432
+ st.write(f"πŸ§‘ **You:** {msg['content']}")
433
+ elif msg["role"] == "assistant":
434
+ st.write(f"πŸ€– **Bot:** {msg['content']}")
435
+
436
+
437
+
438
+