ishanprogs's picture
Upload 4 files
e56a7d9 verified
import gradio as gr
import torch
import clip
from PIL import Image
import numpy as np
import os
import cv2
import gc # Garbage collector
import logging
# --- Detectron2 Imports ---
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog
# --- Setup Logging ---
# Reduce default Detectron2 logging noise if needed
logging.getLogger("detectron2").setLevel(logging.WARNING)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Constants ---
# Damage segmentation classes (MUST match training order)
DAMAGE_CLASSES = ['Cracked', 'Scratch', 'Flaking', 'Broken part', 'Corrosion', 'Dent', 'Paint chip', 'Missing part']
NUM_DAMAGE_CLASSES = len(DAMAGE_CLASSES)
# Paths within the Hugging Face Space repository
CLIP_TEXT_FEATURES_PATH = "./clip_text_features.pt"
# CLIP_MODEL_WEIGHTS_PATH = "./clip_model/clip_vit_b16.pth" # Alt: Load state dict
MASKRCNN_MODEL_WEIGHTS_PATH = "./model_best.pth" # Your best Mask R-CNN weights
MASKRCNN_BASE_CONFIG = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"
# Prediction Thresholds
DAMAGE_PRED_THRESHOLD = 0.4 # Threshold for showing damage masks
# --- Device Setup ---
if torch.cuda.is_available():
DEVICE = "cuda"
logger.info("CUDA available, using GPU.")
else:
DEVICE = "cpu"
logger.info("CUDA not available, using CPU.")
# --- MODEL LOADING (Load models globally ONCE on startup) ---
print("Loading models...")
# --- Load CLIP Model ---
try:
logger.info("Loading CLIP model...")
clip_model, clip_preprocess = clip.load("ViT-B/16", device=DEVICE)
# Optional: Load state dict if you saved it manually
# clip_model.load_state_dict(torch.load(CLIP_MODEL_WEIGHTS_PATH, map_location=DEVICE))
clip_model.eval()
logger.info("CLIP model loaded.")
logger.info(f"Loading CLIP text features from {CLIP_TEXT_FEATURES_PATH}...")
if not os.path.exists(CLIP_TEXT_FEATURES_PATH):
raise FileNotFoundError(f"CLIP text features not found at {CLIP_TEXT_FEATURES_PATH}. Make sure it's uploaded.")
clip_text_features = torch.load(CLIP_TEXT_FEATURES_PATH, map_location=DEVICE)
logger.info("CLIP text features loaded.")
except Exception as e:
logger.error(f"Error loading CLIP model or features: {e}", exc_info=True)
clip_model = None # Set to None if loading fails
# --- Load Mask R-CNN Model ---
maskrcnn_predictor = None
maskrcnn_metadata = None
try:
logger.info("Setting up Mask R-CNN configuration...")
cfg_mrcnn = get_cfg()
cfg_mrcnn.merge_from_file(model_zoo.get_config_file(MASKRCNN_BASE_CONFIG))
# Manual configuration based on your previous working setup
cfg_mrcnn.defrost()
cfg_mrcnn.MODEL.WEIGHTS = MASKRCNN_MODEL_WEIGHTS_PATH
if not os.path.exists(MASKRCNN_MODEL_WEIGHTS_PATH):
raise FileNotFoundError(f"Mask R-CNN weights not found at {MASKRCNN_MODEL_WEIGHTS_PATH}. Make sure it's uploaded.")
cfg_mrcnn.MODEL.ROI_HEADS.NUM_CLASSES = NUM_DAMAGE_CLASSES
cfg_mrcnn.MODEL.ROI_HEADS.SCORE_THRESH_TEST = DAMAGE_PRED_THRESHOLD
cfg_mrcnn.MODEL.DEVICE = DEVICE
# Apply necessary norm settings if changed during training
cfg_mrcnn.MODEL.FPN.NORM = "GN"
cfg_mrcnn.MODEL.ROI_HEADS.NORM = "GN"
cfg_mrcnn.freeze()
logger.info("Mask R-CNN configuration loaded.")
logger.info("Creating Mask R-CNN predictor...")
maskrcnn_predictor = DefaultPredictor(cfg_mrcnn)
logger.info("Mask R-CNN predictor created.")
# Setup metadata for visualization
metadata_name = "car_damage_inference_app"
if metadata_name not in MetadataCatalog.list():
MetadataCatalog.get(metadata_name).set(thing_classes=DAMAGE_CLASSES)
maskrcnn_metadata = MetadataCatalog.get(metadata_name)
logger.info("Mask R-CNN metadata prepared.")
except Exception as e:
logger.error(f"Error setting up Mask R-CNN predictor: {e}", exc_info=True)
maskrcnn_predictor = None # Set to None if loading fails
print("Model loading complete.")
# --- Prediction Functions ---
def classify_image_clip(image_pil):
"""Classifies image using CLIP. Returns label and probabilities."""
if clip_model is None or clip_text_features is None:
return "Error: CLIP Model Not Loaded", {"Error": 1.0}
try:
# Basic preprocessing (CLIP handles resizing)
image_input = clip_preprocess(image_pil).unsqueeze(0).to(DEVICE)
with torch.no_grad():
image_features = clip_model.encode_image(image_input)
image_features /= image_features.norm(dim=-1, keepdim=True)
# Calculate similarity
logit_scale = clip_model.logit_scale.exp()
similarity = (image_features @ clip_text_features.T) * logit_scale
probs = similarity.softmax(dim=-1).squeeze().cpu() # Move probs to CPU
# Get prediction
# Index 0 = Car, Index 1 = Not Car (based on your feature creation)
predicted_label = "Car" if probs[0] > probs[1] else "Not Car"
prob_dict = {"Car": f"{probs[0]:.3f}", "Not Car": f"{probs[1]:.3f}"}
return predicted_label, prob_dict
except Exception as e:
logger.error(f"Error during CLIP prediction: {e}", exc_info=True)
return "Error during CLIP processing", {"Error": 1.0}
def segment_damage(image_np_bgr):
"""Segments damage using Mask R-CNN. Returns visualized image."""
if maskrcnn_predictor is None or maskrcnn_metadata is None:
logger.error("Mask R-CNN predictor or metadata not available.")
# Return original image with an error message?
# For simplicity, return None, Gradio interface might handle it better
return None
try:
logger.info("Running Mask R-CNN inference...")
outputs = maskrcnn_predictor(image_np_bgr) # Predictor expects BGR numpy array
predictions = outputs["instances"].to("cpu")
logger.info(f"Mask R-CNN detected {len(predictions)} instances.")
# Visualize
v = Visualizer(image_np_bgr[:, :, ::-1], # Convert BGR to RGB for Visualizer
metadata=maskrcnn_metadata,
scale=0.8,
instance_mode=ColorMode.SEGMENTATION)
# Draw predictions only if any exist
if len(predictions) > 0:
out = v.draw_instance_predictions(predictions)
output_image_np_rgb = out.get_image() # Visualizer gives RGB
else:
# If no detections, return the original image (converted to RGB)
logger.info("No damage instances detected above threshold.")
output_image_np_rgb = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
return output_image_np_rgb
except Exception as e:
logger.error(f"Error during Mask R-CNN prediction/visualization: {e}", exc_info=True)
# Return original image on error?
return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
# --- Main Gradio Function ---
def predict_pipeline(image_np_input):
"""
Takes numpy image input, runs CLIP, then optionally Mask R-CNN.
Returns: classification text, probability dict, output image (numpy RGB)
"""
if image_np_input is None:
return "Please upload an image.", {}, None
logger.info("Received image for processing...")
# --- Stage 1: CLIP Classification ---
# Convert BGR numpy array from Gradio to PIL RGB for CLIP
image_pil = Image.fromarray(cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB))
classification_result, probabilities = classify_image_clip(image_pil)
logger.info(f"CLIP Result: {classification_result}, Probs: {probabilities}")
output_image = None # Initialize output image
# --- Stage 2: Damage Segmentation (if classified as 'Car') ---
if classification_result == "Car":
logger.info("Image classified as Car. Proceeding to damage segmentation...")
# Pass the original BGR numpy array to the segmentation function
output_image = segment_damage(image_np_input)
if output_image is None: # Handle potential error in segmentation
logger.warning("Damage segmentation returned None. Displaying original image.")
output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB)
elif classification_result == "Not Car":
logger.info("Image classified as Not Car. Skipping damage segmentation.")
# Show the original image if it's not a car
output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB)
else: # Handle CLIP error case
logger.error("CLIP classification failed.")
output_image = cv2.cvtColor(image_np_input, cv2.COLOR_BGR2RGB)
# --- Cleanup ---
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return classification_result, probabilities, output_image
# --- Gradio Interface ---
logger.info("Setting up Gradio interface...")
title = "Car Damage Segmentation Pipeline"
description = """
Upload an image.
1. The first model (CLIP) classifies if it's a car.
2. If it's a car, the second model (Mask R-CNN) segments potential damages.
"""
examples = [
# Add paths to example images if you upload them to the repo
# ["./example_car_damaged.jpg"],
# ["./example_car_ok.jpg"],
# ["./example_not_car.jpg"],
]
# Define Inputs and Outputs
input_image = gr.Image(type="numpy", label="Upload Car Image")
output_classification = gr.Textbox(label="Classification Result")
output_probabilities = gr.Label(label="Class Probabilities") # Label is good for dicts
output_segmentation = gr.Image(type="numpy", label="Damage Segmentation / Original Image")
# Launch the interface
iface = gr.Interface(
fn=predict_pipeline,
inputs=input_image,
outputs=[output_classification, output_probabilities, output_segmentation],
title=title,
description=description,
examples=examples,
allow_flagging="never" # Disable flagging unless needed
)
if __name__ == "__main__":
logger.info("Launching Gradio app...")
iface.launch() # share=True to create public link (use with caution)