The model can be loaded like below:

import torch
import torchvision.transforms.v2 as T
import segmentation_models_pytorch as smp

checkpoint = "unet-resnet18.pt"
model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights=None,
    in_channels=4,
    classes=2,
)
model.load_state_dict(checkpoint, map_location="cpu")
transforms = torch.nn.Sequential(T.Normalize(mean=[0.0], std=[255.0]))
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support