FILMUnet2D

This model is a 2D U-Net with FiLM conditioning for Ultrasound multi-organ segmentation.

Installation

Make sure you have transformers and torch installed.

pip install transformers torch

Usage

You can load the model and processor using the Auto classes from transformers. Since this repository contains custom code, make sure to pass trust_remote_code=True.

import torch
from transformers import AutoModel, AutoImageProcessor
from PIL import Image

# 1. Load model and processor
repo_id = "AImageLab-Zip/US_FiLMUNet" 

processor = AutoImageProcessor.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
model.eval()

# 2. Load and preprocess your image
#    The processor handles resizing, letterboxing, and normalization.
image = Image.open("path/to/your/image.png").convert("RGB")
inputs = processor(images=image, return_tensors="pt")

# 3. Prepare conditioning input
#    This should be an integer tensor representing the organ ID.
#    Replace `4` with the appropriate ID for your use case.
organ_id = torch.tensor([4]) 

# 4. Run inference
with torch.no_grad():
    outputs = model(**inputs, organ_id=organ_id)

# 5. Post-process the output to get the final segmentation mask
#    The processor can convert the logits to a binary mask, automatically handling
#    the removal of letterbox padding and resizing to the original image dimensions.
mask = processor.post_process_semantic_segmentation(
    outputs, 
    inputs, 
    threshold=0.7, 
    return_as_pil=True
)[0]

# 6. Save the result
mask.save("output_mask.png")

print("Segmentation mask saved to output_mask.png")

Model Details

  • Architecture: U-Net with FiLM layers for conditional segmentation.
  • Conditioning: The model's output is conditioned on an organ_id input.
  • Input: RGB images.
  • Output: A single-channel segmentation mask.

Configuration

The model configuration can be accessed via model.config. Key parameters include:

  • in_channels: Number of input channels (default: 3).
  • num_classes: Number of output classes (default: 1).
  • n_organs: The number of different organs the model was trained to condition on.
  • depth: The depth of the U-Net.
  • size: The base number of filters in the first layer.

Organ IDs

The organ_id passed to the model corresponds to the following mapping:

organ_to_class_dict = {
    "appendix": 0,
    "breast": 1,
    "breast_luminal": 1,
    "cardiac": 2,
    "thyroid": 3,
    "fetal": 4,
    "kidney": 5,
    "liver": 6,
    "testicle": 7,
}

Alternative Versions

This repository contains multiple versions of the model located in subfolders. You can load a specific version by using the subfolder parameter.

4-Stage U-Net

This version has a U-Net depth of 4.

from transformers import AutoModel

model_4_stages = AutoModel.from_pretrained(
    "AImageLab-Zip/US_FiLMUNet", 
    subfolder="unet_4_stages",
    trust_remote_code=True
)
Downloads last month
61
Safetensors
Model size
0.2B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support