ishanprogs commited on
Commit
ad31d5e
·
verified ·
1 Parent(s): a883ac2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +407 -0
app.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import clip
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os
7
+ import cv2
8
+ import gc # Garbage collector
9
+ import logging
10
+ import random # For annotator colors
11
+ import time # For timing checks
12
+ import traceback # For detailed error printing
13
+
14
+ # --- YOLOv8 Imports ---
15
+ from ultralytics import YOLO
16
+ from ultralytics.utils.plotting import Annotator # For drawing YOLO results
17
+
18
+ # --- Setup Logging ---
19
+ logging.getLogger("ultralytics").setLevel(logging.WARNING) # Reduce YOLO logging noise
20
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # --- Constants ---
24
+ # Damage segmentation classes (Order MUST match the training of 'model_best.pt')
25
+ DAMAGE_CLASSES = ['Cracked', 'Scratch', 'Flaking', 'Broken part', 'Corrosion', 'Dent', 'Paint chip', 'Missing part']
26
+ NUM_DAMAGE_CLASSES = len(DAMAGE_CLASSES)
27
+
28
+ # Part segmentation classes (Order MUST match the training of 'partdetection_yolobest.pt')
29
+ CAR_PART_CLASSES = [
30
+ "Quarter-panel", "Front-wheel", "Back-window", "Trunk", "Front-door",
31
+ "Rocker-panel", "Grille", "Windshield", "Front-window", "Back-door",
32
+ "Headlight", "Back-wheel", "Back-windshield", "Hood", "Fender",
33
+ "Tail-light", "License-plate", "Front-bumper", "Back-bumper", "Mirror",
34
+ "Roof"
35
+ ]
36
+ NUM_CAR_PART_CLASSES = len(CAR_PART_CLASSES)
37
+
38
+
39
+ # Paths within the Hugging Face Space repository
40
+ CLIP_TEXT_FEATURES_PATH = "./clip_text_features.pt"
41
+ DAMAGE_MODEL_WEIGHTS_PATH = "./best.pt" # Your YOLOv8 damage model weights
42
+ PART_MODEL_WEIGHTS_PATH = "./partdetection_yolobest.pt" # Your YOLOv8 part model weights
43
+
44
+ # Prediction Thresholds
45
+ DAMAGE_PRED_THRESHOLD = 0.4 # Threshold for showing damage masks
46
+ PART_PRED_THRESHOLD = 0.3 # Threshold for showing part masks
47
+
48
+ # --- Device Setup ---
49
+ if torch.cuda.is_available():
50
+ DEVICE = "cuda"
51
+ logger.info("CUDA available, using GPU.")
52
+ else:
53
+ DEVICE = "cpu"
54
+ logger.info("CUDA not available, using CPU.")
55
+
56
+ # --- MODEL LOADING (Load models globally ONCE on startup) ---
57
+ print("Loading models...")
58
+ clip_model = None
59
+ clip_preprocess = None
60
+ clip_text_features = None
61
+ damage_model = None
62
+ part_model = None
63
+ clip_load_error_msg = None
64
+ damage_load_error_msg = None
65
+ part_load_error_msg = None
66
+
67
+
68
+ # --- Load CLIP Model (Model 1) ---
69
+ try:
70
+ logger.info("Loading CLIP model (ViT-B/16)...")
71
+ # jit=False might improve stability/compatibility in some environments
72
+ clip_model, clip_preprocess = clip.load("ViT-B/16", device=DEVICE, jit=False)
73
+ clip_model.eval()
74
+ logger.info("CLIP model loaded.")
75
+
76
+ logger.info(f"Loading CLIP text features from {CLIP_TEXT_FEATURES_PATH}...")
77
+ if not os.path.exists(CLIP_TEXT_FEATURES_PATH):
78
+ raise FileNotFoundError(f"CLIP text features not found: {CLIP_TEXT_FEATURES_PATH}.")
79
+ # Load text features initially to the designated device
80
+ clip_text_features = torch.load(CLIP_TEXT_FEATURES_PATH, map_location=DEVICE)
81
+ logger.info(f"CLIP text features loaded (dtype: {clip_text_features.dtype}).")
82
+
83
+ except Exception as e:
84
+ clip_load_error_msg = f"CLIP load error: {e}"
85
+ logger.error(clip_load_error_msg, exc_info=True)
86
+ clip_model = None # Set to None if loading fails
87
+
88
+ # --- Load Damage Segmentation Model (Model 2 - YOLOv8) ---
89
+ try:
90
+ logger.info(f"Loading Damage Segmentation (YOLOv8) model from {DAMAGE_MODEL_WEIGHTS_PATH}...")
91
+ if not os.path.exists(DAMAGE_MODEL_WEIGHTS_PATH):
92
+ raise FileNotFoundError(f"Damage model weights not found: {DAMAGE_MODEL_WEIGHTS_PATH}.")
93
+ damage_model = YOLO(DAMAGE_MODEL_WEIGHTS_PATH)
94
+ damage_model.to(DEVICE)
95
+ # Verify class names match
96
+ loaded_damage_names = list(damage_model.names.values())
97
+ if loaded_damage_names != DAMAGE_CLASSES:
98
+ logger.warning(f"Mismatch: Defined DAMAGE_CLASSES vs names in {DAMAGE_MODEL_WEIGHTS_PATH}")
99
+ DAMAGE_CLASSES = loaded_damage_names # Use names from model file
100
+ logger.warning(f"Updated DAMAGE_CLASSES to: {DAMAGE_CLASSES}")
101
+ logger.info("Damage Segmentation (YOLOv8) model loaded.")
102
+ except Exception as e:
103
+ damage_load_error_msg = f"Damage YOLO load error: {e}"
104
+ logger.error(damage_load_error_msg, exc_info=True)
105
+ damage_model = None
106
+
107
+ # --- Load Part Segmentation Model (Model 3 - YOLOv8) ---
108
+ try:
109
+ logger.info(f"Loading Part Segmentation (YOLOv8) model from {PART_MODEL_WEIGHTS_PATH}...")
110
+ if not os.path.exists(PART_MODEL_WEIGHTS_PATH):
111
+ raise FileNotFoundError(f"Part model weights not found: {PART_MODEL_WEIGHTS_PATH}.")
112
+ part_model = YOLO(PART_MODEL_WEIGHTS_PATH)
113
+ part_model.to(DEVICE)
114
+ # Verify class names match
115
+ loaded_part_names = list(part_model.names.values())
116
+ if loaded_part_names != CAR_PART_CLASSES:
117
+ logger.warning(f"Mismatch: Defined CAR_PART_CLASSES vs names in {PART_MODEL_WEIGHTS_PATH}")
118
+ CAR_PART_CLASSES = loaded_part_names # Use names from model file
119
+ logger.warning(f"Updated CAR_PART_CLASSES to: {CAR_PART_CLASSES}")
120
+ logger.info("Part Segmentation (YOLOv8) model loaded.")
121
+ except Exception as e:
122
+ part_load_error_msg = f"Part YOLO load error: {e}"
123
+ logger.error(part_load_error_msg, exc_info=True)
124
+ part_model = None
125
+
126
+ print("Model loading process finished.")
127
+ if clip_load_error_msg: print(f"WARNING: {clip_load_error_msg}")
128
+ if damage_load_error_msg: print(f"WARNING: {damage_load_error_msg}")
129
+ if part_load_error_msg: print(f"WARNING: {part_load_error_msg}")
130
+
131
+
132
+ # --- Prediction Functions ---
133
+
134
+ # --- Updated classify_image_clip (incorporating dtype handling) ---
135
+ def classify_image_clip(image_pil):
136
+ """Classifies image using CLIP. Returns label and probability dictionary."""
137
+ if clip_model is None or clip_text_features is None:
138
+ logger.error(f"CLIP model or text features not loaded. Error: {clip_load_error_msg}")
139
+ return "Error: CLIP Model Not Loaded", {"Error": 1.0}
140
+
141
+ logger.info("Running CLIP classification...")
142
+ try:
143
+ # Ensure image is RGB PIL
144
+ if image_pil.mode != "RGB":
145
+ image_pil = image_pil.convert("RGB")
146
+
147
+ logger.info(" Preprocessing image for CLIP...")
148
+ image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
149
+ logger.info(f" Image input tensor created (device: {image_input.device}, dtype: {image_input.dtype}).")
150
+
151
+ with torch.no_grad():
152
+ logger.info(" Encoding image with CLIP...")
153
+ image_features = clip_model.encode_image(image_input)
154
+ logger.info(f" Image features encoded (dtype: {image_features.dtype}).")
155
+ image_features /= image_features.norm(dim=-1, keepdim=True)
156
+
157
+ # --- Ensure Text Features match Image Features dtype ---
158
+ text_features_matched = clip_text_features
159
+ if image_features.dtype != clip_text_features.dtype:
160
+ logger.warning(f" Dtype mismatch! Image: {image_features.dtype}, Text: {clip_text_features.dtype}. Converting text features...")
161
+ text_features_matched = clip_text_features.to(image_features.dtype)
162
+ # -----------------------------------------------------
163
+
164
+ logit_scale = clip_model.logit_scale.exp()
165
+ logger.info(" Calculating similarity...")
166
+ similarity = (image_features @ text_features_matched.T) * logit_scale
167
+ probs = similarity.softmax(dim=-1).squeeze().cpu() # Move probabilities to CPU
168
+ logger.info(" Similarity calculated.")
169
+
170
+ # Indices based on your original feature creation: 0=Car, 1=Not Car
171
+ car_prob = probs[0].item()
172
+ not_car_prob = probs[1].item()
173
+
174
+ predicted_label = "Car" if car_prob > not_car_prob else "Not Car"
175
+ # Format probabilities for display
176
+ prob_dict = {"Car": f"{car_prob:.3f}", "Not Car": f"{not_car_prob:.3f}"}
177
+ logger.info(f"CLIP Result: {predicted_label}, Probs: {prob_dict}")
178
+
179
+ return predicted_label, prob_dict # Return dictionary
180
+
181
+ except Exception as e:
182
+ logger.error(f"Error during CLIP prediction: {e}", exc_info=True)
183
+ traceback.print_exc() # Print detailed traceback to logs
184
+ return "Error during CLIP processing", {"Error": 1.0}
185
+
186
+
187
+ # --- Combined Processing and Overlap Logic (process_car_image) ---
188
+ # (Keep the process_car_image function from the previous response, it should be fine)
189
+ def process_car_image(image_np_bgr):
190
+ """
191
+ Runs damage and part segmentation (both YOLOv8), calculates overlap, and returns results.
192
+ Returns:
193
+ - combined_image_rgb: Image with both part and damage masks drawn.
194
+ - assignment_text: String describing damage-part assignments.
195
+ """
196
+ if damage_model is None:
197
+ logger.error("Damage YOLOv8 model not available.")
198
+ return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), f"Error: Damage model not loaded ({damage_load_error_msg})"
199
+ if part_model is None:
200
+ logger.error("Part YOLOv8 model not available.")
201
+ return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB), f"Error: Part model not loaded ({part_load_error_msg})"
202
+
203
+ final_assignments = []
204
+ annotated_image_bgr = image_np_bgr.copy()
205
+ img_h, img_w = image_np_bgr.shape[:2]
206
+ logger.info("Starting combined image processing...")
207
+
208
+ try:
209
+ # --- 1. Predict Damages (YOLOv8) ---
210
+ logger.info(f"Running Damage Segmentation (Threshold: {DAMAGE_PRED_THRESHOLD})...")
211
+ damage_results = damage_model.predict(image_np_bgr, verbose=False, device=DEVICE, conf=DAMAGE_PRED_THRESHOLD)
212
+ damage_result = damage_results[0]
213
+ logger.info(f"Found {len(damage_result.boxes)} potential damages.")
214
+ damage_masks_raw = damage_result.masks.data.cpu() if damage_result.masks is not None else torch.empty((0,0,0))
215
+ damage_classes_ids = damage_result.boxes.cls.cpu().numpy().astype(int) if damage_result.boxes is not None else np.array([])
216
+ damage_boxes_xyxy = damage_result.boxes.xyxy.cpu() if damage_result.boxes is not None else torch.empty((0,4))
217
+
218
+ # --- 2. Predict Parts (YOLOv8) ---
219
+ logger.info(f"Running Part Segmentation (Threshold: {PART_PRED_THRESHOLD})...")
220
+ part_results = part_model.predict(image_np_bgr, verbose=False, device=DEVICE, conf=PART_PRED_THRESHOLD)
221
+ part_result = part_results[0]
222
+ logger.info(f"Found {len(part_result.boxes)} potential parts.")
223
+ part_masks_raw = part_result.masks.data.cpu() if part_result.masks is not None else torch.empty((0,0,0))
224
+ part_classes_ids = part_result.boxes.cls.cpu().numpy().astype(int) if part_result.boxes is not None else np.array([])
225
+ part_boxes_xyxy = part_result.boxes.xyxy.cpu() if part_result.boxes is not None else torch.empty((0,4))
226
+
227
+
228
+ # --- 3. Resize Masks if Necessary (Function definition) ---
229
+ def resize_masks(masks_tensor, target_h, target_w):
230
+ if masks_tensor.shape[0] == 0: return np.array([]) # Handle empty tensor
231
+ # Check if resize is needed
232
+ if masks_tensor.shape[1] == target_h and masks_tensor.shape[2] == target_w:
233
+ return masks_tensor.numpy().astype(bool) # No resize needed, convert to numpy bool
234
+
235
+ logger.info(f"Resizing {masks_tensor.shape[0]} masks from {masks_tensor.shape[1:]} to {(target_h, target_w)}")
236
+ masks_np = masks_tensor.numpy() # Convert to numpy first
237
+ resized_masks_list = []
238
+ for i in range(masks_np.shape[0]):
239
+ mask = masks_np[i]
240
+ mask_resized = cv2.resize(mask.astype(np.uint8), (target_w, target_h), interpolation=cv2.INTER_NEAREST)
241
+ resized_masks_list.append(mask_resized.astype(bool))
242
+ return np.array(resized_masks_list) # Return numpy array
243
+
244
+ # --- Perform resizing ---
245
+ damage_masks_np = resize_masks(damage_masks_raw, img_h, img_w)
246
+ part_masks_np = resize_masks(part_masks_raw, img_h, img_w)
247
+
248
+ # --- 4. Calculate Overlap ---
249
+ logger.info("Calculating overlap...")
250
+ if damage_masks_np.shape[0] > 0 and part_masks_np.shape[0] > 0:
251
+ overlap_threshold = 0.4 # Minimum overlap ratio
252
+ # (Overlap calculation logic remains the same)
253
+ for i in range(len(damage_masks_np)):
254
+ damage_mask = damage_masks_np[i]
255
+ damage_class_id = damage_classes_ids[i]
256
+ try: damage_name = DAMAGE_CLASSES[damage_class_id]
257
+ except IndexError: logger.warning(f"Invalid damage ID {damage_class_id}"); continue
258
+
259
+ damage_area = np.sum(damage_mask)
260
+ if damage_area < 10: continue # Skip tiny masks
261
+
262
+ max_overlap = 0
263
+ assigned_part_name = "Unknown / Outside Parts"
264
+ for j in range(len(part_masks_np)):
265
+ part_mask = part_masks_np[j]
266
+ part_class_id = part_classes_ids[j]
267
+ try: part_name = CAR_PART_CLASSES[part_class_id]
268
+ except IndexError: logger.warning(f"Invalid part ID {part_class_id}"); continue
269
+
270
+ intersection = np.logical_and(damage_mask, part_mask)
271
+ overlap_ratio = np.sum(intersection) / damage_area if damage_area > 0 else 0
272
+ if overlap_ratio > max_overlap:
273
+ max_overlap = overlap_ratio
274
+ if max_overlap >= overlap_threshold: assigned_part_name = part_name
275
+
276
+ assignment_desc = f"{damage_name} in {assigned_part_name}"
277
+ if assigned_part_name == "Unknown / Outside Parts": assignment_desc += f" (Overlap < {overlap_threshold*100:.0f}%)"
278
+ final_assignments.append(assignment_desc)
279
+
280
+ # Handle cases with no damages or no parts found after thresholding
281
+ elif damage_masks_np.shape[0] > 0: final_assignments.append(f"{len(damage_masks_np)} damages found, but no parts detected/matched above threshold.")
282
+ elif part_masks_np.shape[0] > 0: final_assignments.append(f"No damages detected above threshold.")
283
+ else: final_assignments.append("No damages or parts detected above thresholds.")
284
+ logger.info(f"Assignment results: {final_assignments}")
285
+
286
+
287
+ # --- 5. Visualization using YOLO Annotator ---
288
+ logger.info("Visualizing results...")
289
+ annotator = Annotator(annotated_image_bgr, line_width=2, example=part_model.names) # Use names from part model
290
+
291
+ # Draw PART masks/boxes (Greenish) - Use original raw masks for drawing coordinates
292
+ if part_result.masks is not None:
293
+ colors_part = [(0, random.randint(100, 200), 0) for _ in part_classes_ids]
294
+ annotator.masks(part_result.masks.data, colors=colors_part, alpha=0.3)
295
+ for box, cls_id in zip(part_boxes_xyxy, part_classes_ids):
296
+ try: label = f"{CAR_PART_CLASSES[cls_id]}"; annotator.box_label(box, label=label, color=(0, 200, 0))
297
+ except IndexError: continue
298
+
299
+ # Draw DAMAGE masks/boxes (Reddish) - Use original raw masks for drawing coordinates
300
+ if damage_result.masks is not None:
301
+ colors_dmg = [(random.randint(100, 200), 0, 0) for _ in damage_classes_ids]
302
+ annotator.masks(damage_result.masks.data, colors=colors_dmg, alpha=0.4)
303
+ for box, cls_id in zip(damage_boxes_xyxy, damage_classes_ids):
304
+ try: label = f"{DAMAGE_CLASSES[cls_id]}"; annotator.box_label(box, label=label, color=(200, 0, 0))
305
+ except IndexError: continue
306
+
307
+ annotated_image_bgr = annotator.result() # Get final BGR image
308
+
309
+ except Exception as e:
310
+ logger.error(f"Error during combined processing: {e}", exc_info=True)
311
+ traceback.print_exc()
312
+ final_assignments.append("Error during segmentation/processing.")
313
+
314
+ # --- Prepare output ---
315
+ assignment_text = "\n".join(final_assignments) if final_assignments else "No specific damage assignments."
316
+ final_output_image_rgb = cv2.cvtColor(annotated_image_bgr, cv2.COLOR_BGR2RGB) # Convert final to RGB
317
+
318
+ return final_output_image_rgb, assignment_text
319
+
320
+
321
+ # --- Main Gradio Function ---
322
+ # (Keep the predict_pipeline function from the previous response, it calls the updated classify_image_clip and process_car_image)
323
+ def predict_pipeline(image_np_input):
324
+ """
325
+ Main pipeline: Classify -> Segment -> Assign -> Visualize
326
+ """
327
+ if image_np_input is None:
328
+ return "Please upload an image.", {}, None, "N/A"
329
+
330
+ logger.info("--- New Request ---")
331
+ start_time = time.time()
332
+ # Convert Gradio input (assumed RGB) to BGR for processing functions
333
+ image_np_bgr = cv2.cvtColor(image_np_input, cv2.COLOR_RGB2BGR)
334
+ image_pil = Image.fromarray(image_np_input) # PIL for CLIP (already RGB)
335
+
336
+ final_output_image = None
337
+ assignment_text = "Processing..."
338
+ classification_result = "Error"
339
+ probabilities = {}
340
+
341
+ # --- Stage 1: CLIP Classification ---
342
+ try:
343
+ classification_result, probabilities = classify_image_clip(image_pil)
344
+ except Exception as e:
345
+ logger.error(f"Error in CLIP stage: {e}", exc_info=True)
346
+ assignment_text = f"Error during classification: {e}"
347
+ # Show original image in case of classification error
348
+ final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
349
+
350
+ # --- Stage 2 & 3: Segmentation and Assignment (if 'Car') ---
351
+ if classification_result == "Car":
352
+ logger.info("Image classified as Car. Running segmentation and assignment...")
353
+ try:
354
+ final_output_image, assignment_text = process_car_image(image_np_bgr)
355
+ except Exception as e:
356
+ logger.error(f"Error in segmentation/assignment stage: {e}", exc_info=True)
357
+ assignment_text = f"Error during segmentation/assignment: {e}"
358
+ # Show original image in case of processing error
359
+ final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
360
+
361
+ elif classification_result == "Not Car":
362
+ logger.info("Image classified as Not Car.")
363
+ final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB) # Show original
364
+ assignment_text = "Image classified as Not Car."
365
+ # Else: Handle CLIP error case (already logged, show original image)
366
+ elif final_output_image is None: # Ensure image is set if CLIP error occurred
367
+ final_output_image = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
368
+
369
+
370
+ # --- Cleanup ---
371
+ gc.collect()
372
+ if torch.cuda.is_available():
373
+ torch.cuda.empty_cache()
374
+
375
+ end_time = time.time()
376
+ logger.info(f"Total processing time: {end_time - start_time:.2f} seconds.")
377
+ # Return all results
378
+ return classification_result, probabilities, final_output_image, assignment_text
379
+
380
+
381
+ # --- Gradio Interface ---
382
+ # (Keep the Gradio Interface definition from the previous response)
383
+ logger.info("Setting up Gradio interface...")
384
+ title = "🚗 Car Damage Analysis Pipeline (CLIP + YOLOv8 x2)"
385
+ # ... (rest of Gradio interface setup: description, examples, inputs, outputs, iface.launch()) ...
386
+
387
+ # Define Inputs and Outputs
388
+ input_image = gr.Image(type="numpy", label="Upload Car Image") # Input numpy array (RGB from Gradio)
389
+ output_classification = gr.Textbox(label="1. Classification Result")
390
+ output_probabilities = gr.Label(label="Classification Probabilities") # Label is good for dicts
391
+ output_image_display = gr.Image(type="numpy", label="3. Segmentation Visualization") # Output numpy array (RGB)
392
+ output_assignment = gr.Textbox(label="2. Damage Assignments", lines=5, interactive=False)
393
+
394
+ # Launch the interface
395
+ iface = gr.Interface(
396
+ fn=predict_pipeline,
397
+ inputs=input_image,
398
+ outputs=[output_classification, output_probabilities, output_image_display, output_assignment],
399
+ title=title,
400
+ # description=description, # Add description back if needed
401
+ # examples=examples, # Add examples back if needed
402
+ allow_flagging="never"
403
+ )
404
+
405
+ if __name__ == "__main__":
406
+ logger.info("Launching Gradio app...")
407
+ iface.launch()