ishanprogs commited on
Commit
c2717cb
·
verified ·
1 Parent(s): 73266fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -68
app.py CHANGED
@@ -33,7 +33,7 @@ CAR_PART_CLASSES = [
33
  NUM_CAR_PART_CLASSES = len(CAR_PART_CLASSES)
34
 
35
  CLIP_TEXT_FEATURES_PATH = "./clip_text_features.pt"
36
- DAMAGE_MODEL_WEIGHTS_PATH = "./best.pt"
37
  PART_MODEL_WEIGHTS_PATH = "./partdetection_yolobest.pt"
38
  DEFAULT_DAMAGE_PRED_THRESHOLD = 0.4
39
  DEFAULT_PART_PRED_THRESHOLD = 0.3
@@ -52,14 +52,18 @@ try:
52
  logger.info("Loading CLIP model (ViT-B/16)...")
53
  clip_model, clip_preprocess = clip.load("ViT-B/16", device=DEVICE, jit=False)
54
  clip_model.eval()
55
- if not os.path.exists(CLIP_TEXT_FEATURES_PATH): raise FileNotFoundError(f"CLIP text features not found: {CLIP_TEXT_FEATURES_PATH}.")
 
56
  clip_text_features = torch.load(CLIP_TEXT_FEATURES_PATH, map_location=DEVICE)
57
  logger.info(f"CLIP loaded (Text Features dtype: {clip_text_features.dtype}).")
58
- except Exception as e: clip_load_error_msg = f"CLIP load error: {e}"; logger.error(clip_load_error_msg, exc_info=True)
 
 
59
 
60
  try:
61
  logger.info(f"Loading Damage YOLOv8 model from {DAMAGE_MODEL_WEIGHTS_PATH}...")
62
- if not os.path.exists(DAMAGE_MODEL_WEIGHTS_PATH): raise FileNotFoundError(f"Damage model weights not found: {DAMAGE_MODEL_WEIGHTS_PATH}.")
 
63
  damage_model = YOLO(DAMAGE_MODEL_WEIGHTS_PATH)
64
  damage_model.to(DEVICE)
65
  logger.info(f"Damage model task: {damage_model.task}")
@@ -70,13 +74,19 @@ try:
70
  else:
71
  loaded_damage_names = list(damage_model.names.values())
72
  if loaded_damage_names != DAMAGE_CLASSES:
73
- logger.warning(f"Mismatch: Defined DAMAGE_CLASSES vs names in {DAMAGE_MODEL_WEIGHTS_PATH}"); DAMAGE_CLASSES = loaded_damage_names; logger.warning(f"Updated DAMAGE_CLASSES to: {DAMAGE_CLASSES}")
 
 
74
  logger.info("Damage YOLOv8 model loaded.")
75
- except Exception as e: damage_load_error_msg = f"Damage YOLO load error: {e}"; logger.error(damage_load_error_msg, exc_info=True); damage_model = None
 
 
 
76
 
77
  try:
78
  logger.info(f"Loading Part YOLOv8 model from {PART_MODEL_WEIGHTS_PATH}...")
79
- if not os.path.exists(PART_MODEL_WEIGHTS_PATH): raise FileNotFoundError(f"Part model weights not found: {PART_MODEL_WEIGHTS_PATH}.")
 
80
  part_model = YOLO(PART_MODEL_WEIGHTS_PATH)
81
  part_model.to(DEVICE)
82
  logger.info(f"Part model task: {part_model.task}")
@@ -87,20 +97,27 @@ try:
87
  else:
88
  loaded_part_names = list(part_model.names.values())
89
  if loaded_part_names != CAR_PART_CLASSES:
90
- logger.warning(f"Mismatch: Defined CAR_PART_CLASSES vs names in {PART_MODEL_WEIGHTS_PATH}"); CAR_PART_CLASSES = loaded_part_names; logger.warning(f"Updated CAR_PART_CLASSES to: {CAR_PART_CLASSES}")
 
 
91
  logger.info("Part YOLOv8 model loaded.")
92
- except Exception as e: part_load_error_msg = f"Part YOLO load error: {e}"; logger.error(part_load_error_msg, exc_info=True); part_model = None
 
 
 
93
 
94
- print("--- Model loading process finished. ---");
95
- if clip_load_error_msg: print(f"CLIP STATUS: {clip_load_error_msg}"); else: print("CLIP STATUS: Loaded OK.")
96
- if damage_load_error_msg: print(f"DAMAGE MODEL STATUS: {damage_load_error_msg}"); else: print("DAMAGE MODEL STATUS: Loaded OK.")
97
- if part_load_error_msg: print(f"PART MODEL STATUS: {part_load_error_msg}"); else: print("PART MODEL STATUS: Loaded OK.")
98
 
99
  # --- Prediction Functions ---
100
  def classify_image_clip(image_pil):
101
- if clip_model is None: return "Error: CLIP Model Not Loaded", {"Error": 1.0}
 
102
  try:
103
- if image_pil.mode != "RGB": image_pil = image_pil.convert("RGB")
 
104
  image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
105
  with torch.no_grad():
106
  image_features = clip_model.encode_image(image_input)
@@ -111,13 +128,19 @@ def classify_image_clip(image_pil):
111
  similarity = (image_features @ text_features_matched.T) * clip_model.logit_scale.exp()
112
  probs = similarity.softmax(dim=-1).squeeze().cpu()
113
  return ("Car" if probs[0] > probs[1] else "Not Car"), {"Car": f"{probs[0]:.3f}", "Not Car": f"{probs[1]:.3f}"}
114
- except Exception as e: logger.error(f"CLIP Error: {e}", exc_info=True); return "Error: CLIP", {"Error": 1.0}
 
 
115
 
116
  def process_car_image(image_np_bgr, damage_threshold, part_threshold):
117
- if damage_model is None: return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), f"Error: Damage model failed to load ({damage_load_error_msg})"
118
- if part_model is None: return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), f"Error: Part model failed to load ({part_load_error_msg})"
119
- if damage_model.task != 'segment': return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), "Error: Damage model is not a segmentation model."
120
- if part_model.task != 'segment': return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), "Error: Part model is not a segmentation model."
 
 
 
 
121
 
122
  final_assignments = []
123
  annotated_image_bgr = image_np_bgr.copy()
@@ -147,8 +170,10 @@ def process_car_image(image_np_bgr, damage_threshold, part_threshold):
147
  damage_result = damage_results[0]
148
  logger.info(f"Found {len(damage_result.boxes)} potential damages.")
149
  damage_masks_raw = damage_result.masks.data if damage_result.masks is not None else torch.empty((0,0,0), device=DEVICE)
150
- if damage_result.masks is None: logger.warning("No damage masks in result! Check if damage model is segmentation type.")
151
- else: logger.info(f"Damage masks available: shape={damage_masks_raw.shape if damage_masks_raw.numel() > 0 else 'Empty'}")
 
 
152
  damage_classes_ids_cpu = damage_result.boxes.cls.cpu().numpy().astype(int) if damage_result.boxes is not None else np.array([])
153
  damage_boxes_xyxy_cpu = damage_result.boxes.xyxy.cpu() if damage_result.boxes is not None else torch.empty((0,4))
154
 
@@ -158,28 +183,69 @@ def process_car_image(image_np_bgr, damage_threshold, part_threshold):
158
  part_result = part_results[0]
159
  logger.info(f"Found {len(part_result.boxes)} potential parts.")
160
  part_masks_raw = part_result.masks.data if part_result.masks is not None else torch.empty((0,0,0), device=DEVICE)
161
- if part_result.masks is None: logger.warning("No part masks in result! Check if part model is segmentation type.")
162
- else: logger.info(f"Part masks available: shape={part_masks_raw.shape if part_masks_raw.numel() > 0 else 'Empty'}")
 
 
163
  part_classes_ids_cpu = part_result.boxes.cls.cpu().numpy().astype(int) if part_result.boxes is not None else np.array([])
164
  part_boxes_xyxy_cpu = part_result.boxes.xyxy.cpu() if part_result.boxes is not None else torch.empty((0,4))
165
 
166
  # --- 3. Resize Masks ---
167
  def resize_masks(masks_tensor, target_h, target_w):
168
- # ... (resize logic remains the same - uses CPU numpy) ...
169
- masks_np_bool = masks_tensor.cpu().numpy().astype(bool); if masks_np_bool.shape[0] == 0 or (masks_np_bool.shape[1] == target_h and masks_np_bool.shape[2] == target_w): return masks_np_bool; resized_masks_list = []; for i in range(masks_np_bool.shape[0]): mask = masks_np_bool[i]; mask_resized = cv2.resize(mask.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST); resized_masks_list.append(mask_resized.astype(bool)); return np.array(resized_masks_list)
 
 
 
 
 
 
 
 
170
  damage_masks_np = resize_masks(damage_masks_raw, img_h, img_w)
171
  part_masks_np = resize_masks(part_masks_raw, img_h, img_w)
172
 
173
  # --- 4. Calculate Overlap ---
174
  logger.info("Calculating overlap...")
175
- # ... (Overlap calculation logic remains the same - uses CPU numpy) ...
176
- if damage_masks_np.shape[0] > 0 and part_masks_np.shape[0] > 0: overlap_threshold = 0.4;
177
- for i in range(len(damage_masks_np)): damage_mask = damage_masks_np[i]; damage_class_id = damage_classes_ids_cpu[i]; try: damage_name = DAMAGE_CLASSES[damage_class_id]; except IndexError: logger.warning(f"Invalid damage ID {damage_class_id}"); continue; damage_area = np.sum(damage_mask); if damage_area < 10: continue; max_overlap = 0; assigned_part_name = "Unknown / Outside Parts";
178
- for j in range(len(part_masks_np)): part_mask = part_masks_np[j]; part_class_id = part_classes_ids_cpu[j]; try: part_name = CAR_PART_CLASSES[part_class_id]; except IndexError: logger.warning(f"Invalid part ID {part_class_id}"); continue; intersection = np.logical_and(damage_mask, part_mask); overlap_ratio = np.sum(intersection) / damage_area if damage_area > 0 else 0; if overlap_ratio > max_overlap: max_overlap = overlap_ratio; if max_overlap >= overlap_threshold: assigned_part_name = part_name;
179
- assignment_desc = f"{damage_name} in {assigned_part_name}"; if assigned_part_name == "Unknown / Outside Parts": assignment_desc += f" (Overlap < {overlap_threshold*100:.0f}%)"; final_assignments.append(assignment_desc)
180
- elif damage_masks_np.shape[0] > 0: final_assignments.append(f"{len(damage_masks_np)} damages found, but no parts detected/matched above threshold {part_threshold}.")
181
- elif part_masks_np.shape[0] > 0: final_assignments.append(f"No damages detected above threshold {damage_threshold}.")
182
- else: final_assignments.append(f"No damages or parts detected above thresholds.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  logger.info(f" Assignment results: {final_assignments}")
184
 
185
  # --- 5. Visualization using YOLO Annotator ---
@@ -192,15 +258,21 @@ def process_car_image(image_np_bgr, damage_threshold, part_threshold):
192
  logger.info("Attempting to draw part masks...")
193
  colors_part = [(0, random.randint(100, 200), 0) for _ in part_classes_ids_cpu]
194
  mask_data_part = part_masks_raw
195
- if mask_data_part.device != im_tensor_gpu_for_annotator.device: mask_data_part = mask_data_part.to(im_tensor_gpu_for_annotator.device)
 
196
  annotator.masks(mask_data_part, colors=colors_part, im_gpu=im_tensor_gpu_for_annotator, alpha=0.3)
197
  logger.info("Successfully drew part masks.")
198
  for box, cls_id in zip(part_boxes_xyxy_cpu, part_classes_ids_cpu):
199
- try: label = f"{CAR_PART_CLASSES[cls_id]}"; annotator.box_label(box, label=label, color=(0, 200, 0))
200
- except IndexError: logger.warning(f"Invalid part ID {cls_id} during drawing")
201
- except Exception as e_part_vis: logger.error(f"Error drawing part masks/boxes: {e_part_vis}", exc_info=True); traceback.print_exc()
202
- elif part_masks_raw.numel() > 0: logger.warning("Part masks exist but image tensor for annotator is None. Skipping part mask drawing.")
203
-
 
 
 
 
 
204
 
205
  # Draw DAMAGE masks
206
  if damage_masks_raw.numel() > 0 and im_tensor_gpu_for_annotator is not None:
@@ -215,16 +287,22 @@ def process_car_image(image_np_bgr, damage_threshold, part_threshold):
215
  annotator.masks(mask_data_dmg, colors=colors_dmg, im_gpu=im_tensor_gpu_for_annotator, alpha=0.4)
216
  logger.info("Successfully drew damage masks.")
217
  for box, cls_id in zip(damage_boxes_xyxy_cpu, damage_classes_ids_cpu):
218
- try: label = f"{DAMAGE_CLASSES[cls_id]}"; annotator.box_label(box, label=label, color=(200, 0, 0))
219
- except IndexError: logger.warning(f"Invalid damage ID {cls_id} during drawing")
220
- except Exception as e_dmg_vis: logger.error(f"Error drawing damage masks/boxes: {e_dmg_vis}", exc_info=True); traceback.print_exc()
221
- elif damage_masks_raw.numel() > 0: logger.warning("Damage masks exist but image tensor for annotator is None. Skipping damage mask drawing.")
222
-
 
 
 
 
 
223
 
224
  annotated_image_bgr = annotator.result()
225
 
226
  except Exception as e:
227
- logger.error(f"Error during combined processing: {e}", exc_info=True); traceback.print_exc()
 
228
  final_assignments.append("Error during segmentation/processing.")
229
 
230
  assignment_text = "\n".join(final_assignments) if final_assignments else "No damage assignments generated."
@@ -233,33 +311,62 @@ def process_car_image(image_np_bgr, damage_threshold, part_threshold):
233
 
234
  # --- Main Gradio Function ---
235
  def predict_pipeline(image_np_input, damage_thresh, part_thresh):
236
- if image_np_input is None: return "Please upload an image.", {}, None, "N/A";
237
- logger.info(f"--- New Request (Damage Thr: {damage_thresh:.2f}, Part Thr: {part_thresh:.2f}) ---"); start_time = time.time();
238
- image_np_bgr = cv2.cvtColor(image_np_input, cv2.COLOR_RGB2BGR); image_pil = Image.fromarray(image_np_input);
239
- final_output_image, assignment_text, classification_result, probabilities = None, "Processing...", "Error", {};
240
- try: classification_result, probabilities = classify_image_clip(image_pil)
241
- except Exception as e: logger.error(f"CLIP Error: {e}", exc_info=True); assignment_text = f"CLIP Error: {e}"; final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB);
 
 
 
 
 
 
 
 
 
242
  if classification_result == "Car":
243
- try: final_output_image, assignment_text = process_car_image(image_np_bgr, damage_thresh, part_thresh)
244
- except Exception as e: logger.error(f"Seg/Assign Error: {e}", exc_info=True); assignment_text = f"Seg Error: {e}"; final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB);
245
- elif classification_result == "Not Car": final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB); assignment_text = "Image classified as Not Car.";
246
- elif final_output_image is None: final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB); assignment_text = "Error during classification.";
247
- gc.collect();
248
- if torch.cuda.is_available(): torch.cuda.empty_cache();
249
- logger.info(f"Total processing time: {time.time() - start_time:.2f}s.");
 
 
 
 
 
 
 
 
 
 
250
  return classification_result, probabilities, final_output_image, assignment_text
251
 
252
  # --- Gradio Interface ---
253
  logger.info("Setting up Gradio interface...")
254
  title = "🚗 Car Damage Detection"
255
  description = "1. Upload... 2. Classify... 3. Segment... 4. Assign... 5. Output..." # Shortened
256
- input_image = gr.Image(type="numpy", label="Upload Car Image");
257
- damage_threshold_slider = gr.Slider(minimum=0.05, maximum=0.95, step=0.05, value=DEFAULT_DAMAGE_PRED_THRESHOLD, label="Damage Confidence Threshold");
258
- part_threshold_slider = gr.Slider(minimum=0.05, maximum=0.95, step=0.05, value=DEFAULT_PART_PRED_THRESHOLD, label="Part Confidence Threshold");
259
- output_classification = gr.Textbox(label="1. Classification Result");
260
- output_probabilities = gr.Label(label="Classification Probabilities");
261
- output_image_display = gr.Image(type="numpy", label="3. Segmentation Visualization");
262
- output_assignment = gr.Textbox(label="2. Damage Assignments", lines=5, interactive=False);
263
- iface = gr.Interface(fn=predict_pipeline, inputs=[input_image, damage_threshold_slider, part_threshold_slider], outputs=[output_classification, output_probabilities, output_image_display, output_assignment], title=title, description=description, allow_flagging="never" );
 
 
 
 
 
 
 
 
264
 
265
- if __name__ == "__main__": logger.info("Launching Gradio app..."); iface.launch()
 
 
 
33
  NUM_CAR_PART_CLASSES = len(CAR_PART_CLASSES)
34
 
35
  CLIP_TEXT_FEATURES_PATH = "./clip_text_features.pt"
36
+ DAMAGE_MODEL_WEIGHTS_PATH = "./model_best.pt"
37
  PART_MODEL_WEIGHTS_PATH = "./partdetection_yolobest.pt"
38
  DEFAULT_DAMAGE_PRED_THRESHOLD = 0.4
39
  DEFAULT_PART_PRED_THRESHOLD = 0.3
 
52
  logger.info("Loading CLIP model (ViT-B/16)...")
53
  clip_model, clip_preprocess = clip.load("ViT-B/16", device=DEVICE, jit=False)
54
  clip_model.eval()
55
+ if not os.path.exists(CLIP_TEXT_FEATURES_PATH):
56
+ raise FileNotFoundError(f"CLIP text features not found: {CLIP_TEXT_FEATURES_PATH}.")
57
  clip_text_features = torch.load(CLIP_TEXT_FEATURES_PATH, map_location=DEVICE)
58
  logger.info(f"CLIP loaded (Text Features dtype: {clip_text_features.dtype}).")
59
+ except Exception as e:
60
+ clip_load_error_msg = f"CLIP load error: {e}"
61
+ logger.error(clip_load_error_msg, exc_info=True)
62
 
63
  try:
64
  logger.info(f"Loading Damage YOLOv8 model from {DAMAGE_MODEL_WEIGHTS_PATH}...")
65
+ if not os.path.exists(DAMAGE_MODEL_WEIGHTS_PATH):
66
+ raise FileNotFoundError(f"Damage model weights not found: {DAMAGE_MODEL_WEIGHTS_PATH}.")
67
  damage_model = YOLO(DAMAGE_MODEL_WEIGHTS_PATH)
68
  damage_model.to(DEVICE)
69
  logger.info(f"Damage model task: {damage_model.task}")
 
74
  else:
75
  loaded_damage_names = list(damage_model.names.values())
76
  if loaded_damage_names != DAMAGE_CLASSES:
77
+ logger.warning(f"Mismatch: Defined DAMAGE_CLASSES vs names in {DAMAGE_MODEL_WEIGHTS_PATH}")
78
+ DAMAGE_CLASSES = loaded_damage_names
79
+ logger.warning(f"Updated DAMAGE_CLASSES to: {DAMAGE_CLASSES}")
80
  logger.info("Damage YOLOv8 model loaded.")
81
+ except Exception as e:
82
+ damage_load_error_msg = f"Damage YOLO load error: {e}"
83
+ logger.error(damage_load_error_msg, exc_info=True)
84
+ damage_model = None
85
 
86
  try:
87
  logger.info(f"Loading Part YOLOv8 model from {PART_MODEL_WEIGHTS_PATH}...")
88
+ if not os.path.exists(PART_MODEL_WEIGHTS_PATH):
89
+ raise FileNotFoundError(f"Part model weights not found: {PART_MODEL_WEIGHTS_PATH}.")
90
  part_model = YOLO(PART_MODEL_WEIGHTS_PATH)
91
  part_model.to(DEVICE)
92
  logger.info(f"Part model task: {part_model.task}")
 
97
  else:
98
  loaded_part_names = list(part_model.names.values())
99
  if loaded_part_names != CAR_PART_CLASSES:
100
+ logger.warning(f"Mismatch: Defined CAR_PART_CLASSES vs names in {PART_MODEL_WEIGHTS_PATH}")
101
+ CAR_PART_CLASSES = loaded_part_names
102
+ logger.warning(f"Updated CAR_PART_CLASSES to: {CAR_PART_CLASSES}")
103
  logger.info("Part YOLOv8 model loaded.")
104
+ except Exception as e:
105
+ part_load_error_msg = f"Part YOLO load error: {e}"
106
+ logger.error(part_load_error_msg, exc_info=True)
107
+ part_model = None
108
 
109
+ print("--- Model loading process finished. ---")
110
+ print(f"CLIP STATUS: {clip_load_error_msg}" if clip_load_error_msg else "CLIP STATUS: Loaded OK.")
111
+ print(f"DAMAGE MODEL STATUS: {damage_load_error_msg}" if damage_load_error_msg else "DAMAGE MODEL STATUS: Loaded OK.")
112
+ print(f"PART MODEL STATUS: {part_load_error_msg}" if part_load_error_msg else "PART MODEL STATUS: Loaded OK.")
113
 
114
  # --- Prediction Functions ---
115
  def classify_image_clip(image_pil):
116
+ if clip_model is None:
117
+ return "Error: CLIP Model Not Loaded", {"Error": 1.0}
118
  try:
119
+ if image_pil.mode != "RGB":
120
+ image_pil = image_pil.convert("RGB")
121
  image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
122
  with torch.no_grad():
123
  image_features = clip_model.encode_image(image_input)
 
128
  similarity = (image_features @ text_features_matched.T) * clip_model.logit_scale.exp()
129
  probs = similarity.softmax(dim=-1).squeeze().cpu()
130
  return ("Car" if probs[0] > probs[1] else "Not Car"), {"Car": f"{probs[0]:.3f}", "Not Car": f"{probs[1]:.3f}"}
131
+ except Exception as e:
132
+ logger.error(f"CLIP Error: {e}", exc_info=True)
133
+ return "Error: CLIP", {"Error": 1.0}
134
 
135
  def process_car_image(image_np_bgr, damage_threshold, part_threshold):
136
+ if damage_model is None:
137
+ return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), f"Error: Damage model failed to load ({damage_load_error_msg})"
138
+ if part_model is None:
139
+ return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), f"Error: Part model failed to load ({part_load_error_msg})"
140
+ if damage_model.task != 'segment':
141
+ return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), "Error: Damage model is not a segmentation model."
142
+ if part_model.task != 'segment':
143
+ return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), "Error: Part model is not a segmentation model."
144
 
145
  final_assignments = []
146
  annotated_image_bgr = image_np_bgr.copy()
 
170
  damage_result = damage_results[0]
171
  logger.info(f"Found {len(damage_result.boxes)} potential damages.")
172
  damage_masks_raw = damage_result.masks.data if damage_result.masks is not None else torch.empty((0,0,0), device=DEVICE)
173
+ if damage_result.masks is None:
174
+ logger.warning("No damage masks in result! Check if damage model is segmentation type.")
175
+ else:
176
+ logger.info(f"Damage masks available: shape={damage_masks_raw.shape if damage_masks_raw.numel() > 0 else 'Empty'}")
177
  damage_classes_ids_cpu = damage_result.boxes.cls.cpu().numpy().astype(int) if damage_result.boxes is not None else np.array([])
178
  damage_boxes_xyxy_cpu = damage_result.boxes.xyxy.cpu() if damage_result.boxes is not None else torch.empty((0,4))
179
 
 
183
  part_result = part_results[0]
184
  logger.info(f"Found {len(part_result.boxes)} potential parts.")
185
  part_masks_raw = part_result.masks.data if part_result.masks is not None else torch.empty((0,0,0), device=DEVICE)
186
+ if part_result.masks is None:
187
+ logger.warning("No part masks in result! Check if part model is segmentation type.")
188
+ else:
189
+ logger.info(f"Part masks available: shape={part_masks_raw.shape if part_masks_raw.numel() > 0 else 'Empty'}")
190
  part_classes_ids_cpu = part_result.boxes.cls.cpu().numpy().astype(int) if part_result.boxes is not None else np.array([])
191
  part_boxes_xyxy_cpu = part_result.boxes.xyxy.cpu() if part_result.boxes is not None else torch.empty((0,4))
192
 
193
  # --- 3. Resize Masks ---
194
  def resize_masks(masks_tensor, target_h, target_w):
195
+ masks_np_bool = masks_tensor.cpu().numpy().astype(bool)
196
+ if masks_np_bool.shape[0] == 0 or (masks_np_bool.shape[1] == target_h and masks_np_bool.shape[2] == target_w):
197
+ return masks_np_bool
198
+ resized_masks_list = []
199
+ for i in range(masks_np_bool.shape[0]):
200
+ mask = masks_np_bool[i]
201
+ mask_resized = cv2.resize(mask.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST)
202
+ resized_masks_list.append(mask_resized.astype(bool))
203
+ return np.array(resized_masks_list)
204
+
205
  damage_masks_np = resize_masks(damage_masks_raw, img_h, img_w)
206
  part_masks_np = resize_masks(part_masks_raw, img_h, img_w)
207
 
208
  # --- 4. Calculate Overlap ---
209
  logger.info("Calculating overlap...")
210
+ if damage_masks_np.shape[0] > 0 and part_masks_np.shape[0] > 0:
211
+ overlap_threshold = 0.4
212
+ for i in range(len(damage_masks_np)):
213
+ damage_mask = damage_masks_np[i]
214
+ damage_class_id = damage_classes_ids_cpu[i]
215
+ try:
216
+ damage_name = DAMAGE_CLASSES[damage_class_id]
217
+ except IndexError:
218
+ logger.warning(f"Invalid damage ID {damage_class_id}")
219
+ continue
220
+ damage_area = np.sum(damage_mask)
221
+ if damage_area < 10:
222
+ continue
223
+ max_overlap = 0
224
+ assigned_part_name = "Unknown / Outside Parts"
225
+ for j in range(len(part_masks_np)):
226
+ part_mask = part_masks_np[j]
227
+ part_class_id = part_classes_ids_cpu[j]
228
+ try:
229
+ part_name = CAR_PART_CLASSES[part_class_id]
230
+ except IndexError:
231
+ logger.warning(f"Invalid part ID {part_class_id}")
232
+ continue
233
+ intersection = np.logical_and(damage_mask, part_mask)
234
+ overlap_ratio = np.sum(intersection) / damage_area if damage_area > 0 else 0
235
+ if overlap_ratio > max_overlap:
236
+ max_overlap = overlap_ratio
237
+ if max_overlap >= overlap_threshold:
238
+ assigned_part_name = part_name
239
+ assignment_desc = f"{damage_name} in {assigned_part_name}"
240
+ if assigned_part_name == "Unknown / Outside Parts":
241
+ assignment_desc += f" (Overlap < {overlap_threshold*100:.0f}%)"
242
+ final_assignments.append(assignment_desc)
243
+ elif damage_masks_np.shape[0] > 0:
244
+ final_assignments.append(f"{len(damage_masks_np)} damages found, but no parts detected/matched above threshold {part_threshold}.")
245
+ elif part_masks_np.shape[0] > 0:
246
+ final_assignments.append(f"No damages detected above threshold {damage_threshold}.")
247
+ else:
248
+ final_assignments.append(f"No damages or parts detected above thresholds.")
249
  logger.info(f" Assignment results: {final_assignments}")
250
 
251
  # --- 5. Visualization using YOLO Annotator ---
 
258
  logger.info("Attempting to draw part masks...")
259
  colors_part = [(0, random.randint(100, 200), 0) for _ in part_classes_ids_cpu]
260
  mask_data_part = part_masks_raw
261
+ if mask_data_part.device != im_tensor_gpu_for_annotator.device:
262
+ mask_data_part = mask_data_part.to(im_tensor_gpu_for_annotator.device)
263
  annotator.masks(mask_data_part, colors=colors_part, im_gpu=im_tensor_gpu_for_annotator, alpha=0.3)
264
  logger.info("Successfully drew part masks.")
265
  for box, cls_id in zip(part_boxes_xyxy_cpu, part_classes_ids_cpu):
266
+ try:
267
+ label = f"{CAR_PART_CLASSES[cls_id]}"
268
+ annotator.box_label(box, label=label, color=(0, 200, 0))
269
+ except IndexError:
270
+ logger.warning(f"Invalid part ID {cls_id} during drawing")
271
+ except Exception as e_part_vis:
272
+ logger.error(f"Error drawing part masks/boxes: {e_part_vis}", exc_info=True)
273
+ traceback.print_exc()
274
+ elif part_masks_raw.numel() > 0:
275
+ logger.warning("Part masks exist but image tensor for annotator is None. Skipping part mask drawing.")
276
 
277
  # Draw DAMAGE masks
278
  if damage_masks_raw.numel() > 0 and im_tensor_gpu_for_annotator is not None:
 
287
  annotator.masks(mask_data_dmg, colors=colors_dmg, im_gpu=im_tensor_gpu_for_annotator, alpha=0.4)
288
  logger.info("Successfully drew damage masks.")
289
  for box, cls_id in zip(damage_boxes_xyxy_cpu, damage_classes_ids_cpu):
290
+ try:
291
+ label = f"{DAMAGE_CLASSES[cls_id]}"
292
+ annotator.box_label(box, label=label, color=(200, 0, 0))
293
+ except IndexError:
294
+ logger.warning(f"Invalid damage ID {cls_id} during drawing")
295
+ except Exception as e_dmg_vis:
296
+ logger.error(f"Error drawing damage masks/boxes: {e_dmg_vis}", exc_info=True)
297
+ traceback.print_exc()
298
+ elif damage_masks_raw.numel() > 0:
299
+ logger.warning("Damage masks exist but image tensor for annotator is None. Skipping damage mask drawing.")
300
 
301
  annotated_image_bgr = annotator.result()
302
 
303
  except Exception as e:
304
+ logger.error(f"Error during combined processing: {e}", exc_info=True)
305
+ traceback.print_exc()
306
  final_assignments.append("Error during segmentation/processing.")
307
 
308
  assignment_text = "\n".join(final_assignments) if final_assignments else "No damage assignments generated."
 
311
 
312
  # --- Main Gradio Function ---
313
  def predict_pipeline(image_np_input, damage_thresh, part_thresh):
314
+ if image_np_input is None:
315
+ return "Please upload an image.", {}, None, "N/A"
316
+ logger.info(f"--- New Request (Damage Thr: {damage_thresh:.2f}, Part Thr: {part_thresh:.2f}) ---")
317
+ start_time = time.time()
318
+ image_np_bgr = cv2.cvtColor(image_np_input, cv2.COLOR_RGB2BGR)
319
+ image_pil = Image.fromarray(image_np_input)
320
+ final_output_image, assignment_text, classification_result, probabilities = None, "Processing...", "Error", {}
321
+
322
+ try:
323
+ classification_result, probabilities = classify_image_clip(image_pil)
324
+ except Exception as e:
325
+ logger.error(f"CLIP Error: {e}", exc_info=True)
326
+ assignment_text = f"CLIP Error: {e}"
327
+ final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
328
+
329
  if classification_result == "Car":
330
+ try:
331
+ final_output_image, assignment_text = process_car_image(image_np_bgr, damage_thresh, part_thresh)
332
+ except Exception as e:
333
+ logger.error(f"Seg/Assign Error: {e}", exc_info=True)
334
+ assignment_text = f"Seg Error: {e}"
335
+ final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
336
+ elif classification_result == "Not Car":
337
+ final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
338
+ assignment_text = "Image classified as Not Car."
339
+ elif final_output_image is None:
340
+ final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
341
+ assignment_text = "Error during classification."
342
+
343
+ gc.collect()
344
+ if torch.cuda.is_available():
345
+ torch.cuda.empty_cache()
346
+ logger.info(f"Total processing time: {time.time() - start_time:.2f}s.")
347
  return classification_result, probabilities, final_output_image, assignment_text
348
 
349
  # --- Gradio Interface ---
350
  logger.info("Setting up Gradio interface...")
351
  title = "🚗 Car Damage Detection"
352
  description = "1. Upload... 2. Classify... 3. Segment... 4. Assign... 5. Output..." # Shortened
353
+ input_image = gr.Image(type="numpy", label="Upload Car Image")
354
+ damage_threshold_slider = gr.Slider(minimum=0.05, maximum=0.95, step=0.05, value=DEFAULT_DAMAGE_PRED_THRESHOLD, label="Damage Confidence Threshold")
355
+ part_threshold_slider = gr.Slider(minimum=0.05, maximum=0.95, step=0.05, value=DEFAULT_PART_PRED_THRESHOLD, label="Part Confidence Threshold")
356
+ output_classification = gr.Textbox(label="1. Classification Result")
357
+ output_probabilities = gr.Label(label="Classification Probabilities")
358
+ output_image_display = gr.Image(type="numpy", label="3. Segmentation Visualization")
359
+ output_assignment = gr.Textbox(label="2. Damage Assignments", lines=5, interactive=False)
360
+
361
+ iface = gr.Interface(
362
+ fn=predict_pipeline,
363
+ inputs=[input_image, damage_threshold_slider, part_threshold_slider],
364
+ outputs=[output_classification, output_probabilities, output_image_display, output_assignment],
365
+ title=title,
366
+ description=description,
367
+ allow_flagging="never"
368
+ )
369
 
370
+ if __name__ == "__main__":
371
+ logger.info("Launching Gradio app...")
372
+ iface.launch()