Chest X-ray Multi-Label Classifier (ViT with MAE Pre-training)
Task: Multi-label classification for 15 NIH Chest X-ray findings (14 diseases + No Finding)
📋 Model Overview
Two-stage training approach:
- Stage 1 - MAE Pre-training: Self-supervised learning on unlabeled chest X-rays (70 epochs)
- Stage 2 - Fine-tuning ViT: Supervised fine-tuning for multi-label classification (20 epochs)
🏥 Dataset
Dataset: tta1301/nih-chest-xray-small
| Statistic |
Value |
| Total images |
>10,000 |
| Disease classes |
15 (14 diseases + No Finding) |
| Train/Val/Test split |
70/15/15 |
| Image size |
224x224 |
Disease Classes (15 classes)
| Index |
Disease (English) |
Disease (Vietnamese) |
| 0 |
No Finding |
Không phát hiện bất thường |
| 1 |
Atelectasis |
Xẹp phổi |
| 2 |
Cardiomegaly |
Tim to |
| 3 |
Effusion |
Tràn dịch màng phổi |
| 4 |
Infiltration |
Thâm nhiễm phổi |
| 5 |
Mass |
Khối u phổi |
| 6 |
Nodule |
Nốt phổi |
| 7 |
Pneumonia |
Viêm phổi |
| 8 |
Pneumothorax |
Tràn khí màng phổi |
| 9 |
Consolidation |
Đông đặc phổi |
| 10 |
Edema |
Phù phổi |
| 11 |
Emphysema |
Khí phế thũng |
| 12 |
Fibrosis |
Xơ phổi |
| 13 |
Pleural_Thickening |
Dày màng phổi |
| 14 |
Hernia |
Thoát vị hoành |
🚀 Training Results
Stage 1 - MAE Pre-training (70 epochs)
| Epoch |
Loss |
| 0 |
0.7709 |
| 20 |
0.3218 |
| 40 |
0.1987 |
| 69 |
0.1168 |
Stage 2 - Fine-tuning (20 epochs)
| Metric |
Train |
Validation |
Test |
| Accuracy |
0.9306 |
0.9307 |
0.9025 |
| Micro F1 |
0.9254 |
0.9213 |
0.8932 |
| Macro F1 |
0.8912 |
0.8876 |
0.8567 |
| ROC-AUC |
0.9789 |
0.9754 |
0.9612 |
Per-class F1 Score (Test)
| Disease |
F1 |
| No Finding |
0.95 |
| Hernia |
0.945 |
| Pneumothorax |
0.912 |
| Cardiomegaly |
0.903 |
| Edema |
0.894 |
| Pneumonia |
0.892 |
| Effusion |
0.885 |
| Mass |
0.876 |
| Consolidation |
0.873 |
| Atelectasis |
0.859 |
| Emphysema |
0.854 |
| Nodule |
0.843 |
| Fibrosis |
0.833 |
| Pleural_Thickening |
0.823 |
| Infiltration |
0.812 |
💻 Usage
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
from PIL import Image
processor = AutoImageProcessor.from_pretrained("tta1301/xray-vit-classifier-v3")
model = AutoModelForImageClassification.from_pretrained("tta1301/xray-vit-classifier-v3")
model.eval()
DISEASES = [
'No Finding',
'Atelectasis',
'Cardiomegaly',
'Effusion',
'Infiltration',
'Mass',
'Nodule',
'Pneumonia',
'Pneumothorax',
'Consolidation',
'Edema',
'Emphysema',
'Fibrosis',
'Pleural_Thickening',
'Hernia'
]
def predict_chest_xray(image_path, threshold=0.3):
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.sigmoid(outputs.logits)[0]
results = {DISEASES[i]: float(probs[i])
for i in range(len(DISEASES)) if probs[i] > threshold}
return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
result = predict_chest_xray("chest_xray.jpg")
print(result)