import gradio as gr import torch import clip from PIL import Image import numpy as np import os import cv2 import gc # Garbage collector import logging # --- Detectron2 Imports --- from detectron2 import model_zoo from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer, ColorMode from detectron2.data import MetadataCatalog # --- Setup Logging --- # Reduce default Detectron2 logging noise if needed logging.getLogger("detectron2").setLevel(logging.WARNING) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- Constants --- # Damage segmentation classes (MUST match training order) DAMAGE_CLASSES = ['Cracked', 'Scratch', 'Flaking', 'Broken part', 'Corrosion', 'Dent', 'Paint chip', 'Missing part'] NUM_DAMAGE_CLASSES = len(DAMAGE_CLASSES) # Paths within the Hugging Face Space repository CLIP_TEXT_FEATURES_PATH = "./clip_text_features.pt" # CLIP_MODEL_WEIGHTS_PATH = "./clip_model/clip_vit_b16.pth" # Alt: Load state dict MASKRCNN_MODEL_WEIGHTS_PATH = "./model_best.pth" # Your best Mask R-CNN weights MASKRCNN_BASE_CONFIG = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml" # Prediction Thresholds DAMAGE_PRED_THRESHOLD = 0.4 # Threshold for showing damage masks # --- Device Setup --- if torch.cuda.is_available(): DEVICE = "cuda" logger.info("CUDA available, using GPU.") else: DEVICE = "cpu" logger.info("CUDA not available, using CPU.") # --- MODEL LOADING (Load models globally ONCE on startup) --- print("Loading models...") # --- Load CLIP Model --- try: logger.info("Loading CLIP model...") clip_model, clip_preprocess = clip.load("ViT-B/16", device=DEVICE) # Optional: Load state dict if you saved it manually # clip_model.load_state_dict(torch.load(CLIP_MODEL_WEIGHTS_PATH, map_location=DEVICE)) clip_model.eval() logger.info("CLIP model loaded.") logger.info(f"Loading CLIP text features from {CLIP_TEXT_FEATURES_PATH}...") if not os.path.exists(CLIP_TEXT_FEATURES_PATH): raise FileNotFoundError(f"CLIP text features not found at {CLIP_TEXT_FEATURES_PATH}. Make sure it's uploaded.") clip_text_features = torch.load(CLIP_TEXT_FEATURES_PATH, map_location=DEVICE) logger.info("CLIP text features loaded.") except Exception as e: logger.error(f"Error loading CLIP model or features: {e}", exc_info=True) clip_model = None # Set to None if loading fails # --- Load Mask R-CNN Model --- maskrcnn_predictor = None maskrcnn_metadata = None try: logger.info("Setting up Mask R-CNN configuration...") cfg_mrcnn = get_cfg() cfg_mrcnn.merge_from_file(model_zoo.get_config_file(MASKRCNN_BASE_CONFIG)) # Manual configuration based on your previous working setup cfg_mrcnn.defrost() cfg_mrcnn.MODEL.WEIGHTS = MASKRCNN_MODEL_WEIGHTS_PATH if not os.path.exists(MASKRCNN_MODEL_WEIGHTS_PATH): raise FileNotFoundError(f"Mask R-CNN weights not found at {MASKRCNN_MODEL_WEIGHTS_PATH}. Make sure it's uploaded.") cfg_mrcnn.MODEL.ROI_HEADS.NUM_CLASSES = NUM_DAMAGE_CLASSES cfg_mrcnn.MODEL.ROI_HEADS.SCORE_THRESH_TEST = DAMAGE_PRED_THRESHOLD cfg_mrcnn.MODEL.DEVICE = DEVICE # Apply necessary norm settings if changed during training cfg_mrcnn.MODEL.FPN.NORM = "GN" cfg_mrcnn.MODEL.ROI_HEADS.NORM = "GN" cfg_mrcnn.freeze() logger.info("Mask R-CNN configuration loaded.") logger.info("Creating Mask R-CNN predictor...") maskrcnn_predictor = DefaultPredictor(cfg_mrcnn) logger.info("Mask R-CNN predictor created.") # Setup metadata for visualization metadata_name = "car_damage_inference_app" if metadata_name not in MetadataCatalog.list(): MetadataCatalog.get(metadata_name).set(thing_classes=DAMAGE_CLASSES) maskrcnn_metadata = MetadataCatalog.get(metadata_name) logger.info("Mask R-CNN metadata prepared.") except Exception as e: logger.error(f"Error setting up Mask R-CNN predictor: {e}", exc_info=True) maskrcnn_predictor = None # Set to None if loading fails print("Model loading complete.") # --- Prediction Functions --- def classify_image_clip(image_pil): """Classifies image using CLIP. Returns label and probabilities.""" if clip_model is None or clip_text_features is None: return "Error: CLIP Model Not Loaded", {"Error": 1.0} try: # Basic preprocessing (CLIP handles resizing) image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE) with torch.no_grad(): image_features = clip_model.encode_image(image_input) image_features /= image_features.norm(dim=-1, keepdim=True) # Calculate similarity logit_scale = clip_model.logit_scale.exp() similarity = (image_features @ clip_text_features.T) * logit_scale probs = similarity.softmax(dim=-1).squeeze().cpu() # Move probs to CPU # Get prediction # Index 0 = Car, Index 1 = Not Car (based on your feature creation) predicted_label = "Car" if probs[0] > probs[1] else "Not Car" prob_dict = {"Car": f"{probs[0]:.3f}", "Not Car": f"{probs[1]:.3f}"} return predicted_label, prob_dict except Exception as e: logger.error(f"Error during CLIP prediction: {e}", exc_info=True) return "Error during CLIP processing", {"Error": 1.0} def segment_damage(image_np_bgr): """Segments damage using Mask R-CNN. Returns visualized image.""" if maskrcnn_predictor is None or maskrcnn_metadata is None: logger.error("Mask R-CNN predictor or metadata not available.") # Return original image with an error message? # For simplicity, return None, Gradio interface might handle it better return None try: logger.info("Running Mask R-CNN inference...") outputs = maskrcnn_predictor(image_np_bgr) # Predictor expects BGR numpy array predictions = outputs["instances"].to("cpu") logger.info(f"Mask R-CNN detected {len(predictions)} instances.") # Visualize v = Visualizer(image_np_bgr[:, :, ::-1], # Convert BGR to RGB for Visualizer metadata=maskrcnn_metadata, scale=0.8, instance_mode=ColorMode.SEGMENTATION) # Draw predictions only if any exist if len(predictions) > 0: out = v.draw_instance_predictions(predictions) output_image_np_rgb = out.get_image() # Visualizer gives RGB else: # If no detections, return the original image (converted to RGB) logger.info("No damage instances detected above threshold.") output_image_np_rgb = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB) return output_image_np_rgb except Exception as e: logger.error(f"Error during Mask R-CNN prediction/visualization: {e}", exc_info=True) # Return original image on error? return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB) # --- Main Gradio Function --- def predict_pipeline(image_np_input): """ Takes numpy image input, runs CLIP, then optionally Mask R-CNN. Returns: classification text, probability dict, output image (numpy RGB) """ if image_np_input is None: return "Please upload an image.", {}, None logger.info("Received image for processing...") # --- Stage 1: CLIP Classification --- # Convert BGR numpy array from Gradio to PIL RGB for CLIP image_pil = Image.fromarray(cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB)) classification_result, probabilities = classify_image_clip(image_pil) logger.info(f"CLIP Result: {classification_result}, Probs: {probabilities}") output_image = None # Initialize output image # --- Stage 2: Damage Segmentation (if classified as 'Car') --- if classification_result == "Car": logger.info("Image classified as Car. Proceeding to damage segmentation...") # Pass the original BGR numpy array to the segmentation function output_image = segment_damage(image_np_input) if output_image is None: # Handle potential error in segmentation logger.warning("Damage segmentation returned None. Displaying original image.") output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB) elif classification_result == "Not Car": logger.info("Image classified as Not Car. Skipping damage segmentation.") # Show the original image if it's not a car output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB) else: # Handle CLIP error case logger.error("CLIP classification failed.") output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB) # --- Cleanup --- gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return classification_result, probabilities, output_image # --- Gradio Interface --- logger.info("Setting up Gradio interface...") title = "Car Damage Segmentation Pipeline" description = """ Upload an image. 1. The first model (CLIP) classifies if it's a car. 2. If it's a car, the second model (Mask R-CNN) segments potential damages. """ examples = [ # Add paths to example images if you upload them to the repo # ["./example_car_damaged.jpg"], # ["./example_car_ok.jpg"], # ["./example_not_car.jpg"], ] # Define Inputs and Outputs input_image = gr.Image(type="numpy", label="Upload Car Image") output_classification = gr.Textbox(label="Classification Result") output_probabilities = gr.Label(label="Class Probabilities") # Label is good for dicts output_segmentation = gr.Image(type="numpy", label="Damage Segmentation / Original Image") # Launch the interface iface = gr.Interface( fn=predict_pipeline, inputs=input_image, outputs=[output_classification, output_probabilities, output_segmentation], title=title, description=description, examples=examples, allow_flagging="never" # Disable flagging unless needed ) if __name__ == "__main__": logger.info("Launching Gradio app...") iface.launch() # share=True to create public link (use with caution)