ACDC Heart Segmentation β€” Attention U-Net Ensemble

A 5-fold cross-validated Attention U-Net ensemble trained on the ACDC cardiac MRI dataset for multi-class segmentation of cardiac structures.

Model Description

This model segments cardiac MRI short-axis slices into 4 classes:

  • Class 0: Background
  • Class 1: Right Ventricle (RV)
  • Class 2: Myocardium (LVM)
  • Class 3: Left Ventricle (LVC)

Architecture

  • Base: U-Net with Attention Gates
  • Input: Single-channel grayscale MRI (256x256)
  • Output: 4-class segmentation map
  • Training: 5-fold cross-validation on the ACDC training set

Usage

import torch
from model import AttentionUNet

model = AttentionUNet(img_ch=1, output_ch=4)
state_dict = torch.load("fold_1_model.pth", map_location="cpu", weights_only=False)
if 'model_state_dict' in state_dict:
    state_dict = state_dict['model_state_dict']
model.load_state_dict(state_dict)
model.eval()

# Input: [batch, 1, 256, 256] normalized to mean=0.5, std=0.5
img_tensor = torch.randn(1, 1, 256, 256)
with torch.no_grad():
    output = model(img_tensor)  # [batch, 4, 256, 256]
    pred = torch.argmax(output, dim=1)  # [batch, 256, 256]

Files

File Description
model.py Model architecture (AttentionUNet)
fold_1_model.pth - fold_5_model.pth Trained weights for each CV fold

Training Details

  • Dataset: ACDC (Automated Cardiac Diagnosis Challenge)
  • Optimizer: Adam
  • Loss: Cross-Entropy + Dice Loss
  • Image Size: 256x256
  • Normalization: (pixel - 0.5) / 0.5
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Space using MohidAbdullah/ACDC-Heart-Segmentation 1