jerry2247 commited on
Commit
30eefab
·
verified ·
1 Parent(s): 9b77aaa

Update app.py

Browse files

Added Grad-CAM Overlay

Files changed (1) hide show
  1. app.py +87 -23
app.py CHANGED
@@ -1,19 +1,17 @@
1
- import os
2
  import numpy as np
3
  import torch
4
  import joblib
5
  import torch.nn as nn
 
6
  from transformers import AutoImageProcessor, AutoModel
7
  from PIL import Image
8
  import requests
9
  import gradio as gr
10
-
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
- # -----------------------
15
- # Your model class (unchanged)
16
- # -----------------------
17
  class ImageAuthenticityClassifier(nn.Module):
18
  def __init__(self, backbone, w, b):
19
  super().__init__()
@@ -42,9 +40,7 @@ class ImageAuthenticityClassifier(nn.Module):
42
 
43
  patch_tokens = hidden[:, 1:, :]
44
  emb = patch_tokens.mean(dim = 1)
45
-
46
- # Apply classifier head to mean patch token embeddings
47
- logits = self.head(emb)
48
  prob = torch.sigmoid(logits)
49
 
50
  if (return_tokens):
@@ -53,30 +49,23 @@ class ImageAuthenticityClassifier(nn.Module):
53
  return logits, prob, emb
54
 
55
 
56
- # -----------------------
57
  # Load linear classifier head for logistic regression
58
- # -----------------------
59
  model_save_path = "logisticRegressionClassifier.joblib"
60
  logisticRegressionClassifier = joblib.load(model_save_path)
61
-
62
  coef = logisticRegressionClassifier.coef_
63
  w = torch.from_numpy(coef.squeeze(0)).float()
64
  intercept = logisticRegressionClassifier.intercept_
65
  b = float(intercept[0])
66
 
67
 
68
- # -----------------------
69
  # Load DinoV3 backbone + processor (gated repo via token)
70
- # -----------------------
71
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
72
  backbone = AutoModel.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN).to(device)
73
  processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN,)
74
  image_auth_model = ImageAuthenticityClassifier(backbone, w, b).to(device)
75
 
76
 
77
- # -----------------------
78
  # Inference helper functions (unchanged)
79
- # -----------------------
80
  def load_image(online_image_url):
81
  img = Image.open(requests.get(online_image_url, stream=True).raw).convert("RGB")
82
  return img
@@ -86,6 +75,7 @@ def prepare_pixel_values(img):
86
  pixel_values = inputs["pixel_values"].to(device)
87
  return pixel_values
88
 
 
89
  def predict_from_online_url(online_image_url):
90
  img = load_image(online_image_url)
91
  pixel_values = prepare_pixel_values(img)
@@ -95,26 +85,99 @@ def predict_from_online_url(online_image_url):
95
  return float(prob[0][0].item())
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # -----------------------
99
  # Gradio interface exposing ui_predict as a web UI/API. (AI Generated lol)
100
  # -----------------------
101
-
102
  def ui_predict(image_url: str):
103
  if not image_url:
104
- return None, "Awaiting input", "Enter an image URL to run a prediction."
105
-
106
  try:
107
  img = load_image(image_url)
108
- ai_prob = float(predict_from_online_url(image_url))
 
109
  percent = ai_prob * 100.0
110
-
111
  verdict = "AI-generated" if ai_prob >= 0.5 else "Not AI-generated"
112
  headline = verdict
113
  detail = f"{percent:.1f}% probability the image is AI-generated"
114
-
115
- return img, headline, detail
116
  except Exception as e:
117
- return None, "Error", str(e)
118
 
119
  demo = gr.Interface(
120
  fn=ui_predict,
@@ -126,6 +189,7 @@ demo = gr.Interface(
126
  gr.Image(label="Preview"),
127
  gr.Textbox(label="Verdict"),
128
  gr.Textbox(label="Details"),
 
129
  ],
130
  title="Image Authenticicity",
131
  description="Paste an image URL to estimate how likely it is AI-generated.",
 
 
1
  import numpy as np
2
  import torch
3
  import joblib
4
  import torch.nn as nn
5
+ import torch.nn.functional as F
6
  from transformers import AutoImageProcessor, AutoModel
7
  from PIL import Image
8
  import requests
9
  import gradio as gr
10
+ import cv2
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
+ # My model from Collab (unchanged)
 
 
15
  class ImageAuthenticityClassifier(nn.Module):
16
  def __init__(self, backbone, w, b):
17
  super().__init__()
 
40
 
41
  patch_tokens = hidden[:, 1:, :]
42
  emb = patch_tokens.mean(dim = 1)
43
+ logits = self.head(emb) # Apply classifier head to mean patch token embeddings
 
 
44
  prob = torch.sigmoid(logits)
45
 
46
  if (return_tokens):
 
49
  return logits, prob, emb
50
 
51
 
 
52
  # Load linear classifier head for logistic regression
 
53
  model_save_path = "logisticRegressionClassifier.joblib"
54
  logisticRegressionClassifier = joblib.load(model_save_path)
 
55
  coef = logisticRegressionClassifier.coef_
56
  w = torch.from_numpy(coef.squeeze(0)).float()
57
  intercept = logisticRegressionClassifier.intercept_
58
  b = float(intercept[0])
59
 
60
 
 
61
  # Load DinoV3 backbone + processor (gated repo via token)
 
62
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
63
  backbone = AutoModel.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN).to(device)
64
  processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vitb16-pretrain-lvd1689m", token=HF_TOKEN,)
65
  image_auth_model = ImageAuthenticityClassifier(backbone, w, b).to(device)
66
 
67
 
 
68
  # Inference helper functions (unchanged)
 
69
  def load_image(online_image_url):
70
  img = Image.open(requests.get(online_image_url, stream=True).raw).convert("RGB")
71
  return img
 
75
  pixel_values = inputs["pixel_values"].to(device)
76
  return pixel_values
77
 
78
+ # Unused
79
  def predict_from_online_url(online_image_url):
80
  img = load_image(online_image_url)
81
  pixel_values = prepare_pixel_values(img)
 
85
  return float(prob[0][0].item())
86
 
87
 
88
+ # Grad-CAM Helper Functions (Unchanged) -------------------
89
+ def compute_cam_from_tokens(patch_tokens, pixel_values, patch_size=16):
90
+ # Dimension calculations
91
+ H_in, W_in = pixel_values.shape[-2], pixel_values.shape[-1]
92
+ H_p = H_in // patch_size
93
+ W_p = W_in // patch_size
94
+ num_spatial = H_p * W_p
95
+
96
+ # Tokens and grads for all 200 tokens after CLS. Keep only the spatial patch tokens (drop the 4 global tokens at start)
97
+ tokens_all = patch_tokens[0] # (200, D)
98
+ grads_all = patch_tokens.grad[0] # (200, D)
99
+ tokens_spatial = tokens_all[-num_spatial:, :] # (196, D)
100
+ grads_spatial = grads_all[-num_spatial:, :] # (196, D)
101
+
102
+ # Get a single weight per feature dimension averaged over all patches
103
+ weights = grads_spatial.mean(dim=0) # (D,)
104
+
105
+ # For each patch, combine activation and weights to make different importance for each patch, and normalize results.
106
+ cam_per_patch = (tokens_spatial * weights).sum(dim=-1)
107
+ cam_per_patch = torch.relu(cam_per_patch)
108
+ cam_per_patch = cam_per_patch - cam_per_patch.min()
109
+ cam_per_patch = cam_per_patch / (cam_per_patch.max() + 1e-8) # shape: (N,)
110
+
111
+
112
+ cam_grid = cam_per_patch.reshape(H_p, W_p)
113
+ cam = cam_grid.unsqueeze(0).unsqueeze(0) # (1, 1, H_p, W_p)
114
+ cam_up = F.interpolate(
115
+ cam,
116
+ size=(H_in, W_in),
117
+ mode="bilinear",
118
+ align_corners=False,
119
+ )[0, 0] # (H_in, W_in)
120
+
121
+ return cam_up
122
+
123
+ def grad_cam_from_online_url(online_image_url):
124
+ # Load image and get pixel_values
125
+ img = load_image(online_image_url)
126
+ pixel_values = prepare_pixel_values(img)
127
+
128
+ # Run prediction with return_tokens=True
129
+ logits, prob, emb, patch_tokens = image_auth_model(pixel_values, return_tokens=True)
130
+ ai_prob = float(prob[0][0].item())
131
+ target_logit = logits[0, 0]
132
+
133
+ image_auth_model.zero_grad()
134
+
135
+ if patch_tokens.grad is not None:
136
+ patch_tokens.grad.zero_()
137
+
138
+ patch_tokens.retain_grad()
139
+ target_logit.backward(retain_graph=True) # Finds d_target_logit/d_patch_tokens in patch_tokens.grad()
140
+
141
+ # Compute Grad-CAM heatmap
142
+ cam_up = compute_cam_from_tokens(patch_tokens, pixel_values)
143
+ cam_np = cam_up.detach().cpu().numpy()
144
+ orig_np = np.array(img).astype(np.float32) / 255.0
145
+ H0, W0, _ = orig_np.shape
146
+
147
+ cam = cam_np.astype(np.float32)
148
+ if cam.shape != (H0, W0):
149
+ cam_t = torch.from_numpy(cam).unsqueeze(0).unsqueeze(0)
150
+ cam_t = F.interpolate(cam_t, size=(H0, W0), mode="bilinear", align_corners=False)
151
+ cam = cam_t[0, 0].cpu().numpy()
152
+
153
+ cam_uint8 = np.uint8(cam * 255)
154
+ heatmap_bgr = cv2.applyColorMap(cam_uint8, cv2.COLORMAP_JET)
155
+ heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
156
+
157
+ alpha = 0.5
158
+ overlay = alpha * heatmap_rgb + (1.0 - alpha) * orig_np
159
+ overlay = np.clip(overlay, 0.0, 1.0)
160
+
161
+ return ai_prob, orig_np, overlay
162
+
163
  # -----------------------
164
  # Gradio interface exposing ui_predict as a web UI/API. (AI Generated lol)
165
  # -----------------------
 
166
  def ui_predict(image_url: str):
167
  if not image_url:
168
+ return None, "Awaiting input", "Enter an image URL to run a prediction.", None
 
169
  try:
170
  img = load_image(image_url)
171
+
172
+ ai_prob, img, img_with_gradcam_overlay = grad_cam_from_online_url(image_url)
173
  percent = ai_prob * 100.0
 
174
  verdict = "AI-generated" if ai_prob >= 0.5 else "Not AI-generated"
175
  headline = verdict
176
  detail = f"{percent:.1f}% probability the image is AI-generated"
177
+ return img, headline, detail, img_with_gradcam_overlay
178
+
179
  except Exception as e:
180
+ return None, "Error", str(e), None
181
 
182
  demo = gr.Interface(
183
  fn=ui_predict,
 
189
  gr.Image(label="Preview"),
190
  gr.Textbox(label="Verdict"),
191
  gr.Textbox(label="Details"),
192
+ gr.Image(label="Grad-CAM"),
193
  ],
194
  title="Image Authenticicity",
195
  description="Paste an image URL to estimate how likely it is AI-generated.",