ViT-Large for Alzheimer's Detection from Brain MRI Scans

Model Description

This model is a fine-tuned version of google/vit-large-patch16-224 for multiclass classification of brain MRI scans to detect Alzheimer's disease stages.

Key Features:

  • Base Model: Vision Transformer Large (304M parameters)
  • Fine-tuning Strategy: Last 6 transformer layers + classifier
  • Class Imbalance Handling: Weighted cross-entropy loss
  • Data Augmentation: Rotation, flip, brightness/contrast adjustments

Model Performance

Metric Value
Accuracy 0.9375
Precision 0.9374
Recall 0.9375
F1 Score 0.9373

Per-Class F1 Scores

  • Class 0: 0.9108
  • Class 1: 1.0000
  • Class 2: 0.9536
  • Class 3: 0.9231

Training Details

  • Training Data: 4352 brain MRI scans
  • Validation Data: 768 brain MRI scans
  • Epochs: 25
  • Batch Size: 4 (effective: 16)
  • Learning Rate: 1e-05
  • Optimizer: AdamW with cosine learning rate schedule
  • Loss Function: Weighted Cross-Entropy (for class imbalance)
  • Training Time: 53.3 minutes

Intended Use

This model is designed for research purposes in medical image classification, specifically for:

  • Alzheimer's disease detection from brain MRI scans
  • Multi-stage classification of cognitive decline
  • Research and educational purposes in medical AI

Note: This model is NOT intended for clinical diagnosis. Always consult qualified medical professionals.

How to Use

from transformers import AutoImageProcessor, ViTForImageClassification
from PIL import Image
import torch

# Load model and processor
processor = AutoImageProcessor.from_pretrained("NotIshaan/vit-large-alzheimer-6layers-75M-final")
model = ViTForImageClassification.from_pretrained("NotIshaan/vit-large-alzheimer-6layers-75M-final")

# Load and preprocess image
image = Image.open("brain_mri.jpg")
inputs = processor(images=image, return_tensors="pt")

# Make prediction
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class = logits.argmax(-1).item()
    confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()

print(f"Predicted class: {model.config.id2label[predicted_class]}")
print(f"Confidence: {confidence:.2%}")

Label Mapping

{0: '0', 1: '1', 2: '2', 3: '3'}

Limitations

  • Model trained on specific brain MRI dataset - may not generalize to all MRI protocols
  • Class imbalance in training data may affect minority class performance
  • Requires grayscale MRI images (converted to RGB internally)
  • Input image size: 224x224 pixels

Citation

If you use this model, please cite:

@misc{vit-large-alzheimer,
  author = {Your Name},
  title = {ViT-Large for Alzheimer's Detection},
  year = {2025},
  publisher = {Hugging Face},
  howpublished = {\url{https://huggingface.co/NotIshaan/vit-large-alzheimer-6layers-75M-final}}
}

Acknowledgments

Downloads last month
8
Safetensors
Model size
0.3B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for NotIshaan/vit-large-alzheimer-6layers-75M-final

Evaluation results