File size: 25,285 Bytes
562f83f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
import os
import gradio as gr
import torch
import numpy as np
import cv2
import time
import functools
from PIL import Image
from huggingface_hub import hf_hub_download

# --- Configuration ---

# Hugging Face model repositories and filenames
HF_MODEL_CONFIG = {
    "SAM2 Hiera Tiny": {
        "repo_id": "astroanand/CoronarySAM2",
        "filename": "Coronary_Sam2_t.pt"
    },
    "SAM2 Hiera Small": {
        "repo_id": "astroanand/CoronarySAM2",
        "filename": "Coronary_Sam2_s.pt"
    },
    "SAM2 Hiera Base Plus": {
        "repo_id": "astroanand/CoronarySAM2",
        "filename": "Coronary_Sam2_b+.pt"
    },
    "SAM2 Hiera Large": {
        "repo_id": "astroanand/CoronarySAM2",
        "filename": "Coronary_Sam2_l.pt"
    }
}

# Download and cache models from Hugging Face
print("Checking and downloading models from Hugging Face...")
models_available = {}
for name, config in HF_MODEL_CONFIG.items():
    try:
        print(f"Downloading {name} from {config['repo_id']}...")
        model_path = hf_hub_download(
            repo_id=config["repo_id"],
            filename=config["filename"],
            cache_dir="./hf_cache"
        )
        models_available[name] = model_path
        print(f"✓ {name} downloaded successfully to {model_path}")
    except Exception as e:
        print(f"✗ Warning: Failed to download {name} from {config['repo_id']}: {e}")
        print(f"  {name} will not be available in the dropdown.")

if not models_available:
    print("Error: No valid models could be downloaded. Please check your internet connection and HF repository access.")
    # exit() # Or handle gracefully

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Try importing SAM2 modules
try:
    from sam2.build_sam import build_sam2
    from sam2.sam2_image_predictor import SAM2ImagePredictor
except ImportError:
    print("Error: SAM2 modules not found. Make sure 'sam2' directory is in your Python path or installed.")
    exit()

# --- Preprocessing Functions ---
# ...existing code...
def normalize_xray_image(image, kernel_size=(51,51), sigma=0):
    """Normalize X-ray image by applying Gaussian blur and intensity normalization."""
    if image is None: return None
    is_color = len(image.shape) == 3
    if is_color:
        gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    else:
        gray_image = image.copy()
    gray_image = gray_image.astype(float)
    blurred = gray_image.copy()
    for _ in range(5): # Reduced iterations
        blurred = cv2.GaussianBlur(blurred, kernel_size, sigma)
    mean_intensity = np.mean(blurred)
    factor_image = mean_intensity / (blurred + 1e-10)
    if is_color:
        normalized_image = image.copy().astype(float)
        for i in range(3):
            normalized_image[:,:,i] = normalized_image[:,:,i] * factor_image
    else:
        normalized_image = gray_image * factor_image
    return np.clip(normalized_image, 0, 255).astype(np.uint8)

def apply_clahe(image_uint8):
    """Apply CLAHE for better vessel contrast. Expects uint8 input."""
    if image_uint8 is None: return None
    is_color = len(image_uint8.shape) == 3

    # --- ADJUST CLAHE STRENGTH HERE ---
    # Lower clipLimit reduces the contrast enhancement effect.
    # Original was 2.0. Try values like 1.5, 1.0, or even disable by setting it very low.
    clahe_clip_limit = 2.0
    clahe_tile_grid_size = (8, 8)
    print(f" Applying CLAHE with clipLimit={clahe_clip_limit}, tileGridSize={clahe_tile_grid_size}")
    # ---------------------------------

    clahe = cv2.createCLAHE(clipLimit=clahe_clip_limit, tileGridSize=clahe_tile_grid_size)

    if is_color:
        lab = cv2.cvtColor(image_uint8, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        l_clahe = clahe.apply(l)
        lab_clahe = cv2.merge((l_clahe, a, b))
        clahe_image_uint8 = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
    else:
         clahe_image_uint8 = clahe.apply(image_uint8)

    # Return uint8 [0, 255] suitable for predictor.set_image
    return clahe_image_uint8


def preprocess_image_for_sam2(image_rgb_numpy):
    """Combined preprocessing: normalization + CLAHE for SAM2 input."""
    if image_rgb_numpy is None:
        print("Preprocessing: Input image is None.")
        return None

    start_time = time.time()
    print("Preprocessing Step 1: Normalizing X-ray image...")
    if image_rgb_numpy.dtype != np.uint8:
         image_rgb_numpy = np.clip(image_rgb_numpy, 0, 255).astype(np.uint8)
    if len(image_rgb_numpy.shape) == 2:
        image_rgb_numpy = cv2.cvtColor(image_rgb_numpy, cv2.COLOR_GRAY2RGB)

    normalized_uint8 = normalize_xray_image(image_rgb_numpy)
    if normalized_uint8 is None:
        print("Preprocessing failed at normalization step.")
        return None
    print(f"Normalization done in {time.time() - start_time:.2f}s")

    start_time_clahe = time.time()
    print("Preprocessing Step 2: Applying CLAHE...")
    preprocessed_uint8 = apply_clahe(normalized_uint8) # CLAHE applied here
    if preprocessed_uint8 is None:
        print("Preprocessing failed at CLAHE step.")
        return None
    print(f"CLAHE done in {time.time() - start_time_clahe:.2f}s")
    print(f"Total preprocessing time: {time.time() - start_time:.2f}s")

    return preprocessed_uint8 # Return the image after all steps

# --- Model Loading ---
# ...existing code...
@functools.lru_cache(maxsize=4) # Cache up to 4 models (one for each variant)
def load_model(model_name):
    """Loads the specified SAM2 model and creates a predictor."""
    print(f"\nAttempting to load model: {model_name}")
    if model_name not in models_available: # Check against available models
        print(f"Error: Model name '{model_name}' not found or checkpoint missing.")
        return None

    checkpoint_path = models_available[model_name] # Get path from available dict

    try:
        print(f" Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
        if 'model_cfg' not in checkpoint:
            print(f"Error: 'model_cfg' key not found in checkpoint {checkpoint_path}.")
            return None
        model_cfg_name = checkpoint['model_cfg']
        print(f" Using model config from checkpoint: {model_cfg_name}")
        sam2_model = build_sam2(model_cfg_name, checkpoint_path=None, device=DEVICE)
        if 'model_state_dict' not in checkpoint:
             print(f"Error: 'model_state_dict' not found in checkpoint {checkpoint_path}.")
             return None
        state_dict = checkpoint['model_state_dict']
        new_state_dict = {}
        for k, v in state_dict.items():
            name = k[7:] if k.startswith('module.') else k
            new_state_dict[name] = v
        sam2_model.load_state_dict(new_state_dict)
        print(" Successfully loaded fine-tuned model state_dict.")
        sam2_model.to(DEVICE)
        sam2_model.eval()
        predictor = SAM2ImagePredictor(sam2_model)
        print(f"Model '{model_name}' loaded successfully on {DEVICE}.")
        return predictor
    except Exception as e:
        print(f"Error loading model {model_name}: {e}")
        import traceback
        traceback.print_exc()
        return None

# --- Utility Functions ---
# ...existing code...
def resize_image_fixed(image, target_size=1024):
    """Resizes image to a fixed square size (1024x1024)."""
    if image is None: return None
    return cv2.resize(image, (target_size, target_size), interpolation=cv2.INTER_LINEAR)

def draw_points_on_image(image, points_state):
    """Draws points (green positive, red negative) on a copy of the image."""
    if image is None: return image # Return original if no image
    draw_image = image.copy()
    if not points_state: return draw_image # Return copy if no points

    # Make points slightly larger and add a black border
    base_radius = max(4, int(min(image.shape[:2]) * 0.006)) # Slightly larger base radius
    border_thickness = 1 # Thickness of the black border
    radius_with_border = base_radius + border_thickness
    thickness = -1 # Filled circle

    for x, y, label in points_state:
        color = (0, 255, 0) if label == 1 else (255, 0, 0)
        center = (int(x), int(y))
        # Draw black border circle first
        cv2.circle(draw_image, center, radius_with_border, (0, 0, 0), thickness)
        # Draw colored circle on top
        cv2.circle(draw_image, center, base_radius, color, thickness)

    return draw_image

# --- Gradio UI Interaction Functions ---

def get_point_counts_text(points_state):
    """Helper function to generate the point count markdown string."""
    pos_count = sum(1 for _, _, label in points_state if label == 1)
    neg_count = sum(1 for _, _, label in points_state if label == 0)
    return f"**Points Added:** <font color='green'>{pos_count} Positive</font>, <font color='red'>{neg_count} Negative</font>"

def add_point(preprocessed_image, points_state, point_type, evt: gr.SelectData):
    """Callback function when user clicks on the preprocessed image."""
    if preprocessed_image is None:
        gr.Warning("Please upload and preprocess an image first.")
        # Return original image, points state, and existing counts text
        return preprocessed_image, points_state, get_point_counts_text(points_state)
    x, y = evt.index[0], evt.index[1]
    label = 1 if point_type == "Positive" else 0
    # Store coordinates relative to the preprocessed image (1024x1024)
    points_state.append([x, y, label])
    print(f"Added point: ({x}, {y}), Type: {'Positive' if label==1 else 'Negative'}, Total Points: {len(points_state)}")
    image_with_points = draw_points_on_image(preprocessed_image, points_state)
    # Return updated image, points state, and updated counts text
    return image_with_points, points_state, get_point_counts_text(points_state)

def undo_last_point(preprocessed_image, points_state):
    """Removes the last added point and updates the preprocessed display image."""
    if preprocessed_image is None: # Handle case where image is cleared
         # Return None image, points state, and counts text
         return None, points_state, get_point_counts_text(points_state)
    if not points_state:
        print("No points to undo.")
        # Return the current preprocessed image without changes if no points
        return preprocessed_image, points_state, get_point_counts_text(points_state)

    removed_point = points_state.pop()
    print(f"Removed point: {removed_point}, Remaining Points: {len(points_state)}")
    image_with_points = draw_points_on_image(preprocessed_image, points_state)
    # Return updated image, points state, and updated counts text
    return image_with_points, points_state, get_point_counts_text(points_state)

def clear_points_and_display(preprocessed_image_state):
     """Clears points and resets the preprocessed display image."""
     print("Clearing points and resetting preprocessed display.")
     points_state = [] # Clear points
     # Return the stored preprocessed image without points, clear points state, clear mask, clear counts text
     return preprocessed_image_state, points_state, None, get_point_counts_text(points_state)

def run_segmentation(preprocessed_image_state, original_image_state, model_name, points_state):
    """Runs SAM2 segmentation using points on the preprocessed image."""
    start_total_time = time.time()
    # Initialize return values
    output_mask_display = None

    if preprocessed_image_state is None or original_image_state is None:
        gr.Warning("Please upload an image first.")
        return output_mask_display, points_state

    print(f"\n--- Running Segmentation ---")
    print(f" Model Selected: {model_name}")
    print(f" Number of points: {len(points_state)}")

    # --- 1. Load Model ---
    predictor = load_model(model_name)
    if predictor is None:
        gr.Error(f"Failed to load model '{model_name}'. Check logs and paths.")
        return output_mask_display, points_state

    # --- 2. Use Preprocessed Image ---
    # The image is already preprocessed and resized to 1024x1024
    image_for_predictor = preprocessed_image_state
    original_h, original_w = original_image_state.shape[:2] # Get original dims for final resize
    print(f" Using preprocessed image (1024x1024) for predictor.")
    print(f" Original image size for final mask resize: {original_w}x{original_h}")

    print(" Setting preprocessed image in predictor...")
    start_set_image = time.time()
    # Feed the preprocessed image (which is already 1024x1024 uint8) to SAM
    predictor.set_image(image_for_predictor)
    print(f" predictor.set_image took {time.time() - start_set_image:.2f}s")

    # --- 3. Prepare Prompts (No Scaling Needed) ---
    if not points_state:
        # Use center point if no points provided
        center_x, center_y = 512, 512
        point_coords = np.array([[[center_x, center_y]]])
        point_labels = np.array([1])
        print(" No points provided. Using center point (512, 512).")
    else:
        # Points are already relative to the 1024x1024 preprocessed image
        point_coords_list = [[x, y] for x, y, label in points_state]
        labels_list = [label for x, y, label in points_state]
        point_coords = np.array([point_coords_list])
        point_labels = np.array(labels_list)
        print(f" Using {len(points_state)} provided points (coords relative to 1024x1024).")

    point_coords_torch = torch.tensor(point_coords, dtype=torch.float32).to(DEVICE)
    point_labels_torch = torch.tensor(point_labels, dtype=torch.float32).unsqueeze(0).to(DEVICE) # Add batch dim

    # --- 4. Run Model Inference ---
    print(" Running model inference...")
    start_inference_time = time.time()
    with torch.no_grad():
        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
            points=(point_coords_torch, point_labels_torch), boxes=None, masks=None
        )
        if predictor._features is None:
             gr.Error("Image features not computed. Predictor might not have been set correctly.")
             return output_mask_display, points_state
        # Ensure features are accessed correctly
        image_embed = predictor._features["image_embed"][-1].unsqueeze(0)
        image_pe = predictor.model.sam_prompt_encoder.get_dense_pe()
        # Handle potential missing high_res_features key gracefully
        high_res_features = None
        if "high_res_feats" in predictor._features and predictor._features["high_res_feats"]:
             try:
                 high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
             except IndexError:
                 print("Warning: Index error accessing high_res_feats. Proceeding without them.")
             except Exception as e:
                 print(f"Warning: Error processing high_res_features: {e}. Proceeding without them.")

        low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
            image_embeddings=image_embed, image_pe=image_pe,
            sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings,
            multimask_output=True, repeat_image=False, # repeat_image should be False for single image prediction
            high_res_features=high_res_features, # Pass None if not available
        )
        # Postprocess masks to 1024x1024
        prd_masks_1024 = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1]) # predictor._orig_hw should be (1024, 1024)
        # Select the best mask based on predicted score
        best_mask_idx = torch.argmax(prd_scores[0]).item()
        # Apply sigmoid and thresholding
        best_mask_1024_prob = torch.sigmoid(prd_masks_1024[:, best_mask_idx])
        binary_mask_1024 = (best_mask_1024_prob > 0.5).cpu().numpy().squeeze() # Squeeze to get (H, W)
    print(f" Model inference took {time.time() - start_inference_time:.2f}s")

    # --- 5. Resize Mask to Original Dimensions ---
    print(" Resizing mask to original dimensions...")
    final_mask_resized = cv2.resize(
        binary_mask_1024.astype(np.uint8), (original_w, original_h), interpolation=cv2.INTER_NEAREST
    )

    # --- 6. Format Mask for Display ---
    # Mask for display (RGB)
    output_mask_display = (final_mask_resized * 255).astype(np.uint8)
    if len(output_mask_display.shape) == 2: # Ensure RGB for display consistency
        output_mask_display = cv2.cvtColor(output_mask_display, cv2.COLOR_GRAY2RGB)

    total_time = time.time() - start_total_time
    print(f"--- Segmentation Complete (Total time: {total_time:.2f}s) ---")

    # Return: mask for display, points state (unchanged)
    return output_mask_display, points_state # No change needed here as it doesn't modify points


def process_upload(uploaded_image):
    """Handles image upload: preprocesses, resizes, stores states."""
    if uploaded_image is None:
        # Clear everything including point counts
        return None, None, None, [], None, get_point_counts_text([])

    print("Image uploaded. Processing...")
    # 1. Store original image
    original_image = uploaded_image.copy()

    # 2. Resize to 1024x1024 for preprocessing
    image_resized_1024 = resize_image_fixed(original_image, 1024)
    if image_resized_1024 is None:
        gr.Error("Failed to resize image.")
        return None, None, None, [], None

    # 3. Preprocess the 1024x1024 image
    preprocessed_1024 = preprocess_image_for_sam2(image_resized_1024)
    if preprocessed_1024 is None:
        gr.Error("Image preprocessing failed.")
        return None, None, None, [], None

    # Ensure preprocessed image is RGB for display
    if len(preprocessed_1024.shape) == 2:
        preprocessed_1024_display = cv2.cvtColor(preprocessed_1024, cv2.COLOR_GRAY2RGB)
    else:
        preprocessed_1024_display = preprocessed_1024.copy()

    print("Image processed successfully.")
    points_state = [] # Clear points on new upload
    # Return:
    # 1. Preprocessed image for display (interactive)
    # 2. Preprocessed image for state
    # 3. Original image for state
    # 4. Cleared points state
    # 5. Cleared mask display
    # 6. Cleared point counts text
    return preprocessed_1024_display, preprocessed_1024, original_image, points_state, None, get_point_counts_text(points_state)


def clear_all_outputs():
    """Clears all input/output fields and states."""
    print("Clearing all inputs and outputs.")
    points_state = [] # Clear points
    # Clear everything including point counts
    return None, None, None, points_state, None, get_point_counts_text(points_state)


# --- Build Gradio Interface ---
css = """
    #mask_display_container .gradio-image { height: 450px !important; object-fit: contain; }
    #preprocessed_image_container .gradio-image { height: 450px !important; object-fit: contain; cursor: crosshair !important; }
    #upload_container .gradio-image { height: 150px !important; object-fit: contain; } /* Smaller upload preview */
    .output-col img { max-height: 450px; object-fit: contain; }
    .control-col { min-width: 500px; } /* Wider control column */
    .output-col { min-width: 500px; }
"""

with gr.Blocks(css=css, title="Coronary Artery Segmentation (Fine-tuned SAM2)") as demo:
    gr.Markdown("# Coronary Artery Segmentation using Fine-tuned SAM2")
    gr.Markdown(
        "**Let's find those arteries!**\n\n"
        "1. Upload your Coronary X-ray Image.\n"
        "2. The preprocessed image appears on the left. Time to guide the AI! Click directly on the image to add **Positive** (artery) or **Negative** (background) points.\n"
        "3. Choose your fine-tuned SAM2 model.\n"
        "4. Hit 'Run Segmentation' and watch the magic happen!\n"
        "5. Download your predicted mask (the white area) using the download button on the mask image."
    )

    # --- States ---
    points_state = gr.State([])
    # State to store the original uploaded image (needed for final mask resizing)
    original_image_state = gr.State(None)
    # State to store the preprocessed 1024x1024 image data (used for drawing points and predictor input)
    preprocessed_image_state = gr.State(None)


    with gr.Row():
        # --- Left Column (Controls & Preprocessed Image Interaction) ---
        with gr.Column(scale=1, elem_classes="control-col"):
            gr.Markdown("## 1. Upload & Controls")
            # Keep upload separate and smaller
            upload_image = gr.Image(
                type="numpy", label="Upload Coronary X-ray Image",
                height=150, elem_id="upload_container"
            )
            gr.Markdown("## 2. Add Points on Preprocessed Image")
            # Interactive Preprocessed Image Display
            preprocessed_image_display = gr.Image(
                type="numpy", label="Click on Image to Add Points",
                interactive=True, # Make this interactive
                height=450, elem_id="preprocessed_image_container"
            )
            # Add Point Counter Display
            point_counter_display = gr.Markdown(get_point_counts_text([]))

            model_selector = gr.Dropdown(
                choices=list(models_available.keys()),
                label="Select SAM2 Model Variant",
                value=list(models_available.keys())[-1] if models_available else None
            )
            prompt_type = gr.Radio(
                ["Positive", "Negative"], label="Point Prompt Type", value="Positive"
            )
            with gr.Row():
                clear_button = gr.Button("Clear Points")
                undo_button = gr.Button("Undo Last Point")
            run_button = gr.Button("Run Segmentation", variant="primary")
            clear_all_button = gr.Button("Clear All") # Added Clear All

        # --- Right Column (Output Mask) ---
        with gr.Column(scale=1, elem_classes="output-col"):
            gr.Markdown("## 3. Predicted Mask")
            final_mask_display = gr.Image(
                type="numpy", label="Predicted Binary Mask (White = Artery)",
                interactive=False, height=450, elem_id="mask_display_container",
                format="png" # Specify PNG format for download
            )


    # --- Define Interactions ---

    # 1. Upload triggers preprocessing and display
    upload_image.upload(
        fn=process_upload,
        inputs=[upload_image],
        outputs=[
            preprocessed_image_display, # Update interactive display
            preprocessed_image_state,   # Update state
            original_image_state,       # Update state
            points_state,               # Clear points
            final_mask_display,         # Clear mask display
            point_counter_display       # Clear point counts
        ]
    )

    # 2. Clicking on preprocessed image adds points
    preprocessed_image_display.select(
        fn=add_point,
        inputs=[preprocessed_image_state, points_state, prompt_type],
        outputs=[
            preprocessed_image_display, # Update display with points
            points_state,               # Update points state
            point_counter_display       # Update point counts
            ]
    )

    # 3. Clear points button resets points and preprocessed display
    clear_button.click(
        fn=clear_points_and_display,
        inputs=[preprocessed_image_state], # Needs the clean preprocessed image
        outputs=[
            preprocessed_image_display, # Reset display
            points_state,               # Clear points
            final_mask_display,         # Clear mask
            point_counter_display       # Reset point counts
            ]
    )

    # 4. Undo button removes last point and updates preprocessed display
    undo_button.click(
        fn=undo_last_point,
        inputs=[preprocessed_image_state, points_state], # Needs current preprocessed image
        outputs=[
            preprocessed_image_display, # Update display
            points_state,               # Update points state
            point_counter_display       # Update point counts
            ]
    )

    # 5. Run segmentation (Outputs don't change point counts)
    run_button.click(
        fn=run_segmentation,
        inputs=[
            preprocessed_image_state, # Use preprocessed image data
            original_image_state,     # Needed for final resize dim
            model_selector,
            points_state              # Points are relative to preprocessed
            ],
        outputs=[
            final_mask_display,       # Show the final mask
            points_state              # Pass points state (might be needed if run modifies it - currently doesn't)
            ]
    )

    # 7. Clear All button
    clear_all_button.click(
        fn=clear_all_outputs,
        inputs=[],
        outputs=[
            upload_image,
            preprocessed_image_display,
            preprocessed_image_state,
            points_state,
            final_mask_display,
            point_counter_display # Reset point counts
        ]
    )


# --- Launch the App ---
if __name__ == "__main__":
    print("Launching Gradio App...")
    demo.launch()