Federated Learning for Glaucoma Segmentation: Model Checkpoints
Overview
This repository contains trained model checkpoints from the research project: "A Federated Learning-based Optic Disc and Cup Segmentation Model for Glaucoma Monitoring in Color Fundus Photographs"
Key Information
- Task: Automated optic disc and cup segmentation for glaucoma assessment
- Architecture: Mask2Former with Swin Transformer backbone
- Pre-training: ADE20K semantic segmentation dataset
- Training Data: 5,550 color fundus photographs from 9 datasets across 7 countries
- Approach: Privacy-preserving federated learning with site-specific fine-tuning
Clinical Context
Glaucoma is a leading cause of irreversible blindness worldwide, affecting 3.54% of the population aged 40-80 and projected to impact 111.8 million people by 2040. A key indicator of glaucoma severity is the vertical cup-to-disc ratio (CDR), with ratios ≥0.6 suggestive of glaucoma.
This work addresses the need for accurate automated segmentation while preserving patient data privacy across multiple clinical sites, enabling HIPAA/GDPR-compliant multi-institutional collaboration.
Models Included
This repository contains 22 trained models organized into four categories:
Baseline Models
- Central Model (1 model): Trained on pooled multi-site data, representing upper bound performance
- Local Models (9 models): Site-specific models trained on individual datasets, representing lower bound performance
Federated Learning Models
- Pipeline 1 (1 model): Global Validation using unweighted FedAvg
- Pipeline 2 (1 model): Weighted Global Validation using dataset-size weighted FedAvg
- Pipeline 3 (1 model): Onsite Validation with local early stopping before aggregation
- Pipeline 4 (9 models): Fine-Tuned Onsite Validation with site-specific fine-tuning
Usage
Download Specific Model
from huggingface_hub import hf_hub_download
Download central model
model_path = hf_hub_download( repo_id="sud11111/Federated-Learning-Glaucoma", filename="models/central/best_model.pt" )
Download fine-tuned model for specific dataset
model_path = hf_hub_download( repo_id="sud11111/Federated-Learning-Glaucoma", filename="models/pipeline4/chaksu/best_model.pt" )
Download All Models
from huggingface_hub import snapshot_download
Download entire models directory
local_dir = snapshot_download( repo_id="sud11111/Federated-Learning-Glaucoma", allow_patterns="models/**" ) print(f"Models downloaded to: {local_dir}")
Load and Perform Inference
import torch from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor from PIL import Image
Load preprocessor
processor = Mask2FormerImageProcessor.from_pretrained( "facebook/mask2former-swin-base-ade-semantic" )
Load model architecture
model = Mask2FormerForUniversalSegmentation.from_pretrained( "facebook/mask2former-swin-base-ade-semantic", num_labels=4 # background, unlabeled, optic disc, optic cup )
Load trained weights
model.load_state_dict(torch.load(model_path)) model.eval()
Perform inference on fundus image
image = Image.open("fundus_image.jpg") inputs = processor(images=image, return_tensors="pt")
with torch.no_grad(): outputs = model(**inputs)
Post-process segmentation
predicted_segmentation = processor.post_process_semantic_segmentation( outputs, target_sizes=[image.size[::-1]] )[0]
Datasets
Training was performed across 9 public datasets spanning 7 countries, comprising a total of 5,550 color fundus photographs from at least 917 patients:
| Dataset | Total Images | Test Images | Country | Characteristics |
|---|---|---|---|---|
| Chaksu | 1,345 | 135 | India | Multi-center research dataset |
| REFUGE | 1,200 | 120 | China | Glaucoma challenge dataset |
| G1020 | 1,020 | 102 | Germany | Benchmark retinal fundus dataset |
| RIM-ONE DL | 485 | 49 | Spain | Glaucoma assessment dataset |
| MESSIDOR | 460 | 46 | France | Diabetic retinopathy screening |
| ORIGA | 650 | 65 | Singapore | Multi-ethnic Asian population |
| Bin Rushed | 195 | 20 | Saudi Arabia | RIGA dataset collection |
| DRISHTI-GS | 101 | 11 | India | Optic nerve head segmentation |
| Magrabi | 94 | 10 | Saudi Arabia | RIGA dataset collection |
Data Split: Each dataset was divided into training (80%), validation (10%), and testing (10%) subsets. For datasets with multiple expert annotations, the STAPLE (Simultaneous Truth and Performance Level Estimation) method was used to generate consensus segmentation labels.
Model Architecture
- Base Model: Mask2Former
- Backbone: Swin Transformer (Swin-Base)
- Pre-training: ADE20K semantic segmentation dataset
- Input Resolution: 512×512 pixels
- Output Classes: 4 (background, unlabeled, optic disc, optic cup)
- Optimizer: AdamW (learning rate: 2×10⁻⁵)
- Loss Function: Multi-class cross-entropy
- Early Stopping: Patience of 7 epochs/rounds
Training Configuration
Common Hyperparameters
- Batch size: 8
- Learning rate: 2×10⁻⁵
- Optimizer: AdamW with weight decay
- Maximum epochs: 100 (with early stopping)
- Early stopping patience: 7 epochs/rounds
- Input size: 512×512 pixels (normalized)
Federated Learning Specifications
- Pipeline 1 & 2: 1 epoch per site per round, up to 100 rounds
- Pipeline 3: Up to 20 local epochs per site per round, 10 FL rounds
- Pipeline 4: Up to 20 fine-tuning epochs per site with local early stopping