Spaces:
Sleeping
Sleeping
File size: 25,285 Bytes
562f83f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 |
import os
import gradio as gr
import torch
import numpy as np
import cv2
import time
import functools
from PIL import Image
from huggingface_hub import hf_hub_download
# --- Configuration ---
# Hugging Face model repositories and filenames
HF_MODEL_CONFIG = {
"SAM2 Hiera Tiny": {
"repo_id": "astroanand/CoronarySAM2",
"filename": "Coronary_Sam2_t.pt"
},
"SAM2 Hiera Small": {
"repo_id": "astroanand/CoronarySAM2",
"filename": "Coronary_Sam2_s.pt"
},
"SAM2 Hiera Base Plus": {
"repo_id": "astroanand/CoronarySAM2",
"filename": "Coronary_Sam2_b+.pt"
},
"SAM2 Hiera Large": {
"repo_id": "astroanand/CoronarySAM2",
"filename": "Coronary_Sam2_l.pt"
}
}
# Download and cache models from Hugging Face
print("Checking and downloading models from Hugging Face...")
models_available = {}
for name, config in HF_MODEL_CONFIG.items():
try:
print(f"Downloading {name} from {config['repo_id']}...")
model_path = hf_hub_download(
repo_id=config["repo_id"],
filename=config["filename"],
cache_dir="./hf_cache"
)
models_available[name] = model_path
print(f"✓ {name} downloaded successfully to {model_path}")
except Exception as e:
print(f"✗ Warning: Failed to download {name} from {config['repo_id']}: {e}")
print(f" {name} will not be available in the dropdown.")
if not models_available:
print("Error: No valid models could be downloaded. Please check your internet connection and HF repository access.")
# exit() # Or handle gracefully
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {DEVICE}")
# Try importing SAM2 modules
try:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
except ImportError:
print("Error: SAM2 modules not found. Make sure 'sam2' directory is in your Python path or installed.")
exit()
# --- Preprocessing Functions ---
# ...existing code...
def normalize_xray_image(image, kernel_size=(51,51), sigma=0):
"""Normalize X-ray image by applying Gaussian blur and intensity normalization."""
if image is None: return None
is_color = len(image.shape) == 3
if is_color:
gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray_image = image.copy()
gray_image = gray_image.astype(float)
blurred = gray_image.copy()
for _ in range(5): # Reduced iterations
blurred = cv2.GaussianBlur(blurred, kernel_size, sigma)
mean_intensity = np.mean(blurred)
factor_image = mean_intensity / (blurred + 1e-10)
if is_color:
normalized_image = image.copy().astype(float)
for i in range(3):
normalized_image[:,:,i] = normalized_image[:,:,i] * factor_image
else:
normalized_image = gray_image * factor_image
return np.clip(normalized_image, 0, 255).astype(np.uint8)
def apply_clahe(image_uint8):
"""Apply CLAHE for better vessel contrast. Expects uint8 input."""
if image_uint8 is None: return None
is_color = len(image_uint8.shape) == 3
# --- ADJUST CLAHE STRENGTH HERE ---
# Lower clipLimit reduces the contrast enhancement effect.
# Original was 2.0. Try values like 1.5, 1.0, or even disable by setting it very low.
clahe_clip_limit = 2.0
clahe_tile_grid_size = (8, 8)
print(f" Applying CLAHE with clipLimit={clahe_clip_limit}, tileGridSize={clahe_tile_grid_size}")
# ---------------------------------
clahe = cv2.createCLAHE(clipLimit=clahe_clip_limit, tileGridSize=clahe_tile_grid_size)
if is_color:
lab = cv2.cvtColor(image_uint8, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
l_clahe = clahe.apply(l)
lab_clahe = cv2.merge((l_clahe, a, b))
clahe_image_uint8 = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)
else:
clahe_image_uint8 = clahe.apply(image_uint8)
# Return uint8 [0, 255] suitable for predictor.set_image
return clahe_image_uint8
def preprocess_image_for_sam2(image_rgb_numpy):
"""Combined preprocessing: normalization + CLAHE for SAM2 input."""
if image_rgb_numpy is None:
print("Preprocessing: Input image is None.")
return None
start_time = time.time()
print("Preprocessing Step 1: Normalizing X-ray image...")
if image_rgb_numpy.dtype != np.uint8:
image_rgb_numpy = np.clip(image_rgb_numpy, 0, 255).astype(np.uint8)
if len(image_rgb_numpy.shape) == 2:
image_rgb_numpy = cv2.cvtColor(image_rgb_numpy, cv2.COLOR_GRAY2RGB)
normalized_uint8 = normalize_xray_image(image_rgb_numpy)
if normalized_uint8 is None:
print("Preprocessing failed at normalization step.")
return None
print(f"Normalization done in {time.time() - start_time:.2f}s")
start_time_clahe = time.time()
print("Preprocessing Step 2: Applying CLAHE...")
preprocessed_uint8 = apply_clahe(normalized_uint8) # CLAHE applied here
if preprocessed_uint8 is None:
print("Preprocessing failed at CLAHE step.")
return None
print(f"CLAHE done in {time.time() - start_time_clahe:.2f}s")
print(f"Total preprocessing time: {time.time() - start_time:.2f}s")
return preprocessed_uint8 # Return the image after all steps
# --- Model Loading ---
# ...existing code...
@functools.lru_cache(maxsize=4) # Cache up to 4 models (one for each variant)
def load_model(model_name):
"""Loads the specified SAM2 model and creates a predictor."""
print(f"\nAttempting to load model: {model_name}")
if model_name not in models_available: # Check against available models
print(f"Error: Model name '{model_name}' not found or checkpoint missing.")
return None
checkpoint_path = models_available[model_name] # Get path from available dict
try:
print(f" Loading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
if 'model_cfg' not in checkpoint:
print(f"Error: 'model_cfg' key not found in checkpoint {checkpoint_path}.")
return None
model_cfg_name = checkpoint['model_cfg']
print(f" Using model config from checkpoint: {model_cfg_name}")
sam2_model = build_sam2(model_cfg_name, checkpoint_path=None, device=DEVICE)
if 'model_state_dict' not in checkpoint:
print(f"Error: 'model_state_dict' not found in checkpoint {checkpoint_path}.")
return None
state_dict = checkpoint['model_state_dict']
new_state_dict = {}
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
sam2_model.load_state_dict(new_state_dict)
print(" Successfully loaded fine-tuned model state_dict.")
sam2_model.to(DEVICE)
sam2_model.eval()
predictor = SAM2ImagePredictor(sam2_model)
print(f"Model '{model_name}' loaded successfully on {DEVICE}.")
return predictor
except Exception as e:
print(f"Error loading model {model_name}: {e}")
import traceback
traceback.print_exc()
return None
# --- Utility Functions ---
# ...existing code...
def resize_image_fixed(image, target_size=1024):
"""Resizes image to a fixed square size (1024x1024)."""
if image is None: return None
return cv2.resize(image, (target_size, target_size), interpolation=cv2.INTER_LINEAR)
def draw_points_on_image(image, points_state):
"""Draws points (green positive, red negative) on a copy of the image."""
if image is None: return image # Return original if no image
draw_image = image.copy()
if not points_state: return draw_image # Return copy if no points
# Make points slightly larger and add a black border
base_radius = max(4, int(min(image.shape[:2]) * 0.006)) # Slightly larger base radius
border_thickness = 1 # Thickness of the black border
radius_with_border = base_radius + border_thickness
thickness = -1 # Filled circle
for x, y, label in points_state:
color = (0, 255, 0) if label == 1 else (255, 0, 0)
center = (int(x), int(y))
# Draw black border circle first
cv2.circle(draw_image, center, radius_with_border, (0, 0, 0), thickness)
# Draw colored circle on top
cv2.circle(draw_image, center, base_radius, color, thickness)
return draw_image
# --- Gradio UI Interaction Functions ---
def get_point_counts_text(points_state):
"""Helper function to generate the point count markdown string."""
pos_count = sum(1 for _, _, label in points_state if label == 1)
neg_count = sum(1 for _, _, label in points_state if label == 0)
return f"**Points Added:** <font color='green'>{pos_count} Positive</font>, <font color='red'>{neg_count} Negative</font>"
def add_point(preprocessed_image, points_state, point_type, evt: gr.SelectData):
"""Callback function when user clicks on the preprocessed image."""
if preprocessed_image is None:
gr.Warning("Please upload and preprocess an image first.")
# Return original image, points state, and existing counts text
return preprocessed_image, points_state, get_point_counts_text(points_state)
x, y = evt.index[0], evt.index[1]
label = 1 if point_type == "Positive" else 0
# Store coordinates relative to the preprocessed image (1024x1024)
points_state.append([x, y, label])
print(f"Added point: ({x}, {y}), Type: {'Positive' if label==1 else 'Negative'}, Total Points: {len(points_state)}")
image_with_points = draw_points_on_image(preprocessed_image, points_state)
# Return updated image, points state, and updated counts text
return image_with_points, points_state, get_point_counts_text(points_state)
def undo_last_point(preprocessed_image, points_state):
"""Removes the last added point and updates the preprocessed display image."""
if preprocessed_image is None: # Handle case where image is cleared
# Return None image, points state, and counts text
return None, points_state, get_point_counts_text(points_state)
if not points_state:
print("No points to undo.")
# Return the current preprocessed image without changes if no points
return preprocessed_image, points_state, get_point_counts_text(points_state)
removed_point = points_state.pop()
print(f"Removed point: {removed_point}, Remaining Points: {len(points_state)}")
image_with_points = draw_points_on_image(preprocessed_image, points_state)
# Return updated image, points state, and updated counts text
return image_with_points, points_state, get_point_counts_text(points_state)
def clear_points_and_display(preprocessed_image_state):
"""Clears points and resets the preprocessed display image."""
print("Clearing points and resetting preprocessed display.")
points_state = [] # Clear points
# Return the stored preprocessed image without points, clear points state, clear mask, clear counts text
return preprocessed_image_state, points_state, None, get_point_counts_text(points_state)
def run_segmentation(preprocessed_image_state, original_image_state, model_name, points_state):
"""Runs SAM2 segmentation using points on the preprocessed image."""
start_total_time = time.time()
# Initialize return values
output_mask_display = None
if preprocessed_image_state is None or original_image_state is None:
gr.Warning("Please upload an image first.")
return output_mask_display, points_state
print(f"\n--- Running Segmentation ---")
print(f" Model Selected: {model_name}")
print(f" Number of points: {len(points_state)}")
# --- 1. Load Model ---
predictor = load_model(model_name)
if predictor is None:
gr.Error(f"Failed to load model '{model_name}'. Check logs and paths.")
return output_mask_display, points_state
# --- 2. Use Preprocessed Image ---
# The image is already preprocessed and resized to 1024x1024
image_for_predictor = preprocessed_image_state
original_h, original_w = original_image_state.shape[:2] # Get original dims for final resize
print(f" Using preprocessed image (1024x1024) for predictor.")
print(f" Original image size for final mask resize: {original_w}x{original_h}")
print(" Setting preprocessed image in predictor...")
start_set_image = time.time()
# Feed the preprocessed image (which is already 1024x1024 uint8) to SAM
predictor.set_image(image_for_predictor)
print(f" predictor.set_image took {time.time() - start_set_image:.2f}s")
# --- 3. Prepare Prompts (No Scaling Needed) ---
if not points_state:
# Use center point if no points provided
center_x, center_y = 512, 512
point_coords = np.array([[[center_x, center_y]]])
point_labels = np.array([1])
print(" No points provided. Using center point (512, 512).")
else:
# Points are already relative to the 1024x1024 preprocessed image
point_coords_list = [[x, y] for x, y, label in points_state]
labels_list = [label for x, y, label in points_state]
point_coords = np.array([point_coords_list])
point_labels = np.array(labels_list)
print(f" Using {len(points_state)} provided points (coords relative to 1024x1024).")
point_coords_torch = torch.tensor(point_coords, dtype=torch.float32).to(DEVICE)
point_labels_torch = torch.tensor(point_labels, dtype=torch.float32).unsqueeze(0).to(DEVICE) # Add batch dim
# --- 4. Run Model Inference ---
print(" Running model inference...")
start_inference_time = time.time()
with torch.no_grad():
sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
points=(point_coords_torch, point_labels_torch), boxes=None, masks=None
)
if predictor._features is None:
gr.Error("Image features not computed. Predictor might not have been set correctly.")
return output_mask_display, points_state
# Ensure features are accessed correctly
image_embed = predictor._features["image_embed"][-1].unsqueeze(0)
image_pe = predictor.model.sam_prompt_encoder.get_dense_pe()
# Handle potential missing high_res_features key gracefully
high_res_features = None
if "high_res_feats" in predictor._features and predictor._features["high_res_feats"]:
try:
high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
except IndexError:
print("Warning: Index error accessing high_res_feats. Proceeding without them.")
except Exception as e:
print(f"Warning: Error processing high_res_features: {e}. Proceeding without them.")
low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
image_embeddings=image_embed, image_pe=image_pe,
sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings,
multimask_output=True, repeat_image=False, # repeat_image should be False for single image prediction
high_res_features=high_res_features, # Pass None if not available
)
# Postprocess masks to 1024x1024
prd_masks_1024 = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1]) # predictor._orig_hw should be (1024, 1024)
# Select the best mask based on predicted score
best_mask_idx = torch.argmax(prd_scores[0]).item()
# Apply sigmoid and thresholding
best_mask_1024_prob = torch.sigmoid(prd_masks_1024[:, best_mask_idx])
binary_mask_1024 = (best_mask_1024_prob > 0.5).cpu().numpy().squeeze() # Squeeze to get (H, W)
print(f" Model inference took {time.time() - start_inference_time:.2f}s")
# --- 5. Resize Mask to Original Dimensions ---
print(" Resizing mask to original dimensions...")
final_mask_resized = cv2.resize(
binary_mask_1024.astype(np.uint8), (original_w, original_h), interpolation=cv2.INTER_NEAREST
)
# --- 6. Format Mask for Display ---
# Mask for display (RGB)
output_mask_display = (final_mask_resized * 255).astype(np.uint8)
if len(output_mask_display.shape) == 2: # Ensure RGB for display consistency
output_mask_display = cv2.cvtColor(output_mask_display, cv2.COLOR_GRAY2RGB)
total_time = time.time() - start_total_time
print(f"--- Segmentation Complete (Total time: {total_time:.2f}s) ---")
# Return: mask for display, points state (unchanged)
return output_mask_display, points_state # No change needed here as it doesn't modify points
def process_upload(uploaded_image):
"""Handles image upload: preprocesses, resizes, stores states."""
if uploaded_image is None:
# Clear everything including point counts
return None, None, None, [], None, get_point_counts_text([])
print("Image uploaded. Processing...")
# 1. Store original image
original_image = uploaded_image.copy()
# 2. Resize to 1024x1024 for preprocessing
image_resized_1024 = resize_image_fixed(original_image, 1024)
if image_resized_1024 is None:
gr.Error("Failed to resize image.")
return None, None, None, [], None
# 3. Preprocess the 1024x1024 image
preprocessed_1024 = preprocess_image_for_sam2(image_resized_1024)
if preprocessed_1024 is None:
gr.Error("Image preprocessing failed.")
return None, None, None, [], None
# Ensure preprocessed image is RGB for display
if len(preprocessed_1024.shape) == 2:
preprocessed_1024_display = cv2.cvtColor(preprocessed_1024, cv2.COLOR_GRAY2RGB)
else:
preprocessed_1024_display = preprocessed_1024.copy()
print("Image processed successfully.")
points_state = [] # Clear points on new upload
# Return:
# 1. Preprocessed image for display (interactive)
# 2. Preprocessed image for state
# 3. Original image for state
# 4. Cleared points state
# 5. Cleared mask display
# 6. Cleared point counts text
return preprocessed_1024_display, preprocessed_1024, original_image, points_state, None, get_point_counts_text(points_state)
def clear_all_outputs():
"""Clears all input/output fields and states."""
print("Clearing all inputs and outputs.")
points_state = [] # Clear points
# Clear everything including point counts
return None, None, None, points_state, None, get_point_counts_text(points_state)
# --- Build Gradio Interface ---
css = """
#mask_display_container .gradio-image { height: 450px !important; object-fit: contain; }
#preprocessed_image_container .gradio-image { height: 450px !important; object-fit: contain; cursor: crosshair !important; }
#upload_container .gradio-image { height: 150px !important; object-fit: contain; } /* Smaller upload preview */
.output-col img { max-height: 450px; object-fit: contain; }
.control-col { min-width: 500px; } /* Wider control column */
.output-col { min-width: 500px; }
"""
with gr.Blocks(css=css, title="Coronary Artery Segmentation (Fine-tuned SAM2)") as demo:
gr.Markdown("# Coronary Artery Segmentation using Fine-tuned SAM2")
gr.Markdown(
"**Let's find those arteries!**\n\n"
"1. Upload your Coronary X-ray Image.\n"
"2. The preprocessed image appears on the left. Time to guide the AI! Click directly on the image to add **Positive** (artery) or **Negative** (background) points.\n"
"3. Choose your fine-tuned SAM2 model.\n"
"4. Hit 'Run Segmentation' and watch the magic happen!\n"
"5. Download your predicted mask (the white area) using the download button on the mask image."
)
# --- States ---
points_state = gr.State([])
# State to store the original uploaded image (needed for final mask resizing)
original_image_state = gr.State(None)
# State to store the preprocessed 1024x1024 image data (used for drawing points and predictor input)
preprocessed_image_state = gr.State(None)
with gr.Row():
# --- Left Column (Controls & Preprocessed Image Interaction) ---
with gr.Column(scale=1, elem_classes="control-col"):
gr.Markdown("## 1. Upload & Controls")
# Keep upload separate and smaller
upload_image = gr.Image(
type="numpy", label="Upload Coronary X-ray Image",
height=150, elem_id="upload_container"
)
gr.Markdown("## 2. Add Points on Preprocessed Image")
# Interactive Preprocessed Image Display
preprocessed_image_display = gr.Image(
type="numpy", label="Click on Image to Add Points",
interactive=True, # Make this interactive
height=450, elem_id="preprocessed_image_container"
)
# Add Point Counter Display
point_counter_display = gr.Markdown(get_point_counts_text([]))
model_selector = gr.Dropdown(
choices=list(models_available.keys()),
label="Select SAM2 Model Variant",
value=list(models_available.keys())[-1] if models_available else None
)
prompt_type = gr.Radio(
["Positive", "Negative"], label="Point Prompt Type", value="Positive"
)
with gr.Row():
clear_button = gr.Button("Clear Points")
undo_button = gr.Button("Undo Last Point")
run_button = gr.Button("Run Segmentation", variant="primary")
clear_all_button = gr.Button("Clear All") # Added Clear All
# --- Right Column (Output Mask) ---
with gr.Column(scale=1, elem_classes="output-col"):
gr.Markdown("## 3. Predicted Mask")
final_mask_display = gr.Image(
type="numpy", label="Predicted Binary Mask (White = Artery)",
interactive=False, height=450, elem_id="mask_display_container",
format="png" # Specify PNG format for download
)
# --- Define Interactions ---
# 1. Upload triggers preprocessing and display
upload_image.upload(
fn=process_upload,
inputs=[upload_image],
outputs=[
preprocessed_image_display, # Update interactive display
preprocessed_image_state, # Update state
original_image_state, # Update state
points_state, # Clear points
final_mask_display, # Clear mask display
point_counter_display # Clear point counts
]
)
# 2. Clicking on preprocessed image adds points
preprocessed_image_display.select(
fn=add_point,
inputs=[preprocessed_image_state, points_state, prompt_type],
outputs=[
preprocessed_image_display, # Update display with points
points_state, # Update points state
point_counter_display # Update point counts
]
)
# 3. Clear points button resets points and preprocessed display
clear_button.click(
fn=clear_points_and_display,
inputs=[preprocessed_image_state], # Needs the clean preprocessed image
outputs=[
preprocessed_image_display, # Reset display
points_state, # Clear points
final_mask_display, # Clear mask
point_counter_display # Reset point counts
]
)
# 4. Undo button removes last point and updates preprocessed display
undo_button.click(
fn=undo_last_point,
inputs=[preprocessed_image_state, points_state], # Needs current preprocessed image
outputs=[
preprocessed_image_display, # Update display
points_state, # Update points state
point_counter_display # Update point counts
]
)
# 5. Run segmentation (Outputs don't change point counts)
run_button.click(
fn=run_segmentation,
inputs=[
preprocessed_image_state, # Use preprocessed image data
original_image_state, # Needed for final resize dim
model_selector,
points_state # Points are relative to preprocessed
],
outputs=[
final_mask_display, # Show the final mask
points_state # Pass points state (might be needed if run modifies it - currently doesn't)
]
)
# 7. Clear All button
clear_all_button.click(
fn=clear_all_outputs,
inputs=[],
outputs=[
upload_image,
preprocessed_image_display,
preprocessed_image_state,
points_state,
final_mask_display,
point_counter_display # Reset point counts
]
)
# --- Launch the App ---
if __name__ == "__main__":
print("Launching Gradio App...")
demo.launch() |