rosielickorish commited on
Commit
24f0bf0
·
verified ·
1 Parent(s): ccaa233

Include model weights (#3)

Browse files

- Include model weights and update config (87030d3fa3256db70b36402760fef70c1d6095dc)

Files changed (3) hide show
  1. checkpoint.pt +3 -0
  2. config.yaml +64 -201
  3. model_architecture.png +2 -2
checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9555cad115952cbaa0a00864fd6d3467c37390efaa17d0cf45de7c1861153d3
3
+ size 166136110
config.yaml CHANGED
@@ -1,202 +1,65 @@
1
- # lightning.pytorch==2.1.1
2
- seed_everything: 42
3
- out_dtype: float32
4
- custom_modules_path: ./../custom_modules/
5
- ### Trainer configuration
6
- trainer:
7
- accelerator: auto
8
- strategy: auto
9
- devices: auto
10
- num_nodes: 1
11
- # precision: 16-mixed
12
- logger:
13
- class_path: TensorBoardLogger
14
- init_args:
15
- save_dir: ./../data/
16
- name: model_runs
17
- callbacks:
18
- - class_path: LearningRateMonitor
19
- init_args:
20
- logging_interval: epoch
21
- - class_path: EarlyStopping
22
- init_args:
23
- monitor: val/loss
24
- patience: 100
25
- max_epochs: 1
26
- check_val_every_n_epoch: 1
27
- log_every_n_steps: 5
28
- enable_checkpointing: true
29
- default_root_dir: ./../data/
30
 
31
- ### Data configuration
32
- data:
33
- class_path: terratorch.datamodules.GenericNonGeoPixelwiseRegressionDataModule
34
- init_args:
35
- batch_size: 8
36
- num_workers: 2
37
- train_transform:
38
- - class_path: albumentations.HorizontalFlip
39
- init_args:
40
- p: 0.5
41
- - class_path: albumentations.RandomCrop
42
- init_args:
43
- height: 42
44
- width: 42
45
- - class_path: albumentations.Rotate
46
- init_args:
47
- limit: 30
48
- border_mode: 0 # cv2.BORDER_CONSTANT
49
- value: 0
50
- # mask_value: 1
51
- p: 0.5
52
- - class_path: ToTensorV2
53
- # Specify all bands which are in the input data.
54
- # -1 are placeholders for bands that are in the data but that we will discard
55
- dataset_bands:
56
- - Oa01_reflectance
57
- - Oa02_reflectance
58
- - Oa03_reflectance
59
- - Oa04_reflectance
60
- - Oa05_reflectance
61
- - Oa06_reflectance
62
- - Oa07_reflectance
63
- - Oa08_reflectance
64
- - Oa09_reflectance
65
- - Oa10_reflectance
66
- - Oa11_reflectance
67
- - Oa12_reflectance
68
- - Oa16_reflectance
69
- - Oa17_reflectance
70
- - Oa18_reflectance
71
- - Oa21_reflectance
72
- - SST
73
- output_bands: #Specify the bands which are used from the input data.
74
- - Oa01_reflectance
75
- - Oa02_reflectance
76
- - Oa03_reflectance
77
- - Oa04_reflectance
78
- - Oa05_reflectance
79
- - Oa06_reflectance
80
- - Oa07_reflectance
81
- - Oa08_reflectance
82
- - Oa09_reflectance
83
- - Oa10_reflectance
84
- - Oa11_reflectance
85
- - Oa12_reflectance
86
- - Oa16_reflectance
87
- - Oa17_reflectance
88
- - Oa18_reflectance
89
- - Oa21_reflectance
90
- rgb_indices:
91
- - 2
92
- - 1
93
- - 0
94
- # Directory roots to training, validation and test datasplits:
95
- test_data_root: ./../data/fine-tuning
96
- test_label_data_root: ./../data/fine-tuning
97
- test_split: ./../data/fine-tuning/test_data.txt
98
- train_data_root: ./../data/fine-tuning
99
- train_label_data_root: ./../data/fine-tuning
100
- train_split: ./../data/fine-tuning/train_data.txt
101
- val_data_root: ./../data/fine-tuning
102
- val_label_data_root: ./../data/fine-tuning
103
- val_split: ./../data/fine-tuning/val_data.txt
104
- img_grep: "*_img.tif"
105
- label_grep: "*_lab.tif"
106
- means: # Mean value of the training dataset per band
107
- - 11378.33724842
108
- - 11379.51141294
109
- - 11291.99698672
110
- - 11116.38807044
111
- - 10898.95680699
112
- - 10686.41604621
113
- - 10466.67864162
114
- - 10456.52999209
115
- - 10462.41327758
116
- - 10464.24100298
117
- - 10443.59591923
118
- - 10448.53157824
119
- - 10470.36129347
120
- - 10454.74328843
121
- - 10453.79858959
122
- - 10452.88001737
123
- stds: # Standard deviation of the training dataset per band
124
- - 3125.36214152
125
- - 3118.65965249
126
- - 3088.88720386
127
- - 3055.0881767
128
- - 3026.73186213
129
- - 2997.72812315
130
- - 2968.12838628
131
- - 2968.75857855
132
- - 2969.94390514
133
- - 2970.39202078
134
- - 2964.1543642
135
- - 2973.0155451
136
- - 2985.89318262
137
- - 2975.50852528
138
- - 2973.00652761
139
- - 2973.00330406
140
- # Nodata value in label data
141
- no_label_replace: -1
142
- # Nodata value in the input data
143
- no_data_replace: 0
144
- ### Model configuration
145
- model:
146
- class_path: terratorch.tasks.PixelwiseRegressionTask
147
- init_args:
148
- model_args:
149
- backbone_pretrained: true
150
- backbone: prithvi_s3_v1
151
- backbone_pretrained_cfg_overlay:
152
- file: ./../data/checkpoints/checkpoint.pt
153
- backbone_pretrain_img_size: 42
154
- backbone_drop_path: 0.1
155
- backbone_bands:
156
- - Oa01_reflectance
157
- - Oa02_reflectance
158
- - Oa03_reflectance
159
- - Oa04_reflectance
160
- - Oa05_reflectance
161
- - Oa06_reflectance
162
- - Oa07_reflectance
163
- - Oa08_reflectance
164
- - Oa09_reflectance
165
- - Oa10_reflectance
166
- - Oa11_reflectance
167
- - Oa12_reflectance
168
- - Oa16_reflectance
169
- - Oa17_reflectance
170
- - Oa18_reflectance
171
- - Oa21_reflectance
172
- head_dropout: 0.16194593880230534
173
- head_channel_list: [64]
174
- necks:
175
- - name: SelectIndices
176
- indices: [2, 5, 8, 11]
177
- - name: ReshapeTokensToImage
178
- - name: LearnedInterpolateToPyramidal
179
- decoder: UNetDecoder
180
- decoder_channels: [256, 128, 64, 32]
181
- head_dropout: 0.1
182
- loss: rmse
183
- ignore_index: -1
184
- freeze_backbone: false
185
- freeze_decoder: false
186
- model_factory: EncoderDecoderFactory
187
- tiled_inference_parameters:
188
- h_crop: 64
189
- h_stride: 4
190
- w_crop: 64
191
- w_stride: 4
192
- delta: 8
193
- average_patches: true
194
- optimizer:
195
- class_path: torch.optim.AdamW
196
- init_args:
197
- lr: 0.00012
198
- weight_decay: 0.3
199
- lr_scheduler:
200
- class_path: ReduceLROnPlateau
201
- init_args:
202
- monitor: val/loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ {
3
+ "architecture": "PRITHVI_EO",
4
+ "num_features": 768,
5
+ "pretrained_cfg": {
6
+ "img_size": [1, 42, 42],
7
+ "patch_size": [1, 2, 2],
8
+ "num_frames": 1,
9
+ "in_chans": 17,
10
+ "embed_dim": 512,
11
+ "depth": 12,
12
+ "num_heads": 8,
13
+ "decoder_embed_dim": 256,
14
+ "decoder_depth": 4,
15
+ "decoder_num_heads": 8,
16
+ "mlp_ratio": 4,
17
+ "mask_ratio": 0.75,
18
+ "bands": ["Oa01_reflectance", "Oa02_reflectance", "Oa03_reflectance",
19
+ "Oa04_reflectance", "Oa05_reflectance", "Oa06_reflectance",
20
+ "Oa07_reflectance", "Oa08_reflectance", "Oa09_reflectance",
21
+ "Oa10_reflectance", "Oa11_reflectance", "Oa12_reflectance",
22
+ "Oa16_reflectance", "Oa17_reflectance", "Oa18_reflectance",
23
+ "Oa21_reflectance", "SST"],
24
+ "mean": [
25
+ 0.0235427398,
26
+ 0.0226303495,
27
+ 0.0199877248,
28
+ 0.0166938124,
29
+ 0.0119924026,
30
+ 0.00767917988,
31
+ 0.00251636861,
32
+ 0.00189688827,
33
+ 0.0019271833,
34
+ 0.0019056457,
35
+ 0.00103529217,
36
+ 0.00056689044,
37
+ 0.000595696267,
38
+ 0.000402757423,
39
+ 0.000423631744,
40
+ 0.000105166233,
41
+ 293.908469
42
+ ],
43
+ "std": [
44
+ 0.00776708,
45
+ 0.00733259,
46
+ 0.00633057,
47
+ 0.00615707,
48
+ 0.00610327,
49
+ 0.0066378,
50
+ 0.00539699,
51
+ 0.00511585,
52
+ 0.0050785,
53
+ 0.00507704,
54
+ 0.00484563,
55
+ 0.00415998,
56
+ 0.00441236,
57
+ 0.00408463,
58
+ 0.00400387,
59
+ 0.00370793,
60
+ 2.67577808
61
+ ],
62
+ "origin_url": "<<ADD URL>>",
63
+ "paper_ids": "<<ADD>>"
64
+ }
65
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_architecture.png CHANGED

Git LFS Details

  • SHA256: f2ab39c3e45ccb75a0b6f50498f92c140b1a169d176c53bbce8ffb49a50cdf15
  • Pointer size: 132 Bytes
  • Size of remote file: 4.01 MB

Git LFS Details

  • SHA256: 39f18f8ec736aa8e14043f3903cd06c1d656afcaccf726c0a0ca1b067d68816b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.88 MB