Spaces:
Build error
Build error
File size: 10,444 Bytes
e56a7d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
import gradio as gr
import torch
import clip
from PIL import Image
import numpy as np
import os
import cv2
import gc # Garbage collector
import logging
# --- Detectron2 Imports ---
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog
# --- Setup Logging ---
# Reduce default Detectron2 logging noise if needed
logging.getLogger("detectron2").setLevel(logging.WARNING)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Constants ---
# Damage segmentation classes (MUST match training order)
DAMAGE_CLASSES = ['Cracked', 'Scratch', 'Flaking', 'Broken part', 'Corrosion', 'Dent', 'Paint chip', 'Missing part']
NUM_DAMAGE_CLASSES = len(DAMAGE_CLASSES)
# Paths within the Hugging Face Space repository
CLIP_TEXT_FEATURES_PATH = "./clip_text_features.pt"
# CLIP_MODEL_WEIGHTS_PATH = "./clip_model/clip_vit_b16.pth" # Alt: Load state dict
MASKRCNN_MODEL_WEIGHTS_PATH = "./model_best.pth" # Your best Mask R-CNN weights
MASKRCNN_BASE_CONFIG = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"
# Prediction Thresholds
DAMAGE_PRED_THRESHOLD = 0.4 # Threshold for showing damage masks
# --- Device Setup ---
if torch.cuda.is_available():
DEVICE = "cuda"
logger.info("CUDA available, using GPU.")
else:
DEVICE = "cpu"
logger.info("CUDA not available, using CPU.")
# --- MODEL LOADING (Load models globally ONCE on startup) ---
print("Loading models...")
# --- Load CLIP Model ---
try:
logger.info("Loading CLIP model...")
clip_model, clip_preprocess = clip.load("ViT-B/16", device=DEVICE)
# Optional: Load state dict if you saved it manually
# clip_model.load_state_dict(torch.load(CLIP_MODEL_WEIGHTS_PATH, map_location=DEVICE))
clip_model.eval()
logger.info("CLIP model loaded.")
logger.info(f"Loading CLIP text features from {CLIP_TEXT_FEATURES_PATH}...")
if not os.path.exists(CLIP_TEXT_FEATURES_PATH):
raise FileNotFoundError(f"CLIP text features not found at {CLIP_TEXT_FEATURES_PATH}. Make sure it's uploaded.")
clip_text_features = torch.load(CLIP_TEXT_FEATURES_PATH, map_location=DEVICE)
logger.info("CLIP text features loaded.")
except Exception as e:
logger.error(f"Error loading CLIP model or features: {e}", exc_info=True)
clip_model = None # Set to None if loading fails
# --- Load Mask R-CNN Model ---
maskrcnn_predictor = None
maskrcnn_metadata = None
try:
logger.info("Setting up Mask R-CNN configuration...")
cfg_mrcnn = get_cfg()
cfg_mrcnn.merge_from_file(model_zoo.get_config_file(MASKRCNN_BASE_CONFIG))
# Manual configuration based on your previous working setup
cfg_mrcnn.defrost()
cfg_mrcnn.MODEL.WEIGHTS = MASKRCNN_MODEL_WEIGHTS_PATH
if not os.path.exists(MASKRCNN_MODEL_WEIGHTS_PATH):
raise FileNotFoundError(f"Mask R-CNN weights not found at {MASKRCNN_MODEL_WEIGHTS_PATH}. Make sure it's uploaded.")
cfg_mrcnn.MODEL.ROI_HEADS.NUM_CLASSES = NUM_DAMAGE_CLASSES
cfg_mrcnn.MODEL.ROI_HEADS.SCORE_THRESH_TEST = DAMAGE_PRED_THRESHOLD
cfg_mrcnn.MODEL.DEVICE = DEVICE
# Apply necessary norm settings if changed during training
cfg_mrcnn.MODEL.FPN.NORM = "GN"
cfg_mrcnn.MODEL.ROI_HEADS.NORM = "GN"
cfg_mrcnn.freeze()
logger.info("Mask R-CNN configuration loaded.")
logger.info("Creating Mask R-CNN predictor...")
maskrcnn_predictor = DefaultPredictor(cfg_mrcnn)
logger.info("Mask R-CNN predictor created.")
# Setup metadata for visualization
metadata_name = "car_damage_inference_app"
if metadata_name not in MetadataCatalog.list():
MetadataCatalog.get(metadata_name).set(thing_classes=DAMAGE_CLASSES)
maskrcnn_metadata = MetadataCatalog.get(metadata_name)
logger.info("Mask R-CNN metadata prepared.")
except Exception as e:
logger.error(f"Error setting up Mask R-CNN predictor: {e}", exc_info=True)
maskrcnn_predictor = None # Set to None if loading fails
print("Model loading complete.")
# --- Prediction Functions ---
def classify_image_clip(image_pil):
"""Classifies image using CLIP. Returns label and probabilities."""
if clip_model is None or clip_text_features is None:
return "Error: CLIP Model Not Loaded", {"Error": 1.0}
try:
# Basic preprocessing (CLIP handles resizing)
image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
with torch.no_grad():
image_features = clip_model.encode_image(image_input)
image_features /= image_features.norm(dim=-1, keepdim=True)
# Calculate similarity
logit_scale = clip_model.logit_scale.exp()
similarity = (image_features @ clip_text_features.T) * logit_scale
probs = similarity.softmax(dim=-1).squeeze().cpu() # Move probs to CPU
# Get prediction
# Index 0 = Car, Index 1 = Not Car (based on your feature creation)
predicted_label = "Car" if probs[0] > probs[1] else "Not Car"
prob_dict = {"Car": f"{probs[0]:.3f}", "Not Car": f"{probs[1]:.3f}"}
return predicted_label, prob_dict
except Exception as e:
logger.error(f"Error during CLIP prediction: {e}", exc_info=True)
return "Error during CLIP processing", {"Error": 1.0}
def segment_damage(image_np_bgr):
"""Segments damage using Mask R-CNN. Returns visualized image."""
if maskrcnn_predictor is None or maskrcnn_metadata is None:
logger.error("Mask R-CNN predictor or metadata not available.")
# Return original image with an error message?
# For simplicity, return None, Gradio interface might handle it better
return None
try:
logger.info("Running Mask R-CNN inference...")
outputs = maskrcnn_predictor(image_np_bgr) # Predictor expects BGR numpy array
predictions = outputs["instances"].to("cpu")
logger.info(f"Mask R-CNN detected {len(predictions)} instances.")
# Visualize
v = Visualizer(image_np_bgr[:, :, ::-1], # Convert BGR to RGB for Visualizer
metadata=maskrcnn_metadata,
scale=0.8,
instance_mode=ColorMode.SEGMENTATION)
# Draw predictions only if any exist
if len(predictions) > 0:
out = v.draw_instance_predictions(predictions)
output_image_np_rgb = out.get_image() # Visualizer gives RGB
else:
# If no detections, return the original image (converted to RGB)
logger.info("No damage instances detected above threshold.")
output_image_np_rgb = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
return output_image_np_rgb
except Exception as e:
logger.error(f"Error during Mask R-CNN prediction/visualization: {e}", exc_info=True)
# Return original image on error?
return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
# --- Main Gradio Function ---
def predict_pipeline(image_np_input):
"""
Takes numpy image input, runs CLIP, then optionally Mask R-CNN.
Returns: classification text, probability dict, output image (numpy RGB)
"""
if image_np_input is None:
return "Please upload an image.", {}, None
logger.info("Received image for processing...")
# --- Stage 1: CLIP Classification ---
# Convert BGR numpy array from Gradio to PIL RGB for CLIP
image_pil = Image.fromarray(cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB))
classification_result, probabilities = classify_image_clip(image_pil)
logger.info(f"CLIP Result: {classification_result}, Probs: {probabilities}")
output_image = None # Initialize output image
# --- Stage 2: Damage Segmentation (if classified as 'Car') ---
if classification_result == "Car":
logger.info("Image classified as Car. Proceeding to damage segmentation...")
# Pass the original BGR numpy array to the segmentation function
output_image = segment_damage(image_np_input)
if output_image is None: # Handle potential error in segmentation
logger.warning("Damage segmentation returned None. Displaying original image.")
output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB)
elif classification_result == "Not Car":
logger.info("Image classified as Not Car. Skipping damage segmentation.")
# Show the original image if it's not a car
output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB)
else: # Handle CLIP error case
logger.error("CLIP classification failed.")
output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB)
# --- Cleanup ---
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return classification_result, probabilities, output_image
# --- Gradio Interface ---
logger.info("Setting up Gradio interface...")
title = "Car Damage Segmentation Pipeline"
description = """
Upload an image.
1. The first model (CLIP) classifies if it's a car.
2. If it's a car, the second model (Mask R-CNN) segments potential damages.
"""
examples = [
# Add paths to example images if you upload them to the repo
# ["./example_car_damaged.jpg"],
# ["./example_car_ok.jpg"],
# ["./example_not_car.jpg"],
]
# Define Inputs and Outputs
input_image = gr.Image(type="numpy", label="Upload Car Image")
output_classification = gr.Textbox(label="Classification Result")
output_probabilities = gr.Label(label="Class Probabilities") # Label is good for dicts
output_segmentation = gr.Image(type="numpy", label="Damage Segmentation / Original Image")
# Launch the interface
iface = gr.Interface(
fn=predict_pipeline,
inputs=input_image,
outputs=[output_classification, output_probabilities, output_segmentation],
title=title,
description=description,
examples=examples,
allow_flagging="never" # Disable flagging unless needed
)
if __name__ == "__main__":
logger.info("Launching Gradio app...")
iface.launch() # share=True to create public link (use with caution) |