Federated Learning for Glaucoma Segmentation: Model Checkpoints

Overview

This repository contains trained model checkpoints from the research project: "A Federated Learning-based Optic Disc and Cup Segmentation Model for Glaucoma Monitoring in Color Fundus Photographs"

Key Information

  • Task: Automated optic disc and cup segmentation for glaucoma assessment
  • Architecture: Mask2Former with Swin Transformer backbone
  • Pre-training: ADE20K semantic segmentation dataset
  • Training Data: 5,550 color fundus photographs from 9 datasets across 7 countries
  • Approach: Privacy-preserving federated learning with site-specific fine-tuning

Clinical Context

Glaucoma is a leading cause of irreversible blindness worldwide, affecting 3.54% of the population aged 40-80 and projected to impact 111.8 million people by 2040. A key indicator of glaucoma severity is the vertical cup-to-disc ratio (CDR), with ratios ≥0.6 suggestive of glaucoma.

This work addresses the need for accurate automated segmentation while preserving patient data privacy across multiple clinical sites, enabling HIPAA/GDPR-compliant multi-institutional collaboration.

Models Included

This repository contains 22 trained models organized into four categories:

Baseline Models

  • Central Model (1 model): Trained on pooled multi-site data, representing upper bound performance
  • Local Models (9 models): Site-specific models trained on individual datasets, representing lower bound performance

Federated Learning Models

  • Pipeline 1 (1 model): Global Validation using unweighted FedAvg
  • Pipeline 2 (1 model): Weighted Global Validation using dataset-size weighted FedAvg
  • Pipeline 3 (1 model): Onsite Validation with local early stopping before aggregation
  • Pipeline 4 (9 models): Fine-Tuned Onsite Validation with site-specific fine-tuning

Usage

Download Specific Model

from huggingface_hub import hf_hub_download

Download central model

model_path = hf_hub_download( repo_id="sud11111/Federated-Learning-Glaucoma", filename="models/central/best_model.pt" )

Download fine-tuned model for specific dataset

model_path = hf_hub_download( repo_id="sud11111/Federated-Learning-Glaucoma", filename="models/pipeline4/chaksu/best_model.pt" )

Download All Models

from huggingface_hub import snapshot_download

Download entire models directory

local_dir = snapshot_download( repo_id="sud11111/Federated-Learning-Glaucoma", allow_patterns="models/**" ) print(f"Models downloaded to: {local_dir}")

Load and Perform Inference

import torch from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor from PIL import Image

Load preprocessor

processor = Mask2FormerImageProcessor.from_pretrained( "facebook/mask2former-swin-base-ade-semantic" )

Load model architecture

model = Mask2FormerForUniversalSegmentation.from_pretrained( "facebook/mask2former-swin-base-ade-semantic", num_labels=4 # background, unlabeled, optic disc, optic cup )

Load trained weights

model.load_state_dict(torch.load(model_path)) model.eval()

Perform inference on fundus image

image = Image.open("fundus_image.jpg") inputs = processor(images=image, return_tensors="pt")

with torch.no_grad(): outputs = model(**inputs)

Post-process segmentation

predicted_segmentation = processor.post_process_semantic_segmentation( outputs, target_sizes=[image.size[::-1]] )[0]

Datasets

Training was performed across 9 public datasets spanning 7 countries, comprising a total of 5,550 color fundus photographs from at least 917 patients:

Dataset Total Images Test Images Country Characteristics
Chaksu 1,345 135 India Multi-center research dataset
REFUGE 1,200 120 China Glaucoma challenge dataset
G1020 1,020 102 Germany Benchmark retinal fundus dataset
RIM-ONE DL 485 49 Spain Glaucoma assessment dataset
MESSIDOR 460 46 France Diabetic retinopathy screening
ORIGA 650 65 Singapore Multi-ethnic Asian population
Bin Rushed 195 20 Saudi Arabia RIGA dataset collection
DRISHTI-GS 101 11 India Optic nerve head segmentation
Magrabi 94 10 Saudi Arabia RIGA dataset collection

Data Split: Each dataset was divided into training (80%), validation (10%), and testing (10%) subsets. For datasets with multiple expert annotations, the STAPLE (Simultaneous Truth and Performance Level Estimation) method was used to generate consensus segmentation labels.

Model Architecture

  • Base Model: Mask2Former
  • Backbone: Swin Transformer (Swin-Base)
  • Pre-training: ADE20K semantic segmentation dataset
  • Input Resolution: 512×512 pixels
  • Output Classes: 4 (background, unlabeled, optic disc, optic cup)
  • Optimizer: AdamW (learning rate: 2×10⁻⁵)
  • Loss Function: Multi-class cross-entropy
  • Early Stopping: Patience of 7 epochs/rounds

Training Configuration

Common Hyperparameters

  • Batch size: 8
  • Learning rate: 2×10⁻⁵
  • Optimizer: AdamW with weight decay
  • Maximum epochs: 100 (with early stopping)
  • Early stopping patience: 7 epochs/rounds
  • Input size: 512×512 pixels (normalized)

Federated Learning Specifications

  • Pipeline 1 & 2: 1 epoch per site per round, up to 100 rounds
  • Pipeline 3: Up to 20 local epochs per site per round, 10 FL rounds
  • Pipeline 4: Up to 20 fine-tuning epochs per site with local early stopping

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support