Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import clip | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import cv2 | |
| import gc # Garbage collector | |
| import logging | |
| # --- Detectron2 Imports --- | |
| from detectron2 import model_zoo | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2.utils.visualizer import Visualizer, ColorMode | |
| from detectron2.data import MetadataCatalog | |
| # --- Setup Logging --- | |
| # Reduce default Detectron2 logging noise if needed | |
| logging.getLogger("detectron2").setLevel(logging.WARNING) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --- Constants --- | |
| # Damage segmentation classes (MUST match training order) | |
| DAMAGE_CLASSES = ['Cracked', 'Scratch', 'Flaking', 'Broken part', 'Corrosion', 'Dent', 'Paint chip', 'Missing part'] | |
| NUM_DAMAGE_CLASSES = len(DAMAGE_CLASSES) | |
| # Paths within the Hugging Face Space repository | |
| CLIP_TEXT_FEATURES_PATH = "./clip_text_features.pt" | |
| # CLIP_MODEL_WEIGHTS_PATH = "./clip_model/clip_vit_b16.pth" # Alt: Load state dict | |
| MASKRCNN_MODEL_WEIGHTS_PATH = "./model_best.pth" # Your best Mask R-CNN weights | |
| MASKRCNN_BASE_CONFIG = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml" | |
| # Prediction Thresholds | |
| DAMAGE_PRED_THRESHOLD = 0.4 # Threshold for showing damage masks | |
| # --- Device Setup --- | |
| if torch.cuda.is_available(): | |
| DEVICE = "cuda" | |
| logger.info("CUDA available, using GPU.") | |
| else: | |
| DEVICE = "cpu" | |
| logger.info("CUDA not available, using CPU.") | |
| # --- MODEL LOADING (Load models globally ONCE on startup) --- | |
| print("Loading models...") | |
| # --- Load CLIP Model --- | |
| try: | |
| logger.info("Loading CLIP model...") | |
| clip_model, clip_preprocess = clip.load("ViT-B/16", device=DEVICE) | |
| # Optional: Load state dict if you saved it manually | |
| # clip_model.load_state_dict(torch.load(CLIP_MODEL_WEIGHTS_PATH, map_location=DEVICE)) | |
| clip_model.eval() | |
| logger.info("CLIP model loaded.") | |
| logger.info(f"Loading CLIP text features from {CLIP_TEXT_FEATURES_PATH}...") | |
| if not os.path.exists(CLIP_TEXT_FEATURES_PATH): | |
| raise FileNotFoundError(f"CLIP text features not found at {CLIP_TEXT_FEATURES_PATH}. Make sure it's uploaded.") | |
| clip_text_features = torch.load(CLIP_TEXT_FEATURES_PATH, map_location=DEVICE) | |
| logger.info("CLIP text features loaded.") | |
| except Exception as e: | |
| logger.error(f"Error loading CLIP model or features: {e}", exc_info=True) | |
| clip_model = None # Set to None if loading fails | |
| # --- Load Mask R-CNN Model --- | |
| maskrcnn_predictor = None | |
| maskrcnn_metadata = None | |
| try: | |
| logger.info("Setting up Mask R-CNN configuration...") | |
| cfg_mrcnn = get_cfg() | |
| cfg_mrcnn.merge_from_file(model_zoo.get_config_file(MASKRCNN_BASE_CONFIG)) | |
| # Manual configuration based on your previous working setup | |
| cfg_mrcnn.defrost() | |
| cfg_mrcnn.MODEL.WEIGHTS = MASKRCNN_MODEL_WEIGHTS_PATH | |
| if not os.path.exists(MASKRCNN_MODEL_WEIGHTS_PATH): | |
| raise FileNotFoundError(f"Mask R-CNN weights not found at {MASKRCNN_MODEL_WEIGHTS_PATH}. Make sure it's uploaded.") | |
| cfg_mrcnn.MODEL.ROI_HEADS.NUM_CLASSES = NUM_DAMAGE_CLASSES | |
| cfg_mrcnn.MODEL.ROI_HEADS.SCORE_THRESH_TEST = DAMAGE_PRED_THRESHOLD | |
| cfg_mrcnn.MODEL.DEVICE = DEVICE | |
| # Apply necessary norm settings if changed during training | |
| cfg_mrcnn.MODEL.FPN.NORM = "GN" | |
| cfg_mrcnn.MODEL.ROI_HEADS.NORM = "GN" | |
| cfg_mrcnn.freeze() | |
| logger.info("Mask R-CNN configuration loaded.") | |
| logger.info("Creating Mask R-CNN predictor...") | |
| maskrcnn_predictor = DefaultPredictor(cfg_mrcnn) | |
| logger.info("Mask R-CNN predictor created.") | |
| # Setup metadata for visualization | |
| metadata_name = "car_damage_inference_app" | |
| if metadata_name not in MetadataCatalog.list(): | |
| MetadataCatalog.get(metadata_name).set(thing_classes=DAMAGE_CLASSES) | |
| maskrcnn_metadata = MetadataCatalog.get(metadata_name) | |
| logger.info("Mask R-CNN metadata prepared.") | |
| except Exception as e: | |
| logger.error(f"Error setting up Mask R-CNN predictor: {e}", exc_info=True) | |
| maskrcnn_predictor = None # Set to None if loading fails | |
| print("Model loading complete.") | |
| # --- Prediction Functions --- | |
| def classify_image_clip(image_pil): | |
| """Classifies image using CLIP. Returns label and probabilities.""" | |
| if clip_model is None or clip_text_features is None: | |
| return "Error: CLIP Model Not Loaded", {"Error": 1.0} | |
| try: | |
| # Basic preprocessing (CLIP handles resizing) | |
| image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| image_features = clip_model.encode_image(image_input) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| # Calculate similarity | |
| logit_scale = clip_model.logit_scale.exp() | |
| similarity = (image_features @ clip_text_features.T) * logit_scale | |
| probs = similarity.softmax(dim=-1).squeeze().cpu() # Move probs to CPU | |
| # Get prediction | |
| # Index 0 = Car, Index 1 = Not Car (based on your feature creation) | |
| predicted_label = "Car" if probs[0] > probs[1] else "Not Car" | |
| prob_dict = {"Car": f"{probs[0]:.3f}", "Not Car": f"{probs[1]:.3f}"} | |
| return predicted_label, prob_dict | |
| except Exception as e: | |
| logger.error(f"Error during CLIP prediction: {e}", exc_info=True) | |
| return "Error during CLIP processing", {"Error": 1.0} | |
| def segment_damage(image_np_bgr): | |
| """Segments damage using Mask R-CNN. Returns visualized image.""" | |
| if maskrcnn_predictor is None or maskrcnn_metadata is None: | |
| logger.error("Mask R-CNN predictor or metadata not available.") | |
| # Return original image with an error message? | |
| # For simplicity, return None, Gradio interface might handle it better | |
| return None | |
| try: | |
| logger.info("Running Mask R-CNN inference...") | |
| outputs = maskrcnn_predictor(image_np_bgr) # Predictor expects BGR numpy array | |
| predictions = outputs["instances"].to("cpu") | |
| logger.info(f"Mask R-CNN detected {len(predictions)} instances.") | |
| # Visualize | |
| v = Visualizer(image_np_bgr[:, :, ::-1], # Convert BGR to RGB for Visualizer | |
| metadata=maskrcnn_metadata, | |
| scale=0.8, | |
| instance_mode=ColorMode.SEGMENTATION) | |
| # Draw predictions only if any exist | |
| if len(predictions) > 0: | |
| out = v.draw_instance_predictions(predictions) | |
| output_image_np_rgb = out.get_image() # Visualizer gives RGB | |
| else: | |
| # If no detections, return the original image (converted to RGB) | |
| logger.info("No damage instances detected above threshold.") | |
| output_image_np_rgb = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB) | |
| return output_image_np_rgb | |
| except Exception as e: | |
| logger.error(f"Error during Mask R-CNN prediction/visualization: {e}", exc_info=True) | |
| # Return original image on error? | |
| return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB) | |
| # --- Main Gradio Function --- | |
| def predict_pipeline(image_np_input): | |
| """ | |
| Takes numpy image input, runs CLIP, then optionally Mask R-CNN. | |
| Returns: classification text, probability dict, output image (numpy RGB) | |
| """ | |
| if image_np_input is None: | |
| return "Please upload an image.", {}, None | |
| logger.info("Received image for processing...") | |
| # --- Stage 1: CLIP Classification --- | |
| # Convert BGR numpy array from Gradio to PIL RGB for CLIP | |
| image_pil = Image.fromarray(cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB)) | |
| classification_result, probabilities = classify_image_clip(image_pil) | |
| logger.info(f"CLIP Result: {classification_result}, Probs: {probabilities}") | |
| output_image = None # Initialize output image | |
| # --- Stage 2: Damage Segmentation (if classified as 'Car') --- | |
| if classification_result == "Car": | |
| logger.info("Image classified as Car. Proceeding to damage segmentation...") | |
| # Pass the original BGR numpy array to the segmentation function | |
| output_image = segment_damage(image_np_input) | |
| if output_image is None: # Handle potential error in segmentation | |
| logger.warning("Damage segmentation returned None. Displaying original image.") | |
| output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB) | |
| elif classification_result == "Not Car": | |
| logger.info("Image classified as Not Car. Skipping damage segmentation.") | |
| # Show the original image if it's not a car | |
| output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB) | |
| else: # Handle CLIP error case | |
| logger.error("CLIP classification failed.") | |
| output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB) | |
| # --- Cleanup --- | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return classification_result, probabilities, output_image | |
| # --- Gradio Interface --- | |
| logger.info("Setting up Gradio interface...") | |
| title = "Car Damage Segmentation Pipeline" | |
| description = """ | |
| Upload an image. | |
| 1. The first model (CLIP) classifies if it's a car. | |
| 2. If it's a car, the second model (Mask R-CNN) segments potential damages. | |
| """ | |
| examples = [ | |
| # Add paths to example images if you upload them to the repo | |
| # ["./example_car_damaged.jpg"], | |
| # ["./example_car_ok.jpg"], | |
| # ["./example_not_car.jpg"], | |
| ] | |
| # Define Inputs and Outputs | |
| input_image = gr.Image(type="numpy", label="Upload Car Image") | |
| output_classification = gr.Textbox(label="Classification Result") | |
| output_probabilities = gr.Label(label="Class Probabilities") # Label is good for dicts | |
| output_segmentation = gr.Image(type="numpy", label="Damage Segmentation / Original Image") | |
| # Launch the interface | |
| iface = gr.Interface( | |
| fn=predict_pipeline, | |
| inputs=input_image, | |
| outputs=[output_classification, output_probabilities, output_segmentation], | |
| title=title, | |
| description=description, | |
| examples=examples, | |
| allow_flagging="never" # Disable flagging unless needed | |
| ) | |
| if __name__ == "__main__": | |
| logger.info("Launching Gradio app...") | |
| iface.launch() # share=True to create public link (use with caution) |