ACDC Heart Segmentation β Attention U-Net Ensemble
A 5-fold cross-validated Attention U-Net ensemble trained on the ACDC cardiac MRI dataset for multi-class segmentation of cardiac structures.
Model Description
This model segments cardiac MRI short-axis slices into 4 classes:
- Class 0: Background
- Class 1: Right Ventricle (RV)
- Class 2: Myocardium (LVM)
- Class 3: Left Ventricle (LVC)
Architecture
- Base: U-Net with Attention Gates
- Input: Single-channel grayscale MRI (256x256)
- Output: 4-class segmentation map
- Training: 5-fold cross-validation on the ACDC training set
Usage
import torch
from model import AttentionUNet
model = AttentionUNet(img_ch=1, output_ch=4)
state_dict = torch.load("fold_1_model.pth", map_location="cpu", weights_only=False)
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
model.load_state_dict(state_dict)
model.eval()
# Input: [batch, 1, 256, 256] normalized to mean=0.5, std=0.5
img_tensor = torch.randn(1, 1, 256, 256)
with torch.no_grad():
output = model(img_tensor) # [batch, 4, 256, 256]
pred = torch.argmax(output, dim=1) # [batch, 256, 256]
Files
| File | Description |
|---|---|
model.py |
Model architecture (AttentionUNet) |
fold_1_model.pth - fold_5_model.pth |
Trained weights for each CV fold |
Training Details
- Dataset: ACDC (Automated Cardiac Diagnosis Challenge)
- Optimizer: Adam
- Loss: Cross-Entropy + Dice Loss
- Image Size: 256x256
- Normalization: (pixel - 0.5) / 0.5