Spaces:
Sleeping
Sleeping
Update app.py
Browse filesAdded Grad-CAM Overlay
app.py
CHANGED
|
@@ -1,19 +1,17 @@
|
|
| 1 |
-
import os
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
import joblib
|
| 5 |
import torch.nn as nn
|
|
|
|
| 6 |
from transformers import AutoImageProcessor, AutoModel
|
| 7 |
from PIL import Image
|
| 8 |
import requests
|
| 9 |
import gradio as gr
|
| 10 |
-
|
| 11 |
|
| 12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
|
| 14 |
-
#
|
| 15 |
-
# Your model class (unchanged)
|
| 16 |
-
# -----------------------
|
| 17 |
class ImageAuthenticityClassifier(nn.Module):
|
| 18 |
def __init__(self, backbone, w, b):
|
| 19 |
super().__init__()
|
|
@@ -42,9 +40,7 @@ class ImageAuthenticityClassifier(nn.Module):
|
|
| 42 |
|
| 43 |
patch_tokens = hidden[:, 1:, :]
|
| 44 |
emb = patch_tokens.mean(dim = 1)
|
| 45 |
-
|
| 46 |
-
# Apply classifier head to mean patch token embeddings
|
| 47 |
-
logits = self.head(emb)
|
| 48 |
prob = torch.sigmoid(logits)
|
| 49 |
|
| 50 |
if (return_tokens):
|
|
@@ -53,30 +49,23 @@ class ImageAuthenticityClassifier(nn.Module):
|
|
| 53 |
return logits, prob, emb
|
| 54 |
|
| 55 |
|
| 56 |
-
# -----------------------
|
| 57 |
# Load linear classifier head for logistic regression
|
| 58 |
-
# -----------------------
|
| 59 |
model_save_path = "logisticRegressionClassifier.joblib"
|
| 60 |
logisticRegressionClassifier = joblib.load(model_save_path)
|
| 61 |
-
|
| 62 |
coef = logisticRegressionClassifier.coef_
|
| 63 |
w = torch.from_numpy(coef.squeeze(0)).float()
|
| 64 |
intercept = logisticRegressionClassifier.intercept_
|
| 65 |
b = float(intercept[0])
|
| 66 |
|
| 67 |
|
| 68 |
-
# -----------------------
|
| 69 |
# Load DinoV3 backbone + processor (gated repo via token)
|
| 70 |
-
# -----------------------
|
| 71 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 72 |
backbone = AutoModel.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN).to(device)
|
| 73 |
processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN,)
|
| 74 |
image_auth_model = ImageAuthenticityClassifier(backbone, w, b).to(device)
|
| 75 |
|
| 76 |
|
| 77 |
-
# -----------------------
|
| 78 |
# Inference helper functions (unchanged)
|
| 79 |
-
# -----------------------
|
| 80 |
def load_image(online_image_url):
|
| 81 |
img = Image.open(requests.get(online_image_url, stream=True).raw).convert("RGB")
|
| 82 |
return img
|
|
@@ -86,6 +75,7 @@ def prepare_pixel_values(img):
|
|
| 86 |
pixel_values = inputs["pixel_values"].to(device)
|
| 87 |
return pixel_values
|
| 88 |
|
|
|
|
| 89 |
def predict_from_online_url(online_image_url):
|
| 90 |
img = load_image(online_image_url)
|
| 91 |
pixel_values = prepare_pixel_values(img)
|
|
@@ -95,26 +85,99 @@ def predict_from_online_url(online_image_url):
|
|
| 95 |
return float(prob[0][0].item())
|
| 96 |
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
# -----------------------
|
| 99 |
# Gradio interface exposing ui_predict as a web UI/API. (AI Generated lol)
|
| 100 |
# -----------------------
|
| 101 |
-
|
| 102 |
def ui_predict(image_url: str):
|
| 103 |
if not image_url:
|
| 104 |
-
return None, "Awaiting input", "Enter an image URL to run a prediction."
|
| 105 |
-
|
| 106 |
try:
|
| 107 |
img = load_image(image_url)
|
| 108 |
-
|
|
|
|
| 109 |
percent = ai_prob * 100.0
|
| 110 |
-
|
| 111 |
verdict = "AI-generated" if ai_prob >= 0.5 else "Not AI-generated"
|
| 112 |
headline = verdict
|
| 113 |
detail = f"{percent:.1f}% probability the image is AI-generated"
|
| 114 |
-
|
| 115 |
-
|
| 116 |
except Exception as e:
|
| 117 |
-
return None, "Error", str(e)
|
| 118 |
|
| 119 |
demo = gr.Interface(
|
| 120 |
fn=ui_predict,
|
|
@@ -126,6 +189,7 @@ demo = gr.Interface(
|
|
| 126 |
gr.Image(label="Preview"),
|
| 127 |
gr.Textbox(label="Verdict"),
|
| 128 |
gr.Textbox(label="Details"),
|
|
|
|
| 129 |
],
|
| 130 |
title="Image Authenticicity",
|
| 131 |
description="Paste an image URL to estimate how likely it is AI-generated.",
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
import joblib
|
| 4 |
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
from transformers import AutoImageProcessor, AutoModel
|
| 7 |
from PIL import Image
|
| 8 |
import requests
|
| 9 |
import gradio as gr
|
| 10 |
+
import cv2
|
| 11 |
|
| 12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
|
| 14 |
+
# My model from Collab (unchanged)
|
|
|
|
|
|
|
| 15 |
class ImageAuthenticityClassifier(nn.Module):
|
| 16 |
def __init__(self, backbone, w, b):
|
| 17 |
super().__init__()
|
|
|
|
| 40 |
|
| 41 |
patch_tokens = hidden[:, 1:, :]
|
| 42 |
emb = patch_tokens.mean(dim = 1)
|
| 43 |
+
logits = self.head(emb) # Apply classifier head to mean patch token embeddings
|
|
|
|
|
|
|
| 44 |
prob = torch.sigmoid(logits)
|
| 45 |
|
| 46 |
if (return_tokens):
|
|
|
|
| 49 |
return logits, prob, emb
|
| 50 |
|
| 51 |
|
|
|
|
| 52 |
# Load linear classifier head for logistic regression
|
|
|
|
| 53 |
model_save_path = "logisticRegressionClassifier.joblib"
|
| 54 |
logisticRegressionClassifier = joblib.load(model_save_path)
|
|
|
|
| 55 |
coef = logisticRegressionClassifier.coef_
|
| 56 |
w = torch.from_numpy(coef.squeeze(0)).float()
|
| 57 |
intercept = logisticRegressionClassifier.intercept_
|
| 58 |
b = float(intercept[0])
|
| 59 |
|
| 60 |
|
|
|
|
| 61 |
# Load DinoV3 backbone + processor (gated repo via token)
|
|
|
|
| 62 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 63 |
backbone = AutoModel.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN).to(device)
|
| 64 |
processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN,)
|
| 65 |
image_auth_model = ImageAuthenticityClassifier(backbone, w, b).to(device)
|
| 66 |
|
| 67 |
|
|
|
|
| 68 |
# Inference helper functions (unchanged)
|
|
|
|
| 69 |
def load_image(online_image_url):
|
| 70 |
img = Image.open(requests.get(online_image_url, stream=True).raw).convert("RGB")
|
| 71 |
return img
|
|
|
|
| 75 |
pixel_values = inputs["pixel_values"].to(device)
|
| 76 |
return pixel_values
|
| 77 |
|
| 78 |
+
# Unused
|
| 79 |
def predict_from_online_url(online_image_url):
|
| 80 |
img = load_image(online_image_url)
|
| 81 |
pixel_values = prepare_pixel_values(img)
|
|
|
|
| 85 |
return float(prob[0][0].item())
|
| 86 |
|
| 87 |
|
| 88 |
+
# Grad-CAM Helper Functions (Unchanged) -------------------
|
| 89 |
+
def compute_cam_from_tokens(patch_tokens, pixel_values, patch_size=16):
|
| 90 |
+
# Dimension calculations
|
| 91 |
+
H_in, W_in = pixel_values.shape[-2], pixel_values.shape[-1]
|
| 92 |
+
H_p = H_in // patch_size
|
| 93 |
+
W_p = W_in // patch_size
|
| 94 |
+
num_spatial = H_p * W_p
|
| 95 |
+
|
| 96 |
+
# Tokens and grads for all 200 tokens after CLS. Keep only the spatial patch tokens (drop the 4 global tokens at start)
|
| 97 |
+
tokens_all = patch_tokens[0] # (200, D)
|
| 98 |
+
grads_all = patch_tokens.grad[0] # (200, D)
|
| 99 |
+
tokens_spatial = tokens_all[-num_spatial:, :] # (196, D)
|
| 100 |
+
grads_spatial = grads_all[-num_spatial:, :] # (196, D)
|
| 101 |
+
|
| 102 |
+
# Get a single weight per feature dimension averaged over all patches
|
| 103 |
+
weights = grads_spatial.mean(dim=0) # (D,)
|
| 104 |
+
|
| 105 |
+
# For each patch, combine activation and weights to make different importance for each patch, and normalize results.
|
| 106 |
+
cam_per_patch = (tokens_spatial * weights).sum(dim=-1)
|
| 107 |
+
cam_per_patch = torch.relu(cam_per_patch)
|
| 108 |
+
cam_per_patch = cam_per_patch - cam_per_patch.min()
|
| 109 |
+
cam_per_patch = cam_per_patch / (cam_per_patch.max() + 1e-8) # shape: (N,)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
cam_grid = cam_per_patch.reshape(H_p, W_p)
|
| 113 |
+
cam = cam_grid.unsqueeze(0).unsqueeze(0) # (1, 1, H_p, W_p)
|
| 114 |
+
cam_up = F.interpolate(
|
| 115 |
+
cam,
|
| 116 |
+
size=(H_in, W_in),
|
| 117 |
+
mode="bilinear",
|
| 118 |
+
align_corners=False,
|
| 119 |
+
)[0, 0] # (H_in, W_in)
|
| 120 |
+
|
| 121 |
+
return cam_up
|
| 122 |
+
|
| 123 |
+
def grad_cam_from_online_url(online_image_url):
|
| 124 |
+
# Load image and get pixel_values
|
| 125 |
+
img = load_image(online_image_url)
|
| 126 |
+
pixel_values = prepare_pixel_values(img)
|
| 127 |
+
|
| 128 |
+
# Run prediction with return_tokens=True
|
| 129 |
+
logits, prob, emb, patch_tokens = image_auth_model(pixel_values, return_tokens=True)
|
| 130 |
+
ai_prob = float(prob[0][0].item())
|
| 131 |
+
target_logit = logits[0, 0]
|
| 132 |
+
|
| 133 |
+
image_auth_model.zero_grad()
|
| 134 |
+
|
| 135 |
+
if patch_tokens.grad is not None:
|
| 136 |
+
patch_tokens.grad.zero_()
|
| 137 |
+
|
| 138 |
+
patch_tokens.retain_grad()
|
| 139 |
+
target_logit.backward(retain_graph=True) # Finds d_target_logit/d_patch_tokens in patch_tokens.grad()
|
| 140 |
+
|
| 141 |
+
# Compute Grad-CAM heatmap
|
| 142 |
+
cam_up = compute_cam_from_tokens(patch_tokens, pixel_values)
|
| 143 |
+
cam_np = cam_up.detach().cpu().numpy()
|
| 144 |
+
orig_np = np.array(img).astype(np.float32) / 255.0
|
| 145 |
+
H0, W0, _ = orig_np.shape
|
| 146 |
+
|
| 147 |
+
cam = cam_np.astype(np.float32)
|
| 148 |
+
if cam.shape != (H0, W0):
|
| 149 |
+
cam_t = torch.from_numpy(cam).unsqueeze(0).unsqueeze(0)
|
| 150 |
+
cam_t = F.interpolate(cam_t, size=(H0, W0), mode="bilinear", align_corners=False)
|
| 151 |
+
cam = cam_t[0, 0].cpu().numpy()
|
| 152 |
+
|
| 153 |
+
cam_uint8 = np.uint8(cam * 255)
|
| 154 |
+
heatmap_bgr = cv2.applyColorMap(cam_uint8, cv2.COLORMAP_JET)
|
| 155 |
+
heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
| 156 |
+
|
| 157 |
+
alpha = 0.5
|
| 158 |
+
overlay = alpha * heatmap_rgb + (1.0 - alpha) * orig_np
|
| 159 |
+
overlay = np.clip(overlay, 0.0, 1.0)
|
| 160 |
+
|
| 161 |
+
return ai_prob, orig_np, overlay
|
| 162 |
+
|
| 163 |
# -----------------------
|
| 164 |
# Gradio interface exposing ui_predict as a web UI/API. (AI Generated lol)
|
| 165 |
# -----------------------
|
|
|
|
| 166 |
def ui_predict(image_url: str):
|
| 167 |
if not image_url:
|
| 168 |
+
return None, "Awaiting input", "Enter an image URL to run a prediction.", None
|
|
|
|
| 169 |
try:
|
| 170 |
img = load_image(image_url)
|
| 171 |
+
|
| 172 |
+
ai_prob, img, img_with_gradcam_overlay = grad_cam_from_online_url(image_url)
|
| 173 |
percent = ai_prob * 100.0
|
|
|
|
| 174 |
verdict = "AI-generated" if ai_prob >= 0.5 else "Not AI-generated"
|
| 175 |
headline = verdict
|
| 176 |
detail = f"{percent:.1f}% probability the image is AI-generated"
|
| 177 |
+
return img, headline, detail, img_with_gradcam_overlay
|
| 178 |
+
|
| 179 |
except Exception as e:
|
| 180 |
+
return None, "Error", str(e), None
|
| 181 |
|
| 182 |
demo = gr.Interface(
|
| 183 |
fn=ui_predict,
|
|
|
|
| 189 |
gr.Image(label="Preview"),
|
| 190 |
gr.Textbox(label="Verdict"),
|
| 191 |
gr.Textbox(label="Details"),
|
| 192 |
+
gr.Image(label="Grad-CAM"),
|
| 193 |
],
|
| 194 |
title="Image Authenticicity",
|
| 195 |
description="Paste an image URL to estimate how likely it is AI-generated.",
|