Chest X-ray Multi-Label Classifier (ViT with MAE Pre-training)

Task: Multi-label classification for 15 NIH Chest X-ray findings (14 diseases + No Finding)

📋 Model Overview

Two-stage training approach:

  • Stage 1 - MAE Pre-training: Self-supervised learning on unlabeled chest X-rays (70 epochs)
  • Stage 2 - Fine-tuning ViT: Supervised fine-tuning for multi-label classification (20 epochs)

🏥 Dataset

Dataset: tta1301/nih-chest-xray-small

Statistic Value
Total images >10,000
Disease classes 15 (14 diseases + No Finding)
Train/Val/Test split 70/15/15
Image size 224x224

Disease Classes (15 classes)

Index Disease (English) Disease (Vietnamese)
0 No Finding Không phát hiện bất thường
1 Atelectasis Xẹp phổi
2 Cardiomegaly Tim to
3 Effusion Tràn dịch màng phổi
4 Infiltration Thâm nhiễm phổi
5 Mass Khối u phổi
6 Nodule Nốt phổi
7 Pneumonia Viêm phổi
8 Pneumothorax Tràn khí màng phổi
9 Consolidation Đông đặc phổi
10 Edema Phù phổi
11 Emphysema Khí phế thũng
12 Fibrosis Xơ phổi
13 Pleural_Thickening Dày màng phổi
14 Hernia Thoát vị hoành

🚀 Training Results

Stage 1 - MAE Pre-training (70 epochs)

Epoch Loss
0 0.7709
20 0.3218
40 0.1987
69 0.1168

Stage 2 - Fine-tuning (20 epochs)

Metric Train Validation Test
Accuracy 0.9306 0.9307 0.9025
Micro F1 0.9254 0.9213 0.8932
Macro F1 0.8912 0.8876 0.8567
ROC-AUC 0.9789 0.9754 0.9612

Per-class F1 Score (Test)

Disease F1
No Finding 0.95
Hernia 0.945
Pneumothorax 0.912
Cardiomegaly 0.903
Edema 0.894
Pneumonia 0.892
Effusion 0.885
Mass 0.876
Consolidation 0.873
Atelectasis 0.859
Emphysema 0.854
Nodule 0.843
Fibrosis 0.833
Pleural_Thickening 0.823
Infiltration 0.812

💻 Usage

from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
from PIL import Image

# Load model
processor = AutoImageProcessor.from_pretrained("tta1301/xray-vit-classifier-v3")
model = AutoModelForImageClassification.from_pretrained("tta1301/xray-vit-classifier-v3")
model.eval()

# Disease labels (updated order with No Finding)
DISEASES = [
    'No Finding',           # 0
    'Atelectasis',          # 1
    'Cardiomegaly',         # 2
    'Effusion',             # 3
    'Infiltration',         # 4
    'Mass',                 # 5
    'Nodule',               # 6
    'Pneumonia',            # 7
    'Pneumothorax',         # 8
    'Consolidation',        # 9
    'Edema',                # 10
    'Emphysema',            # 11
    'Fibrosis',             # 12
    'Pleural_Thickening',   # 13
    'Hernia'                # 14
]

def predict_chest_xray(image_path, threshold=0.3):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.sigmoid(outputs.logits)[0]
    
    results = {DISEASES[i]: float(probs[i]) 
               for i in range(len(DISEASES)) if probs[i] > threshold}
    return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))

# Example
result = predict_chest_xray("chest_xray.jpg")
print(result)
Downloads last month
68
Safetensors
Model size
85.8M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train tta1301/xray-vit-classifier-v3