ishanprogs commited on
Commit
e56a7d9
·
verified ·
1 Parent(s): 5697a81

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +259 -0
  2. clip_text_features.pt +3 -0
  3. model_best.pth +3 -0
  4. requirements.txt +13 -0
app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import clip
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os
7
+ import cv2
8
+ import gc # Garbage collector
9
+ import logging
10
+
11
+ # --- Detectron2 Imports ---
12
+ from detectron2 import model_zoo
13
+ from detectron2.engine import DefaultPredictor
14
+ from detectron2.config import get_cfg
15
+ from detectron2.utils.visualizer import Visualizer, ColorMode
16
+ from detectron2.data import MetadataCatalog
17
+
18
+ # --- Setup Logging ---
19
+ # Reduce default Detectron2 logging noise if needed
20
+ logging.getLogger("detectron2").setLevel(logging.WARNING)
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # --- Constants ---
25
+ # Damage segmentation classes (MUST match training order)
26
+ DAMAGE_CLASSES = ['Cracked', 'Scratch', 'Flaking', 'Broken part', 'Corrosion', 'Dent', 'Paint chip', 'Missing part']
27
+ NUM_DAMAGE_CLASSES = len(DAMAGE_CLASSES)
28
+
29
+ # Paths within the Hugging Face Space repository
30
+ CLIP_TEXT_FEATURES_PATH = "./clip_text_features.pt"
31
+ # CLIP_MODEL_WEIGHTS_PATH = "./clip_model/clip_vit_b16.pth" # Alt: Load state dict
32
+ MASKRCNN_MODEL_WEIGHTS_PATH = "./model_best.pth" # Your best Mask R-CNN weights
33
+ MASKRCNN_BASE_CONFIG = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"
34
+
35
+ # Prediction Thresholds
36
+ DAMAGE_PRED_THRESHOLD = 0.4 # Threshold for showing damage masks
37
+
38
+ # --- Device Setup ---
39
+ if torch.cuda.is_available():
40
+ DEVICE = "cuda"
41
+ logger.info("CUDA available, using GPU.")
42
+ else:
43
+ DEVICE = "cpu"
44
+ logger.info("CUDA not available, using CPU.")
45
+
46
+ # --- MODEL LOADING (Load models globally ONCE on startup) ---
47
+ print("Loading models...")
48
+
49
+ # --- Load CLIP Model ---
50
+ try:
51
+ logger.info("Loading CLIP model...")
52
+ clip_model, clip_preprocess = clip.load("ViT-B/16", device=DEVICE)
53
+ # Optional: Load state dict if you saved it manually
54
+ # clip_model.load_state_dict(torch.load(CLIP_MODEL_WEIGHTS_PATH, map_location=DEVICE))
55
+ clip_model.eval()
56
+ logger.info("CLIP model loaded.")
57
+
58
+ logger.info(f"Loading CLIP text features from {CLIP_TEXT_FEATURES_PATH}...")
59
+ if not os.path.exists(CLIP_TEXT_FEATURES_PATH):
60
+ raise FileNotFoundError(f"CLIP text features not found at {CLIP_TEXT_FEATURES_PATH}. Make sure it's uploaded.")
61
+ clip_text_features = torch.load(CLIP_TEXT_FEATURES_PATH, map_location=DEVICE)
62
+ logger.info("CLIP text features loaded.")
63
+ except Exception as e:
64
+ logger.error(f"Error loading CLIP model or features: {e}", exc_info=True)
65
+ clip_model = None # Set to None if loading fails
66
+
67
+
68
+ # --- Load Mask R-CNN Model ---
69
+ maskrcnn_predictor = None
70
+ maskrcnn_metadata = None
71
+ try:
72
+ logger.info("Setting up Mask R-CNN configuration...")
73
+ cfg_mrcnn = get_cfg()
74
+ cfg_mrcnn.merge_from_file(model_zoo.get_config_file(MASKRCNN_BASE_CONFIG))
75
+
76
+ # Manual configuration based on your previous working setup
77
+ cfg_mrcnn.defrost()
78
+ cfg_mrcnn.MODEL.WEIGHTS = MASKRCNN_MODEL_WEIGHTS_PATH
79
+ if not os.path.exists(MASKRCNN_MODEL_WEIGHTS_PATH):
80
+ raise FileNotFoundError(f"Mask R-CNN weights not found at {MASKRCNN_MODEL_WEIGHTS_PATH}. Make sure it's uploaded.")
81
+
82
+ cfg_mrcnn.MODEL.ROI_HEADS.NUM_CLASSES = NUM_DAMAGE_CLASSES
83
+ cfg_mrcnn.MODEL.ROI_HEADS.SCORE_THRESH_TEST = DAMAGE_PRED_THRESHOLD
84
+ cfg_mrcnn.MODEL.DEVICE = DEVICE
85
+ # Apply necessary norm settings if changed during training
86
+ cfg_mrcnn.MODEL.FPN.NORM = "GN"
87
+ cfg_mrcnn.MODEL.ROI_HEADS.NORM = "GN"
88
+ cfg_mrcnn.freeze()
89
+ logger.info("Mask R-CNN configuration loaded.")
90
+
91
+ logger.info("Creating Mask R-CNN predictor...")
92
+ maskrcnn_predictor = DefaultPredictor(cfg_mrcnn)
93
+ logger.info("Mask R-CNN predictor created.")
94
+
95
+ # Setup metadata for visualization
96
+ metadata_name = "car_damage_inference_app"
97
+ if metadata_name not in MetadataCatalog.list():
98
+ MetadataCatalog.get(metadata_name).set(thing_classes=DAMAGE_CLASSES)
99
+ maskrcnn_metadata = MetadataCatalog.get(metadata_name)
100
+ logger.info("Mask R-CNN metadata prepared.")
101
+
102
+ except Exception as e:
103
+ logger.error(f"Error setting up Mask R-CNN predictor: {e}", exc_info=True)
104
+ maskrcnn_predictor = None # Set to None if loading fails
105
+
106
+ print("Model loading complete.")
107
+
108
+
109
+ # --- Prediction Functions ---
110
+
111
+ def classify_image_clip(image_pil):
112
+ """Classifies image using CLIP. Returns label and probabilities."""
113
+ if clip_model is None or clip_text_features is None:
114
+ return "Error: CLIP Model Not Loaded", {"Error": 1.0}
115
+
116
+ try:
117
+ # Basic preprocessing (CLIP handles resizing)
118
+ image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
119
+
120
+ with torch.no_grad():
121
+ image_features = clip_model.encode_image(image_input)
122
+ image_features /= image_features.norm(dim=-1, keepdim=True)
123
+
124
+ # Calculate similarity
125
+ logit_scale = clip_model.logit_scale.exp()
126
+ similarity = (image_features @ clip_text_features.T) * logit_scale
127
+ probs = similarity.softmax(dim=-1).squeeze().cpu() # Move probs to CPU
128
+
129
+ # Get prediction
130
+ # Index 0 = Car, Index 1 = Not Car (based on your feature creation)
131
+ predicted_label = "Car" if probs[0] > probs[1] else "Not Car"
132
+ prob_dict = {"Car": f"{probs[0]:.3f}", "Not Car": f"{probs[1]:.3f}"}
133
+
134
+ return predicted_label, prob_dict
135
+
136
+ except Exception as e:
137
+ logger.error(f"Error during CLIP prediction: {e}", exc_info=True)
138
+ return "Error during CLIP processing", {"Error": 1.0}
139
+
140
+
141
+ def segment_damage(image_np_bgr):
142
+ """Segments damage using Mask R-CNN. Returns visualized image."""
143
+ if maskrcnn_predictor is None or maskrcnn_metadata is None:
144
+ logger.error("Mask R-CNN predictor or metadata not available.")
145
+ # Return original image with an error message?
146
+ # For simplicity, return None, Gradio interface might handle it better
147
+ return None
148
+
149
+ try:
150
+ logger.info("Running Mask R-CNN inference...")
151
+ outputs = maskrcnn_predictor(image_np_bgr) # Predictor expects BGR numpy array
152
+ predictions = outputs["instances"].to("cpu")
153
+ logger.info(f"Mask R-CNN detected {len(predictions)} instances.")
154
+
155
+ # Visualize
156
+ v = Visualizer(image_np_bgr[:, :, ::-1], # Convert BGR to RGB for Visualizer
157
+ metadata=maskrcnn_metadata,
158
+ scale=0.8,
159
+ instance_mode=ColorMode.SEGMENTATION)
160
+
161
+ # Draw predictions only if any exist
162
+ if len(predictions) > 0:
163
+ out = v.draw_instance_predictions(predictions)
164
+ output_image_np_rgb = out.get_image() # Visualizer gives RGB
165
+ else:
166
+ # If no detections, return the original image (converted to RGB)
167
+ logger.info("No damage instances detected above threshold.")
168
+ output_image_np_rgb = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
169
+
170
+ return output_image_np_rgb
171
+
172
+ except Exception as e:
173
+ logger.error(f"Error during Mask R-CNN prediction/visualization: {e}", exc_info=True)
174
+ # Return original image on error?
175
+ return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
176
+
177
+
178
+ # --- Main Gradio Function ---
179
+
180
+ def predict_pipeline(image_np_input):
181
+ """
182
+ Takes numpy image input, runs CLIP, then optionally Mask R-CNN.
183
+ Returns: classification text, probability dict, output image (numpy RGB)
184
+ """
185
+ if image_np_input is None:
186
+ return "Please upload an image.", {}, None
187
+
188
+ logger.info("Received image for processing...")
189
+
190
+ # --- Stage 1: CLIP Classification ---
191
+ # Convert BGR numpy array from Gradio to PIL RGB for CLIP
192
+ image_pil = Image.fromarray(cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB))
193
+ classification_result, probabilities = classify_image_clip(image_pil)
194
+ logger.info(f"CLIP Result: {classification_result}, Probs: {probabilities}")
195
+
196
+ output_image = None # Initialize output image
197
+
198
+ # --- Stage 2: Damage Segmentation (if classified as 'Car') ---
199
+ if classification_result == "Car":
200
+ logger.info("Image classified as Car. Proceeding to damage segmentation...")
201
+ # Pass the original BGR numpy array to the segmentation function
202
+ output_image = segment_damage(image_np_input)
203
+ if output_image is None: # Handle potential error in segmentation
204
+ logger.warning("Damage segmentation returned None. Displaying original image.")
205
+ output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB)
206
+ elif classification_result == "Not Car":
207
+ logger.info("Image classified as Not Car. Skipping damage segmentation.")
208
+ # Show the original image if it's not a car
209
+ output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB)
210
+ else: # Handle CLIP error case
211
+ logger.error("CLIP classification failed.")
212
+ output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB)
213
+
214
+
215
+ # --- Cleanup ---
216
+ gc.collect()
217
+ if torch.cuda.is_available():
218
+ torch.cuda.empty_cache()
219
+
220
+ return classification_result, probabilities, output_image
221
+
222
+
223
+ # --- Gradio Interface ---
224
+ logger.info("Setting up Gradio interface...")
225
+
226
+ title = "Car Damage Segmentation Pipeline"
227
+ description = """
228
+ Upload an image.
229
+ 1. The first model (CLIP) classifies if it's a car.
230
+ 2. If it's a car, the second model (Mask R-CNN) segments potential damages.
231
+ """
232
+ examples = [
233
+ # Add paths to example images if you upload them to the repo
234
+ # ["./example_car_damaged.jpg"],
235
+ # ["./example_car_ok.jpg"],
236
+ # ["./example_not_car.jpg"],
237
+ ]
238
+
239
+ # Define Inputs and Outputs
240
+ input_image = gr.Image(type="numpy", label="Upload Car Image")
241
+ output_classification = gr.Textbox(label="Classification Result")
242
+ output_probabilities = gr.Label(label="Class Probabilities") # Label is good for dicts
243
+ output_segmentation = gr.Image(type="numpy", label="Damage Segmentation / Original Image")
244
+
245
+
246
+ # Launch the interface
247
+ iface = gr.Interface(
248
+ fn=predict_pipeline,
249
+ inputs=input_image,
250
+ outputs=[output_classification, output_probabilities, output_segmentation],
251
+ title=title,
252
+ description=description,
253
+ examples=examples,
254
+ allow_flagging="never" # Disable flagging unless needed
255
+ )
256
+
257
+ if __name__ == "__main__":
258
+ logger.info("Launching Gradio app...")
259
+ iface.launch() # share=True to create public link (use with caution)
clip_text_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28315215c9429a04e5aafd99cf8a0292a489bf2937d44d580a3cf1c78ee84f94
3
+ size 3283
model_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ec615e5c5a490ed7f70c848208b583870a88bde585b1ce1243f6fbac2509958
3
+ size 503386528
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ gradio
5
+ ultralytics
6
+ opencv-python-headless # Use headless version for servers
7
+ matplotlib # If used by ultralytics plotting or your code
8
+ ftfy # CLIP dependency
9
+ regex # CLIP dependency
10
+ git+https://github.com/openai/CLIP.git # Install CLIP directly
11
+ Pillow # PIL dependency for CLIP/images
12
+ # Add any other specific libraries you might need
13
+ pyyaml # Usually needed by ultralytics/detectron2 indirectly