ishanprogs commited on
Commit
4ce9de2
·
verified ·
1 Parent(s): b1a336a

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +228 -0
  2. clip_text_features.pt +3 -0
  3. clip_vit_b16.pth +3 -0
  4. requirements.txt +13 -0
  5. yolobest.pt +3 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import clip
4
+ from PIL import Image
5
+ import cv2
6
+ import numpy as np
7
+ import os
8
+ from ultralytics import YOLO # Import YOLO
9
+ import gc
10
+
11
+ # --- Configuration & Model Loading ---
12
+
13
+ # Device Setup
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ print(f"Using device: {device}")
16
+
17
+ # --- CLIP Model Setup ---
18
+ print("Loading CLIP model...")
19
+ try:
20
+ clip_model, clip_preprocess = clip.load("ViT-B/16", device=device, jit=False) # jit=False can sometimes help compatibility
21
+ # Load saved visual backbone weights (optional but good practice if specifically saved)
22
+ # clip_model_path = "clip_model/clip_vit_b16.pth"
23
+ # if os.path.exists(clip_model_path):
24
+ # clip_model.load_state_dict(torch.load(clip_model_path, map_location=device))
25
+ # print("Loaded custom CLIP visual weights.")
26
+ clip_model.eval()
27
+
28
+ # Load saved text features
29
+ clip_text_features_path = "clip_text_features.pt"
30
+ if not os.path.exists(clip_text_features_path):
31
+ raise FileNotFoundError("CLIP text features file 'clip_text_features.pt' not found.")
32
+ clip_text_features = torch.load(clip_text_features_path, map_location=device)
33
+ print("CLIP model and text features loaded.")
34
+ except Exception as e:
35
+ print(f"Error loading CLIP model or features: {e}")
36
+ # Handle error appropriately, maybe disable CLIP check
37
+ clip_model = None
38
+
39
+ # --- YOLOv8 Model Setup ---
40
+ print("Loading YOLOv8 model...")
41
+ # Define class names EXACTLY as used during YOLO training
42
+ YOLO_CLASSES = ['Cracked', 'Scratch', 'Flaking', 'Broken part', 'Corrosion', 'Dent','Paint chip','Missing part']
43
+ YOLO_NUM_CLASSES = len(YOLO_CLASSES)
44
+
45
+ # Path to your best YOLOv8 weights
46
+ yolo_weights_path = "best.pt"
47
+
48
+ if not os.path.exists(yolo_weights_path):
49
+ raise FileNotFoundError(f"YOLOv8 weights file '{yolo_weights_path}' not found.")
50
+
51
+ try:
52
+ yolo_model = YOLO(yolo_weights_path)
53
+ # Set model parameters manually if needed (especially if config wasn't saved)
54
+ # This ensures the internal model state matches your training
55
+ # yolo_model.model.yaml['nc'] = YOLO_NUM_CLASSES # Usually loaded from weights/yaml, but good to verify
56
+ # Forcing model names if they don't load correctly from weights:
57
+ yolo_model.names = {i: name for i, name in enumerate(YOLO_CLASSES)}
58
+
59
+ # Move model to device explicitly
60
+ yolo_model.to(device)
61
+ print("YOLOv8 model loaded.")
62
+ print(f"YOLOv8 Class Names: {yolo_model.names}")
63
+ except Exception as e:
64
+ print(f"Error loading YOLOv8 model: {e}")
65
+ yolo_model = None
66
+
67
+ # --- Prediction Functions ---
68
+
69
+ def validate_image_with_clip(image_pil):
70
+ """Checks if the PIL image is likely a car using CLIP."""
71
+ if clip_model is None:
72
+ print("CLIP model not loaded, skipping validation.")
73
+ return "Car", 1.0 # Assume it's a car if CLIP failed to load
74
+
75
+ print("Running CLIP validation...")
76
+ try:
77
+ # Use simple preprocessing for validation check
78
+ image_input = clip_preprocess(image_pil).unsqueeze(0).to(device)
79
+
80
+ with torch.no_grad():
81
+ image_features = clip_model.encode_image(image_input)
82
+ image_features /= image_features.norm(dim=-1, keepdim=True)
83
+
84
+ logit_scale = clip_model.logit_scale.exp()
85
+ similarity = (image_features @ clip_text_features.T) * logit_scale
86
+ probs = similarity.softmax(dim=-1).squeeze() # Get probabilities
87
+
88
+ car_prob = probs[0].item()
89
+ not_car_prob = probs[1].item()
90
+ predicted_label = "Car" if car_prob > not_car_prob else "Not Car"
91
+
92
+ print(f"CLIP Result: {predicted_label} (Car Prob: {car_prob:.4f}, Not Car Prob: {not_car_prob:.4f})")
93
+ return predicted_label, car_prob
94
+ except Exception as e:
95
+ print(f"Error during CLIP prediction: {e}")
96
+ return "Error", 0.0
97
+
98
+
99
+ def predict_damage_with_yolo(image_np_bgr, confidence_threshold=0.4):
100
+ """Runs YOLOv8 segmentation on the OpenCV image (BGR)."""
101
+ if yolo_model is None:
102
+ print("YOLOv8 model not loaded, skipping damage prediction.")
103
+ # Return original image with error message
104
+ cv2.putText(image_np_bgr, "YOLOv8 model failed to load", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
105
+ return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB) # Return RGB for Gradio
106
+
107
+ print(f"Running YOLOv8 prediction with conf: {confidence_threshold}...")
108
+ try:
109
+ # Perform prediction
110
+ results = yolo_model.predict(
111
+ source=image_np_bgr, # Pass BGR numpy array
112
+ conf=confidence_threshold,
113
+ verbose=False, # Less console output
114
+ device=device
115
+ )
116
+
117
+ if not results or len(results) == 0:
118
+ print("YOLOv8 predict() returned no results.")
119
+ # Return original image with message
120
+ cv2.putText(image_np_bgr, "No results from YOLOv8", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 128, 255), 2)
121
+ return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
122
+
123
+ result = results[0] # Get results for the first image
124
+
125
+ # Use the built-in plot function to draw results on the image
126
+ # result.plot() returns a NumPy array in RGB format
127
+ annotated_image_rgb = result.plot(conf=True, boxes=True, masks=True)
128
+
129
+ print(f"YOLOv8 found {len(result.boxes)} instances above threshold.")
130
+ return annotated_image_rgb # Return the annotated RGB image
131
+
132
+ except Exception as e:
133
+ print(f"Error during YOLOv8 prediction or plotting: {e}")
134
+ # Return original image with error message
135
+ cv2.putText(image_np_bgr, f"YOLO Error: {e}", (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2, cv2.LINE_AA)
136
+ return cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
137
+
138
+
139
+ # --- Main Gradio Function ---
140
+
141
+ def validate_and_segment(input_image_pil, clip_threshold, yolo_threshold):
142
+ """
143
+ Main function called by Gradio interface.
144
+ Takes a PIL image, runs CLIP validation, then YOLOv8 segmentation if valid.
145
+ """
146
+ start_time = torch.cuda.Event(enable_timing=True)
147
+ end_time = torch.cuda.Event(enable_timing=True)
148
+
149
+ if input_image_pil is None:
150
+ return None, "Please upload an image."
151
+
152
+ # 1. Validate using CLIP
153
+ clip_label, clip_prob = validate_image_with_clip(input_image_pil)
154
+
155
+ if clip_label == "Error":
156
+ return None, "Error during CLIP validation."
157
+ if clip_label == "Not Car" or clip_prob < clip_threshold:
158
+ status_message = f"Image rejected by validator. Classified as '{clip_label}' (Confidence: {clip_prob:.2f}). Required > {clip_threshold:.2f}."
159
+ print(status_message)
160
+ # Convert PIL to numpy BGR then RGB for display
161
+ img_display_rgb = cv2.cvtColor(np.array(input_image_pil), cv2.COLOR_RGB2BGR)
162
+ cv2.putText(img_display_rgb, status_message, (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2, cv2.LINE_AA)
163
+ img_display_rgb = cv2.cvtColor(img_display_rgb, cv2.COLOR_BGR2RGB)
164
+ return img_display_rgb, status_message # Display original image with message
165
+
166
+ # 2. If validation passes, run YOLOv8 segmentation
167
+ status_message = f"Image validated as 'Car' (Confidence: {clip_prob:.2f}). Running damage segmentation..."
168
+ print(status_message)
169
+
170
+ # Convert PIL Image to OpenCV format (BGR NumPy array) for YOLOv8
171
+ image_np_bgr = cv2.cvtColor(np.array(input_image_pil), cv2.COLOR_RGB2BGR)
172
+
173
+ # Record start time for YOLO prediction
174
+ start_time.record()
175
+
176
+ # Run YOLO prediction
177
+ annotated_image_rgb = predict_damage_with_yolo(image_np_bgr, yolo_threshold)
178
+
179
+ # Record end time and calculate duration
180
+ end_time.record()
181
+ torch.cuda.synchronize() # Wait for GPU operations to complete
182
+ prediction_time = start_time.elapsed_time(end_time) / 1000.0 # Time in seconds
183
+
184
+ status_message += f"\nDamage segmentation complete (Time: {prediction_time:.2f}s)."
185
+ print(status_message)
186
+
187
+ # Clear memory after prediction
188
+ gc.collect()
189
+ if torch.cuda.is_available():
190
+ torch.cuda.empty_cache()
191
+
192
+ return annotated_image_rgb, status_message
193
+
194
+
195
+ # --- Create Gradio Interface ---
196
+ print("Creating Gradio interface...")
197
+
198
+ # Define input and output components
199
+ image_input = gr.Image(type="pil", label="Upload Car Image") # Input PIL image
200
+ image_output = gr.Image(type="numpy", label="Segmentation Result") # Output NumPy array (RGB)
201
+ status_output = gr.Textbox(label="Status & Validation Result")
202
+ clip_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="CLIP Car Confidence Threshold")
203
+ yolo_slider = gr.Slider(minimum=0.05, maximum=0.95, step=0.05, value=0.4, label="YOLO Damage Confidence Threshold")
204
+
205
+ # Load example images if available
206
+ example_image_folder = "examples"
207
+ example_list = []
208
+ if os.path.isdir(example_image_folder):
209
+ for img_name in os.listdir(example_image_folder):
210
+ if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
211
+ example_list.append(os.path.join(example_image_folder, img_name))
212
+
213
+ # Build the interface
214
+ iface = gr.Interface(
215
+ fn=validate_and_segment,
216
+ inputs=[image_input, clip_slider, yolo_slider],
217
+ outputs=[image_output, status_output],
218
+ title="🚗 Car Damage Validation & Segmentation",
219
+ description="Upload an image of a car. The system first validates if it's a car using CLIP. If validated, it runs YOLOv8 to segment damage.",
220
+ examples=example_list if example_list else None,
221
+ allow_flagging='never' # Disable flagging
222
+ )
223
+
224
+ # --- Launch the Interface ---
225
+ print("Launching Gradio interface...")
226
+ # share=True creates a public link (valid for ~72h) if running locally outside HF Spaces
227
+ # Use auth for basic protection if needed: auth=("username", "password")
228
+ iface.launch(share=False) # Set share=True if running locally and need public access
clip_text_features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28315215c9429a04e5aafd99cf8a0292a489bf2937d44d580a3cf1c78ee84f94
3
+ size 3283
clip_vit_b16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a3b5ff65477fcc1bbaa1fcaa249a6f9745269e6e06e751f9eea6efeb521bb7b
3
+ size 350463888
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ gradio
5
+ ultralytics
6
+ opencv-python-headless # Use headless version for servers
7
+ matplotlib # If used by ultralytics plotting or your code
8
+ ftfy # CLIP dependency
9
+ regex # CLIP dependency
10
+ git+https://github.com/openai/CLIP.git # Install CLIP directly
11
+ Pillow # PIL dependency for CLIP/images
12
+ # Add any other specific libraries you might need
13
+ pyyaml # Usually needed by ultralytics/detectron2 indirectly
yolobest.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae345bfb159676f6343daf72c1912bb374fa4997e6788e84d930b9bb28751d27
3
+ size 92296829