Commit
·
9253cc6
1
Parent(s):
fca4013
Include model config
Browse files- README.md +4 -1
- config.yaml +202 -0
README.md
CHANGED
|
@@ -10,7 +10,7 @@ tags:
|
|
| 10 |
|
| 11 |
# granite-geospatial-ocean
|
| 12 |
|
| 13 |
-
The granite-geospatial-ocean foundation model was jointly developed by IBM
|
| 14 |
|
| 15 |
## Architecture Overview
|
| 16 |
|
|
@@ -42,6 +42,9 @@ Your feedback is invaluable to us. If you have any feedback about the model, ple
|
|
| 42 |
### Model Card Authors
|
| 43 |
Geoffrey Dawson, Remy Vandaele, Andrew Taylor, David Moffat, Helen Tamura-Wicks, Sarah Jackson, Chunbo Luo, Paolo Fraccaro, Hywel Williams, Rosie Lickorish and Anne Jones
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
### IBM Public Repository Disclosure:
|
| 46 |
All content in this repository including code has been provided by IBM under the associated open source software license and IBM is under no obligation to provide enhancements, updates, or support. IBM developers produced this code as an open source project (not as an IBM product), and IBM makes no assertions as to the level of quality nor security, and will not be maintaining this code going forward.
|
| 47 |
|
|
|
|
| 10 |
|
| 11 |
# granite-geospatial-ocean
|
| 12 |
|
| 13 |
+
The granite-geospatial-ocean foundation model was jointly developed by IBM and STFC as part of a collaboration with the University of Exeter and Plymouth Marine Lab under the UK HNCDI programme. This pre-trained model supports a range of potential use cases in ocean ecosystem health, fisheries management, pollution and other ocean processes that can be monitored using ocean colour observations. We provide an example to fine tune the model to quantify primary production by phytoplankton (carbon sequestration which determine's the ocean's role in climate change).
|
| 14 |
|
| 15 |
## Architecture Overview
|
| 16 |
|
|
|
|
| 42 |
### Model Card Authors
|
| 43 |
Geoffrey Dawson, Remy Vandaele, Andrew Taylor, David Moffat, Helen Tamura-Wicks, Sarah Jackson, Chunbo Luo, Paolo Fraccaro, Hywel Williams, Rosie Lickorish and Anne Jones
|
| 44 |
|
| 45 |
+
### Acknowledgments
|
| 46 |
+
This work was supported by the Hartree National Centre for Digital Innovation, a collaboration between STFC and IBM.
|
| 47 |
+
|
| 48 |
### IBM Public Repository Disclosure:
|
| 49 |
All content in this repository including code has been provided by IBM under the associated open source software license and IBM is under no obligation to provide enhancements, updates, or support. IBM developers produced this code as an open source project (not as an IBM product), and IBM makes no assertions as to the level of quality nor security, and will not be maintaining this code going forward.
|
| 50 |
|
config.yaml
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# lightning.pytorch==2.1.1
|
| 2 |
+
seed_everything: 42
|
| 3 |
+
out_dtype: float32
|
| 4 |
+
custom_modules_path: ./../custom_modules/
|
| 5 |
+
### Trainer configuration
|
| 6 |
+
trainer:
|
| 7 |
+
accelerator: auto
|
| 8 |
+
strategy: auto
|
| 9 |
+
devices: auto
|
| 10 |
+
num_nodes: 1
|
| 11 |
+
# precision: 16-mixed
|
| 12 |
+
logger:
|
| 13 |
+
class_path: TensorBoardLogger
|
| 14 |
+
init_args:
|
| 15 |
+
save_dir: ./../data/
|
| 16 |
+
name: model_runs
|
| 17 |
+
callbacks:
|
| 18 |
+
- class_path: LearningRateMonitor
|
| 19 |
+
init_args:
|
| 20 |
+
logging_interval: epoch
|
| 21 |
+
- class_path: EarlyStopping
|
| 22 |
+
init_args:
|
| 23 |
+
monitor: val/loss
|
| 24 |
+
patience: 100
|
| 25 |
+
max_epochs: 1
|
| 26 |
+
check_val_every_n_epoch: 1
|
| 27 |
+
log_every_n_steps: 5
|
| 28 |
+
enable_checkpointing: true
|
| 29 |
+
default_root_dir: ./../data/
|
| 30 |
+
|
| 31 |
+
### Data configuration
|
| 32 |
+
data:
|
| 33 |
+
class_path: terratorch.datamodules.GenericNonGeoPixelwiseRegressionDataModule
|
| 34 |
+
init_args:
|
| 35 |
+
batch_size: 8
|
| 36 |
+
num_workers: 2
|
| 37 |
+
train_transform:
|
| 38 |
+
- class_path: albumentations.HorizontalFlip
|
| 39 |
+
init_args:
|
| 40 |
+
p: 0.5
|
| 41 |
+
- class_path: albumentations.RandomCrop
|
| 42 |
+
init_args:
|
| 43 |
+
height: 42
|
| 44 |
+
width: 42
|
| 45 |
+
- class_path: albumentations.Rotate
|
| 46 |
+
init_args:
|
| 47 |
+
limit: 30
|
| 48 |
+
border_mode: 0 # cv2.BORDER_CONSTANT
|
| 49 |
+
value: 0
|
| 50 |
+
# mask_value: 1
|
| 51 |
+
p: 0.5
|
| 52 |
+
- class_path: ToTensorV2
|
| 53 |
+
# Specify all bands which are in the input data.
|
| 54 |
+
# -1 are placeholders for bands that are in the data but that we will discard
|
| 55 |
+
dataset_bands:
|
| 56 |
+
- Oa01_reflectance
|
| 57 |
+
- Oa02_reflectance
|
| 58 |
+
- Oa03_reflectance
|
| 59 |
+
- Oa04_reflectance
|
| 60 |
+
- Oa05_reflectance
|
| 61 |
+
- Oa06_reflectance
|
| 62 |
+
- Oa07_reflectance
|
| 63 |
+
- Oa08_reflectance
|
| 64 |
+
- Oa09_reflectance
|
| 65 |
+
- Oa10_reflectance
|
| 66 |
+
- Oa11_reflectance
|
| 67 |
+
- Oa12_reflectance
|
| 68 |
+
- Oa16_reflectance
|
| 69 |
+
- Oa17_reflectance
|
| 70 |
+
- Oa18_reflectance
|
| 71 |
+
- Oa21_reflectance
|
| 72 |
+
- SST
|
| 73 |
+
output_bands: #Specify the bands which are used from the input data.
|
| 74 |
+
- Oa01_reflectance
|
| 75 |
+
- Oa02_reflectance
|
| 76 |
+
- Oa03_reflectance
|
| 77 |
+
- Oa04_reflectance
|
| 78 |
+
- Oa05_reflectance
|
| 79 |
+
- Oa06_reflectance
|
| 80 |
+
- Oa07_reflectance
|
| 81 |
+
- Oa08_reflectance
|
| 82 |
+
- Oa09_reflectance
|
| 83 |
+
- Oa10_reflectance
|
| 84 |
+
- Oa11_reflectance
|
| 85 |
+
- Oa12_reflectance
|
| 86 |
+
- Oa16_reflectance
|
| 87 |
+
- Oa17_reflectance
|
| 88 |
+
- Oa18_reflectance
|
| 89 |
+
- Oa21_reflectance
|
| 90 |
+
rgb_indices:
|
| 91 |
+
- 2
|
| 92 |
+
- 1
|
| 93 |
+
- 0
|
| 94 |
+
# Directory roots to training, validation and test datasplits:
|
| 95 |
+
test_data_root: ./../data/fine-tuning
|
| 96 |
+
test_label_data_root: ./../data/fine-tuning
|
| 97 |
+
test_split: ./../data/fine-tuning/test_data.txt
|
| 98 |
+
train_data_root: ./../data/fine-tuning
|
| 99 |
+
train_label_data_root: ./../data/fine-tuning
|
| 100 |
+
train_split: ./../data/fine-tuning/train_data.txt
|
| 101 |
+
val_data_root: ./../data/fine-tuning
|
| 102 |
+
val_label_data_root: ./../data/fine-tuning
|
| 103 |
+
val_split: ./../data/fine-tuning/val_data.txt
|
| 104 |
+
img_grep: "*_img.tif"
|
| 105 |
+
label_grep: "*_lab.tif"
|
| 106 |
+
means: # Mean value of the training dataset per band
|
| 107 |
+
- 11378.33724842
|
| 108 |
+
- 11379.51141294
|
| 109 |
+
- 11291.99698672
|
| 110 |
+
- 11116.38807044
|
| 111 |
+
- 10898.95680699
|
| 112 |
+
- 10686.41604621
|
| 113 |
+
- 10466.67864162
|
| 114 |
+
- 10456.52999209
|
| 115 |
+
- 10462.41327758
|
| 116 |
+
- 10464.24100298
|
| 117 |
+
- 10443.59591923
|
| 118 |
+
- 10448.53157824
|
| 119 |
+
- 10470.36129347
|
| 120 |
+
- 10454.74328843
|
| 121 |
+
- 10453.79858959
|
| 122 |
+
- 10452.88001737
|
| 123 |
+
stds: # Standard deviation of the training dataset per band
|
| 124 |
+
- 3125.36214152
|
| 125 |
+
- 3118.65965249
|
| 126 |
+
- 3088.88720386
|
| 127 |
+
- 3055.0881767
|
| 128 |
+
- 3026.73186213
|
| 129 |
+
- 2997.72812315
|
| 130 |
+
- 2968.12838628
|
| 131 |
+
- 2968.75857855
|
| 132 |
+
- 2969.94390514
|
| 133 |
+
- 2970.39202078
|
| 134 |
+
- 2964.1543642
|
| 135 |
+
- 2973.0155451
|
| 136 |
+
- 2985.89318262
|
| 137 |
+
- 2975.50852528
|
| 138 |
+
- 2973.00652761
|
| 139 |
+
- 2973.00330406
|
| 140 |
+
# Nodata value in label data
|
| 141 |
+
no_label_replace: -1
|
| 142 |
+
# Nodata value in the input data
|
| 143 |
+
no_data_replace: 0
|
| 144 |
+
### Model configuration
|
| 145 |
+
model:
|
| 146 |
+
class_path: terratorch.tasks.PixelwiseRegressionTask
|
| 147 |
+
init_args:
|
| 148 |
+
model_args:
|
| 149 |
+
backbone_pretrained: true
|
| 150 |
+
backbone: prithvi_s3_v1
|
| 151 |
+
backbone_pretrained_cfg_overlay:
|
| 152 |
+
file: ./../data/checkpoints/checkpoint.pt
|
| 153 |
+
backbone_pretrain_img_size: 42
|
| 154 |
+
backbone_drop_path: 0.1
|
| 155 |
+
backbone_bands:
|
| 156 |
+
- Oa01_reflectance
|
| 157 |
+
- Oa02_reflectance
|
| 158 |
+
- Oa03_reflectance
|
| 159 |
+
- Oa04_reflectance
|
| 160 |
+
- Oa05_reflectance
|
| 161 |
+
- Oa06_reflectance
|
| 162 |
+
- Oa07_reflectance
|
| 163 |
+
- Oa08_reflectance
|
| 164 |
+
- Oa09_reflectance
|
| 165 |
+
- Oa10_reflectance
|
| 166 |
+
- Oa11_reflectance
|
| 167 |
+
- Oa12_reflectance
|
| 168 |
+
- Oa16_reflectance
|
| 169 |
+
- Oa17_reflectance
|
| 170 |
+
- Oa18_reflectance
|
| 171 |
+
- Oa21_reflectance
|
| 172 |
+
head_dropout: 0.16194593880230534
|
| 173 |
+
head_channel_list: [64]
|
| 174 |
+
necks:
|
| 175 |
+
- name: SelectIndices
|
| 176 |
+
indices: [2, 5, 8, 11]
|
| 177 |
+
- name: ReshapeTokensToImage
|
| 178 |
+
- name: LearnedInterpolateToPyramidal
|
| 179 |
+
decoder: UNetDecoder
|
| 180 |
+
decoder_channels: [256, 128, 64, 32]
|
| 181 |
+
head_dropout: 0.1
|
| 182 |
+
loss: rmse
|
| 183 |
+
ignore_index: -1
|
| 184 |
+
freeze_backbone: false
|
| 185 |
+
freeze_decoder: false
|
| 186 |
+
model_factory: EncoderDecoderFactory
|
| 187 |
+
tiled_inference_parameters:
|
| 188 |
+
h_crop: 64
|
| 189 |
+
h_stride: 4
|
| 190 |
+
w_crop: 64
|
| 191 |
+
w_stride: 4
|
| 192 |
+
delta: 8
|
| 193 |
+
average_patches: true
|
| 194 |
+
optimizer:
|
| 195 |
+
class_path: torch.optim.AdamW
|
| 196 |
+
init_args:
|
| 197 |
+
lr: 0.00012
|
| 198 |
+
weight_decay: 0.3
|
| 199 |
+
lr_scheduler:
|
| 200 |
+
class_path: ReduceLROnPlateau
|
| 201 |
+
init_args:
|
| 202 |
+
monitor: val/loss
|