|
|
|
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File
|
|
|
from pydantic import BaseModel
|
|
|
from typing import List, Optional
|
|
|
import os
|
|
|
from dotenv import load_dotenv
|
|
|
load_dotenv()
|
|
|
|
|
|
import hashlib
|
|
|
|
|
|
from supabase import create_client, Client
|
|
|
from pymongo import MongoClient
|
|
|
from bson import ObjectId
|
|
|
import math
|
|
|
|
|
|
|
|
|
try:
|
|
|
from imagehash import hex_to_hash
|
|
|
except Exception:
|
|
|
def hex_to_hash(s):
|
|
|
return None
|
|
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
from transformers import CLIPProcessor, CLIPModel
|
|
|
import torch
|
|
|
from PIL import Image
|
|
|
import io
|
|
|
import numpy as np
|
|
|
import imagehash
|
|
|
import cv2
|
|
|
|
|
|
|
|
|
|
|
|
sbert = SentenceTransformer('all-MiniLM-L6-v2')
|
|
|
clip_model_name = 'openai/clip-vit-base-patch32'
|
|
|
clip_model = CLIPModel.from_pretrained(clip_model_name)
|
|
|
clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
|
|
|
|
|
|
|
|
|
SUPABASE_URL = os.getenv('SUPABASE_URL')
|
|
|
|
|
|
SUPABASE_KEY = os.getenv('SUPABASE_SERVICE_ROLE_KEY')
|
|
|
SUPABASE_BUCKET = os.getenv('SUPABASE_BUCKET', 'files')
|
|
|
|
|
|
if not SUPABASE_URL or not SUPABASE_KEY:
|
|
|
print("Warning: SUPABASE_URL or SUPABASE_SERVICE_ROLE_KEY not set.")
|
|
|
supabase = None
|
|
|
else:
|
|
|
try:
|
|
|
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
|
|
print(f"Supabase client initialized for bucket: {SUPABASE_BUCKET}")
|
|
|
except Exception as e:
|
|
|
print(f"Supabase init failed: {e}")
|
|
|
supabase = None
|
|
|
|
|
|
mongo_client = MongoClient(os.getenv('MONGO_URI'))
|
|
|
try:
|
|
|
db = mongo_client.get_default_database()
|
|
|
except Exception:
|
|
|
db = mongo_client['ai_personal_cloud']
|
|
|
|
|
|
files_col = db['files']
|
|
|
|
|
|
app = FastAPI(title='AI Microservice')
|
|
|
|
|
|
class ProcessInput(BaseModel):
|
|
|
fileId: str
|
|
|
minioKey: str
|
|
|
mimetype: Optional[str]
|
|
|
|
|
|
class QueryInput(BaseModel):
|
|
|
query: str
|
|
|
userId: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def download_from_supabase(key):
|
|
|
if not supabase:
|
|
|
raise Exception("Supabase client not initialized")
|
|
|
|
|
|
|
|
|
|
|
|
data = supabase.storage.from_(SUPABASE_BUCKET).download(key)
|
|
|
return data
|
|
|
|
|
|
|
|
|
def sha256_hash(buffer):
|
|
|
h = hashlib.sha256()
|
|
|
h.update(buffer)
|
|
|
return h.hexdigest()
|
|
|
|
|
|
|
|
|
def phash_image(buffer):
|
|
|
try:
|
|
|
img = Image.open(io.BytesIO(buffer)).convert('RGB')
|
|
|
ph = str(imagehash.phash(img))
|
|
|
return ph
|
|
|
except Exception as e:
|
|
|
return None
|
|
|
|
|
|
|
|
|
def embedding_from_text_sbert(text):
|
|
|
vec = sbert.encode(text)
|
|
|
return vec.tolist()
|
|
|
|
|
|
|
|
|
def embedding_from_text_clip(text):
|
|
|
inputs = clip_processor(text=[text], return_tensors='pt', padding=True, truncation=True)
|
|
|
with torch.no_grad():
|
|
|
features = clip_model.get_text_features(**inputs)
|
|
|
features = features / features.norm(p=2, dim=-1, keepdim=True)
|
|
|
return features[0].cpu().numpy().tolist()
|
|
|
|
|
|
|
|
|
def embedding_from_image_clip(buffer):
|
|
|
img = Image.open(io.BytesIO(buffer)).convert('RGB')
|
|
|
inputs = clip_processor(images=img, return_tensors='pt')
|
|
|
with torch.no_grad():
|
|
|
features = clip_model.get_image_features(**inputs)
|
|
|
features = features / features.norm(p=2, dim=-1, keepdim=True)
|
|
|
return features[0].cpu().numpy().tolist()
|
|
|
|
|
|
|
|
|
@app.post('/process-file')
|
|
|
async def process_file(payload: ProcessInput):
|
|
|
import traceback
|
|
|
|
|
|
key = payload.minioKey
|
|
|
file_id = payload.fileId
|
|
|
mimetype = (payload.mimetype or '').lower()
|
|
|
|
|
|
if not ObjectId.is_valid(file_id):
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid ObjectId: {file_id}")
|
|
|
|
|
|
try:
|
|
|
|
|
|
try:
|
|
|
buffer = download_from_supabase(key)
|
|
|
except Exception as e:
|
|
|
print(f"Supabase Error for key {key}: {e}")
|
|
|
raise HTTPException(status_code=404, detail=f"Could not download file from Supabase: {str(e)}")
|
|
|
|
|
|
|
|
|
h = sha256_hash(buffer)
|
|
|
|
|
|
|
|
|
embedding = []
|
|
|
category = 'unknown'
|
|
|
phash = None
|
|
|
embedding_type = 'sbert'
|
|
|
duplicate = False
|
|
|
duplicate_of = None
|
|
|
user_id = None
|
|
|
|
|
|
|
|
|
file_doc = files_col.find_one({'_id': ObjectId(file_id)})
|
|
|
if file_doc:
|
|
|
user_id = file_doc.get('userId')
|
|
|
else:
|
|
|
print(f"Warning: No document found in Mongo for ID {file_id}")
|
|
|
|
|
|
|
|
|
if 'image' in mimetype:
|
|
|
ph = phash_image(buffer)
|
|
|
phash = ph
|
|
|
embedding = embedding_from_image_clip(buffer)
|
|
|
category = 'image'
|
|
|
embedding_type = 'clip_image'
|
|
|
|
|
|
if h:
|
|
|
existing = files_col.find_one({'hash': h, '_id': { '$ne': ObjectId(file_id) }})
|
|
|
if existing:
|
|
|
duplicate = True
|
|
|
duplicate_of = str(existing['_id'])
|
|
|
|
|
|
if not duplicate and phash:
|
|
|
current_hash = hex_to_hash(phash)
|
|
|
if current_hash:
|
|
|
query = { 'pHash': { '$exists': True } }
|
|
|
if user_id:
|
|
|
query['userId'] = user_id
|
|
|
|
|
|
candidates = list(files_col.find(query))
|
|
|
for c in candidates:
|
|
|
try:
|
|
|
ch = hex_to_hash(c.get('pHash'))
|
|
|
if ch and (current_hash - ch <= 6):
|
|
|
duplicate = True
|
|
|
duplicate_of = str(c.get('_id'))
|
|
|
break
|
|
|
except Exception:
|
|
|
continue
|
|
|
|
|
|
elif any(x in mimetype for x in ['pdf', 'text', 'msword', 'officedocument']):
|
|
|
|
|
|
filename = key.split('/')[-1].lower()
|
|
|
|
|
|
|
|
|
if 'text' in mimetype:
|
|
|
try:
|
|
|
txt = buffer.decode('utf-8', errors='ignore')
|
|
|
except Exception:
|
|
|
txt = filename
|
|
|
else:
|
|
|
|
|
|
txt = filename.replace('_', ' ').replace('-', ' ')
|
|
|
|
|
|
|
|
|
embedding = embedding_from_text_sbert(txt[:5000])
|
|
|
|
|
|
|
|
|
if 'presentation' in mimetype or filename.endswith(('.ppt', '.pptx')):
|
|
|
category = 'presentation'
|
|
|
elif 'spreadsheet' in mimetype or filename.endswith(('.xls', '.xlsx', '.csv')):
|
|
|
category = 'spreadsheet'
|
|
|
elif 'pdf' in mimetype or filename.endswith('.pdf'):
|
|
|
|
|
|
if any(word in filename for word in ['invoice', 'bill', 'receipt']):
|
|
|
category = 'invoice'
|
|
|
elif any(word in filename for word in ['resume', 'cv']):
|
|
|
category = 'resume'
|
|
|
elif any(word in filename for word in ['report', 'analysis']):
|
|
|
category = 'report'
|
|
|
else:
|
|
|
category = 'pdf'
|
|
|
elif 'text' in mimetype:
|
|
|
if any(word in txt.lower() for word in ['invoice', 'bill']):
|
|
|
category = 'invoice'
|
|
|
elif any(word in txt.lower() for word in ['note', 'memo']):
|
|
|
category = 'notes'
|
|
|
else:
|
|
|
category = 'text'
|
|
|
else:
|
|
|
category = 'document'
|
|
|
|
|
|
if h:
|
|
|
existing = files_col.find_one({'hash': h, '_id': { '$ne': ObjectId(file_id) }})
|
|
|
if existing:
|
|
|
duplicate = True
|
|
|
duplicate_of = str(existing['_id'])
|
|
|
|
|
|
else:
|
|
|
|
|
|
embedding = embedding_from_text_sbert(key.split('/')[-1].replace('_', ' '))
|
|
|
category = 'file'
|
|
|
|
|
|
if h:
|
|
|
existing = files_col.find_one({'hash': h, '_id': { '$ne': ObjectId(file_id) }})
|
|
|
if existing:
|
|
|
duplicate = True
|
|
|
duplicate_of = str(existing['_id'])
|
|
|
|
|
|
|
|
|
update_data = {
|
|
|
'category': category,
|
|
|
'embedding': embedding,
|
|
|
'embeddingType': embedding_type,
|
|
|
'hash': h,
|
|
|
'pHash': phash,
|
|
|
'duplicate': duplicate,
|
|
|
'duplicateOf': ObjectId(duplicate_of) if duplicate_of else None
|
|
|
}
|
|
|
|
|
|
files_col.update_one({'_id': ObjectId(file_id)}, {'$set': update_data})
|
|
|
|
|
|
return {
|
|
|
'status': 'processed',
|
|
|
'fileId': file_id,
|
|
|
'category': category,
|
|
|
'duplicate': duplicate
|
|
|
}
|
|
|
|
|
|
except HTTPException as he:
|
|
|
raise he
|
|
|
except Exception as e:
|
|
|
traceback.print_exc()
|
|
|
raise HTTPException(status_code=500, detail=f"Internal Error: {str(e)}")
|
|
|
|
|
|
@app.post('/semantic-search')
|
|
|
async def semantic_search(payload: QueryInput):
|
|
|
q = payload.query
|
|
|
userId = getattr(payload, 'userId', None)
|
|
|
if not q:
|
|
|
raise HTTPException(status_code=400, detail='Missing query')
|
|
|
|
|
|
sbert_emb = embedding_from_text_sbert(q)
|
|
|
clip_text_emb = embedding_from_text_clip(q)
|
|
|
|
|
|
|
|
|
query = {}
|
|
|
if userId:
|
|
|
try:
|
|
|
query['userId'] = ObjectId(userId)
|
|
|
except Exception:
|
|
|
query['userId'] = userId
|
|
|
files = list(files_col.find(query, { 'filename': 1, 'category': 1, 'embedding': 1, 'embeddingType': 1 }))
|
|
|
|
|
|
def cosine(a, b):
|
|
|
if not a or not b:
|
|
|
return 0.0
|
|
|
dot = sum(x*y for x, y in zip(a, b))
|
|
|
magA = math.sqrt(sum(x*x for x in a))
|
|
|
magB = math.sqrt(sum(x*x for x in b))
|
|
|
if magA == 0 or magB == 0:
|
|
|
return 0.0
|
|
|
return dot / (magA*magB)
|
|
|
|
|
|
results = []
|
|
|
for f in files:
|
|
|
emb = f.get('embedding') or []
|
|
|
if not emb:
|
|
|
continue
|
|
|
etype = f.get('embeddingType', 'sbert')
|
|
|
score = 0.0
|
|
|
try:
|
|
|
if etype == 'clip_image':
|
|
|
|
|
|
score = cosine(clip_text_emb, emb)
|
|
|
elif etype == 'sbert':
|
|
|
score = cosine(sbert_emb, emb)
|
|
|
else:
|
|
|
|
|
|
score1 = cosine(sbert_emb, emb)
|
|
|
score2 = cosine(clip_text_emb, emb)
|
|
|
score = max(score1, score2)
|
|
|
except Exception:
|
|
|
score = 0.0
|
|
|
results.append({'fileId': str(f['_id']), 'filename': f.get('filename'), 'category': f.get('category'), 'score': float(score)})
|
|
|
|
|
|
results.sort(key=lambda x: x['score'], reverse=True)
|
|
|
return { 'results': results }
|
|
|
|