eacortes commited on
Commit
bad1add
·
verified ·
1 Parent(s): 3016959

Upload 14 files

Browse files
README.md CHANGED
@@ -47,14 +47,14 @@ library_name: sentence-transformers
47
  metrics:
48
  - spearman
49
  co2_eq_emissions:
50
- emissions: 4039.5232961852894
51
- energy_consumed: 19.679154905865374
52
  source: codecarbon
53
  training_type: fine-tuning
54
  on_cloud: false
55
  cpu_model: AMD Ryzen 7 3700X 8-Core Processor
56
  ram_total_size: 62.69887161254883
57
- hours_used: 74.966
58
  hardware_used: 2 x NVIDIA GeForce RTX 3090
59
  model-index:
60
  - name: 'ChemMRL: SMILES Matryoshka Representation Learning Embedding Transformer'
@@ -67,7 +67,7 @@ model-index:
67
  type: pubchem_10m_genmol_similarity_validation
68
  metrics:
69
  - type: spearman
70
- value: 0.9881056976837288
71
  name: Spearman
72
  - task:
73
  type: semantic-similarity
@@ -77,9 +77,8 @@ model-index:
77
  type: pubchem_10m_genmol_similarity_test
78
  metrics:
79
  - type: spearman
80
- value: 0.988127555600757
81
  name: Spearman
82
- new_version: Derify/ChemMRL
83
  ---
84
 
85
  # ChemMRL: SMILES Matryoshka Representation Learning Embedding Transformer
@@ -146,9 +145,9 @@ print(embeddings.shape)
146
  # Get the similarity scores for the embeddings
147
  similarities = model.backbone.similarity(embeddings, embeddings)
148
  print(similarities)
149
- # tensor([[1.0000, 0.4184, 0.0166],
150
- # [0.4158, 1.0000, 0.0136],
151
- # [0.0167, 0.0137, 1.0000]])
152
  ```
153
 
154
  ### Direct Usage (Sentence Transformers)
@@ -186,9 +185,9 @@ print(embeddings.shape)
186
  # Get the similarity scores for the embeddings
187
  similarities = model.similarity(embeddings, embeddings)
188
  print(similarities)
189
- # tensor([[1.0000, 0.5887, 0.0327],
190
- # [0.5887, 1.0000, 0.0269],
191
- # [0.0327, 0.0269, 1.0000]])
192
  ```
193
 
194
  </details>
@@ -209,8 +208,8 @@ print(similarities)
209
 
210
  | Split | Metric | Value |
211
  | :------------- | :----------- | :---------- |
212
- | **validation** | **spearman** | **0.98811** |
213
- | **test** | **spearman** | **0.98813** |
214
 
215
  ## Training Details
216
 
@@ -236,11 +235,11 @@ print(similarities)
236
  ```json
237
  {
238
  "loss": "TanimotoSentLoss",
239
- "n_layers_per_step": 11,
240
- "last_layer_weight": 1.0,
241
- "prior_layers_weight": 1.5,
242
- "kl_div_weight": 0.5,
243
- "kl_temperature": 0.3,
244
  "matryoshka_dims": [
245
  1024,
246
  512,
@@ -261,7 +260,7 @@ print(similarities)
261
  1,
262
  1
263
  ],
264
- "n_dims_per_step": 4
265
  }
266
  ```
267
 
@@ -287,11 +286,11 @@ print(similarities)
287
  ```json
288
  {
289
  "loss": "TanimotoSentLoss",
290
- "n_layers_per_step": 11,
291
- "last_layer_weight": 1.0,
292
- "prior_layers_weight": 1.5,
293
- "kl_div_weight": 0.5,
294
- "kl_temperature": 0.3,
295
  "matryoshka_dims": [
296
  1024,
297
  512,
@@ -312,7 +311,7 @@ print(similarities)
312
  1,
313
  1
314
  ],
315
- "n_dims_per_step": 4
316
  }
317
  ```
318
 
@@ -334,7 +333,7 @@ print(similarities)
334
  - `tf32`: True
335
  - `optim`: stable_adamw
336
  - `optim_args`: decouple_lr=True,max_lr=8.0e-6
337
- - `dataloader_pin_memory`: False
338
  - `eval_on_start`: True
339
 
340
  #### All Hyperparameters
@@ -416,7 +415,7 @@ print(similarities)
416
  - `ddp_find_unused_parameters`: None
417
  - `ddp_bucket_cap_mb`: None
418
  - `ddp_broadcast_buffers`: False
419
- - `dataloader_pin_memory`: False
420
  - `dataloader_persistent_workers`: False
421
  - `skip_memory_metrics`: True
422
  - `use_legacy_prediction_loop`: False
@@ -427,7 +426,7 @@ print(similarities)
427
  - `hub_private_repo`: None
428
  - `hub_always_push`: False
429
  - `hub_revision`: None
430
- - `gradient_checkpointing`: False
431
  - `gradient_checkpointing_kwargs`: None
432
  - `include_inputs_for_metrics`: False
433
  - `include_for_metrics`: []
@@ -467,41 +466,41 @@ print(similarities)
467
 
468
  | Epoch | Step | Training Loss | pubchem 10m genmol similarity loss | pubchem_10m_genmol_similarity_spearman |
469
  | :----: | :----: | :-----------: | :--------------------------------: | :------------------------------------: |
470
- | 0 | 0 | - | 85.7997 | 0.7261 |
471
- | 0.0000 | 1 | 69.0605 | - | - |
472
- | 0.2477 | 25000 | 47.1696 | - | - |
473
- | 0.2500 | 25235 | - | 56.9634 | 0.8997 |
474
- | 0.4978 | 50250 | 45.6212 | - | - |
475
- | 0.5000 | 50470 | - | 55.4366 | 0.9599 |
476
- | 0.7479 | 75500 | 45.1404 | - | - |
477
- | 0.7500 | 75705 | - | 54.5667 | 0.9755 |
478
- | 0.9981 | 100750 | 44.5023 | - | - |
479
- | 1.0000 | 100940 | - | 54.1244 | 0.9810 |
480
- | 1.2482 | 126000 | 43.7545 | - | - |
481
- | 1.2500 | 126175 | - | 53.6974 | 0.9838 |
482
- | 1.4984 | 151250 | 43.7865 | - | - |
483
- | 1.5000 | 151410 | - | 53.4775 | 0.9855 |
484
- | 1.7485 | 176500 | 43.3512 | - | - |
485
- | 1.7499 | 176645 | - | 53.3775 | 0.9866 |
486
- | 1.9987 | 201750 | 43.5808 | - | - |
487
- | 1.9999 | 201880 | - | 53.3119 | 0.9874 |
488
- | 2.2488 | 227000 | 43.281 | - | - |
489
- | 2.2499 | 227115 | - | 53.1854 | 0.9879 |
490
- | 2.4989 | 252250 | 43.3097 | - | - |
491
- | 2.4999 | 252350 | - | 53.1972 | 0.9880 |
492
- | 2.7491 | 277500 | 43.2376 | - | - |
493
- | 2.7499 | 277585 | - | 53.1833 | 0.9881 |
494
- | 2.9992 | 302750 | 43.2006 | - | - |
495
- | 2.9999 | 302820 | - | 53.1241 | 0.9881 |
496
- | 3.0000 | 302829 | - | - | 0.98811 |
497
 
498
  </details>
499
 
500
  ### Environmental Impact
501
  Carbon emissions were measured using [CodeCarbon](https://github.com/mlco2/codecarbon).
502
- - **Energy Consumed**: 19.679 kWh
503
- - **Carbon Emitted**: 4.040 kg of CO2
504
- - **Hours Used**: 74.966 hours
505
 
506
  ### Training Hardware
507
  - **On Cloud**: No
@@ -511,11 +510,11 @@ Carbon emissions were measured using [CodeCarbon](https://github.com/mlco2/codec
511
 
512
  ### Framework Versions
513
  - Python: 3.13.7
514
- - Sentence Transformers: 5.1.1
515
  - Transformers: 4.57.1
516
  - PyTorch: 2.8.0+cu128
517
  - Accelerate: 1.10.1
518
- - Datasets: 3.6.0
519
  - Tokenizers: 0.22.1
520
 
521
  ## Citation
 
47
  metrics:
48
  - spearman
49
  co2_eq_emissions:
50
+ emissions: 6350.153020081601
51
+ energy_consumed: 30.935740629629628
52
  source: codecarbon
53
  training_type: fine-tuning
54
  on_cloud: false
55
  cpu_model: AMD Ryzen 7 3700X 8-Core Processor
56
  ram_total_size: 62.69887161254883
57
+ hours_used: 116.388
58
  hardware_used: 2 x NVIDIA GeForce RTX 3090
59
  model-index:
60
  - name: 'ChemMRL: SMILES Matryoshka Representation Learning Embedding Transformer'
 
67
  type: pubchem_10m_genmol_similarity_validation
68
  metrics:
69
  - type: spearman
70
+ value: 0.989142152637452
71
  name: Spearman
72
  - task:
73
  type: semantic-similarity
 
77
  type: pubchem_10m_genmol_similarity_test
78
  metrics:
79
  - type: spearman
80
+ value: 0.9891625268496924
81
  name: Spearman
 
82
  ---
83
 
84
  # ChemMRL: SMILES Matryoshka Representation Learning Embedding Transformer
 
145
  # Get the similarity scores for the embeddings
146
  similarities = model.backbone.similarity(embeddings, embeddings)
147
  print(similarities)
148
+ # tensor([[1.0000, 0.4179, 0.0165],
149
+ # [0.4179, 1.0000, 0.0140],
150
+ # [0.0165, 0.0140, 1.0000]])
151
  ```
152
 
153
  ### Direct Usage (Sentence Transformers)
 
185
  # Get the similarity scores for the embeddings
186
  similarities = model.similarity(embeddings, embeddings)
187
  print(similarities)
188
+ # tensor([[1.0000, 0.5894, 0.0326],
189
+ # [0.5894, 1.0000, 0.0275],
190
+ # [0.0326, 0.0275, 1.0000]])
191
  ```
192
 
193
  </details>
 
208
 
209
  | Split | Metric | Value |
210
  | :------------- | :----------- | :---------- |
211
+ | **validation** | **spearman** | **0.98914** |
212
+ | **test** | **spearman** | **0.98916** |
213
 
214
  ## Training Details
215
 
 
235
  ```json
236
  {
237
  "loss": "TanimotoSentLoss",
238
+ "n_layers_per_step": -1,
239
+ "last_layer_weight": 2.0,
240
+ "prior_layers_weight": 1.0,
241
+ "kl_div_weight": 0.0,
242
+ "kl_temperature": 0.0,
243
  "matryoshka_dims": [
244
  1024,
245
  512,
 
260
  1,
261
  1
262
  ],
263
+ "n_dims_per_step": -1
264
  }
265
  ```
266
 
 
286
  ```json
287
  {
288
  "loss": "TanimotoSentLoss",
289
+ "n_layers_per_step": -1,
290
+ "last_layer_weight": 2.0,
291
+ "prior_layers_weight": 1.0,
292
+ "kl_div_weight": 0.0,
293
+ "kl_temperature": 0.0,
294
  "matryoshka_dims": [
295
  1024,
296
  512,
 
311
  1,
312
  1
313
  ],
314
+ "n_dims_per_step": -1
315
  }
316
  ```
317
 
 
333
  - `tf32`: True
334
  - `optim`: stable_adamw
335
  - `optim_args`: decouple_lr=True,max_lr=8.0e-6
336
+ - `gradient_checkpointing`: True
337
  - `eval_on_start`: True
338
 
339
  #### All Hyperparameters
 
415
  - `ddp_find_unused_parameters`: None
416
  - `ddp_bucket_cap_mb`: None
417
  - `ddp_broadcast_buffers`: False
418
+ - `dataloader_pin_memory`: True
419
  - `dataloader_persistent_workers`: False
420
  - `skip_memory_metrics`: True
421
  - `use_legacy_prediction_loop`: False
 
426
  - `hub_private_repo`: None
427
  - `hub_always_push`: False
428
  - `hub_revision`: None
429
+ - `gradient_checkpointing`: True
430
  - `gradient_checkpointing_kwargs`: None
431
  - `include_inputs_for_metrics`: False
432
  - `include_for_metrics`: []
 
466
 
467
  | Epoch | Step | Training Loss | pubchem 10m genmol similarity loss | pubchem_10m_genmol_similarity_spearman |
468
  | :----: | :----: | :-----------: | :--------------------------------: | :------------------------------------: |
469
+ | 0 | 0 | - | 297.6136 | 0.7261 |
470
+ | 0.0000 | 1 | 244.6862 | - | - |
471
+ | 0.2477 | 25000 | 161.5037 | - | - |
472
+ | 0.2500 | 25235 | - | 195.4624 | 0.9067 |
473
+ | 0.4978 | 50250 | 155.7822 | - | - |
474
+ | 0.5000 | 50470 | - | 189.4068 | 0.9655 |
475
+ | 0.7479 | 75500 | 152.7915 | - | - |
476
+ | 0.7500 | 75705 | - | 186.3661 | 0.9780 |
477
+ | 0.9981 | 100750 | 151.0411 | - | - |
478
+ | 1.0000 | 100940 | - | 184.6362 | 0.9829 |
479
+ | 1.2482 | 126000 | 149.8544 | - | - |
480
+ | 1.2500 | 126175 | - | 183.5648 | 0.9855 |
481
+ | 1.4984 | 151250 | 149.2916 | - | - |
482
+ | 1.5000 | 151410 | - | 182.8947 | 0.9868 |
483
+ | 1.7485 | 176500 | 148.7942 | - | - |
484
+ | 1.7499 | 176645 | - | 182.3662 | 0.9879 |
485
+ | 1.9987 | 201750 | 148.3459 | - | - |
486
+ | 1.9999 | 201880 | - | 181.9855 | 0.9885 |
487
+ | 2.2488 | 227000 | 148.0316 | - | - |
488
+ | 2.2499 | 227115 | - | 181.7683 | 0.9889 |
489
+ | 2.4989 | 252250 | 147.8658 | - | - |
490
+ | 2.4999 | 252350 | - | 181.6711 | 0.9890 |
491
+ | 2.7491 | 277500 | 147.9642 | - | - |
492
+ | 2.7499 | 277585 | - | 181.6077 | 0.9891 |
493
+ | 2.9992 | 302750 | 147.8874 | - | - |
494
+ | 2.9999 | 302820 | - | 181.6066 | 0.9891 |
495
+ | 3.0000 | 302829 | - | - | 0.98914 |
496
 
497
  </details>
498
 
499
  ### Environmental Impact
500
  Carbon emissions were measured using [CodeCarbon](https://github.com/mlco2/codecarbon).
501
+ - **Energy Consumed**: 30.936 kWh
502
+ - **Carbon Emitted**: 6.350 kg of CO2
503
+ - **Hours Used**: 116.388 hours
504
 
505
  ### Training Hardware
506
  - **On Cloud**: No
 
510
 
511
  ### Framework Versions
512
  - Python: 3.13.7
513
+ - Sentence Transformers: 5.1.2
514
  - Transformers: 4.57.1
515
  - PyTorch: 2.8.0+cu128
516
  - Accelerate: 1.10.1
517
+ - Datasets: 4.3.0
518
  - Tokenizers: 0.22.1
519
 
520
  ## Citation
config_chem_mrl.json CHANGED
@@ -1,11 +1,11 @@
1
  {
2
- "__version__": "0.7.4",
3
  "embedding_pooling": "mean",
4
  "eval_metric": "spearman",
5
  "eval_similarity_fct": "tanimoto",
6
- "kl_div_weight": 0.5,
7
- "kl_temperature": 0.3,
8
- "last_layer_weight": 1.0,
9
  "loss_func": "tanimotosentloss",
10
  "model_name": "Derify/ModChemBERT-IR-BASE",
11
  "mrl_dimension_weights": [
@@ -28,10 +28,9 @@
28
  16,
29
  8
30
  ],
31
- "n_dims_per_step": 4,
32
- "n_layers_per_step": 11,
33
- "prior_layers_weight": 1.5,
34
  "tanimoto_similarity_loss_func": null,
35
- "use_2d_matryoshka": true,
36
- "use_query_tokenizer": false
37
  }
 
1
  {
2
+ "__version__": "0.8.0",
3
  "embedding_pooling": "mean",
4
  "eval_metric": "spearman",
5
  "eval_similarity_fct": "tanimoto",
6
+ "kl_div_weight": 0.0,
7
+ "kl_temperature": 0.0,
8
+ "last_layer_weight": 2.0,
9
  "loss_func": "tanimotosentloss",
10
  "model_name": "Derify/ModChemBERT-IR-BASE",
11
  "mrl_dimension_weights": [
 
28
  16,
29
  8
30
  ],
31
+ "n_dims_per_step": -1,
32
+ "n_layers_per_step": -1,
33
+ "prior_layers_weight": 1.0,
34
  "tanimoto_similarity_loss_func": null,
35
+ "use_2d_matryoshka": true
 
36
  }
config_sentence_transformers.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "model_type": "SentenceTransformer",
3
  "__version__": {
4
- "sentence_transformers": "5.1.1",
5
  "transformers": "4.57.1",
6
  "pytorch": "2.8.0+cu128"
7
  },
 
1
  {
2
  "model_type": "SentenceTransformer",
3
  "__version__": {
4
+ "sentence_transformers": "5.1.2",
5
  "transformers": "4.57.1",
6
  "pytorch": "2.8.0+cu128"
7
  },
configuration_modchembert.py CHANGED
@@ -37,14 +37,15 @@ class ModChemBertConfig(ModernBertConfig):
37
  - "max_cls": Element-wise max pooling over last k hidden states, then take CLS token
38
  - "cls_mha": Multi-head attention with CLS token as query and full sequence as keys/values
39
  - "max_seq_mha": Max pooling over last k states + multi-head attention with CLS as query
 
40
  - "max_seq_mean": Max pooling over last k hidden states, then mean pooling over sequence
41
  Defaults to "sum_mean".
42
  classifier_pooling_num_attention_heads (int, optional): Number of attention heads for multi-head attention
43
- pooling strategies (cls_mha, max_seq_mha). Defaults to 4.
44
  classifier_pooling_attention_dropout (float, optional): Dropout probability for multi-head attention
45
- pooling strategies (cls_mha, max_seq_mha). Defaults to 0.0.
46
- classifier_pooling_last_k (int, optional): Number of last hidden layers to use for max pooling
47
- strategies (max_cls, max_seq_mha, max_seq_mean). Defaults to 8.
48
  *args: Variable length argument list passed to ModernBertConfig.
49
  **kwargs: Arbitrary keyword arguments passed to ModernBertConfig.
50
 
@@ -68,6 +69,7 @@ class ModChemBertConfig(ModernBertConfig):
68
  "max_cls",
69
  "cls_mha",
70
  "max_seq_mha",
 
71
  "max_seq_mean",
72
  ] = "max_seq_mha",
73
  classifier_pooling_num_attention_heads: int = 4,
@@ -75,6 +77,25 @@ class ModChemBertConfig(ModernBertConfig):
75
  classifier_pooling_last_k: int = 8,
76
  **kwargs,
77
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # Pass classifier_pooling="cls" to circumvent ValueError in ModernBertConfig init
79
  super().__init__(*args, classifier_pooling="cls", **kwargs)
80
  # Override with custom value
 
37
  - "max_cls": Element-wise max pooling over last k hidden states, then take CLS token
38
  - "cls_mha": Multi-head attention with CLS token as query and full sequence as keys/values
39
  - "max_seq_mha": Max pooling over last k states + multi-head attention with CLS as query
40
+ - "mean_seq_mha": Mean pooling over last k states + multi-head attention with CLS as query
41
  - "max_seq_mean": Max pooling over last k hidden states, then mean pooling over sequence
42
  Defaults to "sum_mean".
43
  classifier_pooling_num_attention_heads (int, optional): Number of attention heads for multi-head attention
44
+ pooling strategies (cls_mha, max_seq_mha, mean_seq_mha). Defaults to 4.
45
  classifier_pooling_attention_dropout (float, optional): Dropout probability for multi-head attention
46
+ pooling strategies (cls_mha, max_seq_mha, mean_seq_mha). Defaults to 0.0.
47
+ classifier_pooling_last_k (int, optional): Number of last hidden layers to use for max/mean pooling
48
+ strategies (max_cls, max_seq_mha, mean_seq_mha, max_seq_mean). Defaults to 8.
49
  *args: Variable length argument list passed to ModernBertConfig.
50
  **kwargs: Arbitrary keyword arguments passed to ModernBertConfig.
51
 
 
69
  "max_cls",
70
  "cls_mha",
71
  "max_seq_mha",
72
+ "mean_seq_mha",
73
  "max_seq_mean",
74
  ] = "max_seq_mha",
75
  classifier_pooling_num_attention_heads: int = 4,
 
77
  classifier_pooling_last_k: int = 8,
78
  **kwargs,
79
  ):
80
+ valid_classifier_pooling_options = [
81
+ "cls",
82
+ "mean",
83
+ "sum_mean",
84
+ "sum_sum",
85
+ "mean_mean",
86
+ "mean_sum",
87
+ "max_cls",
88
+ "cls_mha",
89
+ "max_seq_mha",
90
+ "mean_seq_mha",
91
+ "max_seq_mean",
92
+ ]
93
+ if classifier_pooling not in valid_classifier_pooling_options:
94
+ raise ValueError(
95
+ f"Invalid value for `classifier_pooling`, should be one of {valid_classifier_pooling_options}, "
96
+ f"but is {classifier_pooling}."
97
+ )
98
+
99
  # Pass classifier_pooling="cls" to circumvent ValueError in ModernBertConfig init
100
  super().__init__(*args, classifier_pooling="cls", **kwargs)
101
  # Override with custom value
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4bdfc920fcc3c65314ef0cf0f5129884443d23748f09e23467015d54d5338ce4
3
  size 397110232
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed6105dfe64c12207b1e1155a0d85c30cdadebcb42a6f0ea216dc36c3c28cf0c
3
  size 397110232
modeling_modchembert.py CHANGED
@@ -19,9 +19,9 @@
19
  # Modifications include:
20
  # - Additional classifier_pooling options for ModChemBertForSequenceClassification
21
  # - sum_mean, sum_sum, mean_sum, mean_mean: from ChemLM (utilizes all hidden states)
22
- # - max_cls, cls_mha, max_seq_mha: from MaxPoolBERT (utilizes last k hidden states)
23
  # - max_seq_mean: a merge between sum_mean and max_cls (utilizes last k hidden states)
24
- # - Addition of ModChemBertPoolingAttention for cls_mha and max_seq_mha pooling options
25
 
26
  import copy
27
  import math
@@ -122,11 +122,7 @@ class ModChemBertPoolingAttention(nn.Module):
122
  self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
123
 
124
  self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
125
- self.out_drop = (
126
- nn.Dropout(config.attention_dropout)
127
- if config.attention_dropout > 0.0
128
- else nn.Identity()
129
- )
130
  self.pruned_heads = set()
131
 
132
  def forward(
@@ -179,14 +175,9 @@ class ModChemBertModel(ModernBertPreTrainedModel):
179
  self.config = config
180
  self.embeddings = ModernBertEmbeddings(config)
181
  self.layers = nn.ModuleList(
182
- [
183
- ModernBertEncoderLayer(config, layer_id)
184
- for layer_id in range(config.num_hidden_layers)
185
- ]
186
- )
187
- self.final_norm = nn.LayerNorm(
188
- config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
189
  )
 
190
  self.gradient_checkpointing = False
191
  self.post_init()
192
 
@@ -228,13 +219,9 @@ class ModChemBertModel(ModernBertPreTrainedModel):
228
  seq_len (`int`, *optional*):
229
  Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
230
  """ # noqa: E501
231
- output_attentions = (
232
- output_attentions if output_attentions is not None else self.config.output_attentions
233
- )
234
  output_hidden_states = (
235
- output_hidden_states
236
- if output_hidden_states is not None
237
- else self.config.output_hidden_states
238
  )
239
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
240
 
@@ -316,25 +303,19 @@ class ModChemBertModel(ModernBertPreTrainedModel):
316
  )
317
  if all_hidden_states is not None:
318
  all_hidden_states = tuple(
319
- _pad_modernbert_output(
320
- inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len
321
- ) # type: ignore
322
  for hs in all_hidden_states
323
  )
324
 
325
  if not return_dict:
326
- return tuple(
327
- v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None
328
- )
329
  return BaseModelOutput(
330
  last_hidden_state=hidden_states, # type: ignore
331
  hidden_states=all_hidden_states, # type: ignore
332
  attentions=all_self_attentions,
333
  )
334
 
335
- def _update_attention_mask(
336
- self, attention_mask: torch.Tensor, output_attentions: bool
337
- ) -> torch.Tensor:
338
  if output_attentions:
339
  if self.config._attn_implementation == "sdpa":
340
  logger.warning_once( # type: ignore
@@ -357,16 +338,9 @@ class ModChemBertModel(ModernBertPreTrainedModel):
357
  distance = torch.abs(rows - rows.T)
358
 
359
  # Create sliding window mask (1 for positions within window, 0 outside)
360
- window_mask = (
361
- (distance <= self.config.local_attention // 2)
362
- .unsqueeze(0)
363
- .unsqueeze(0)
364
- .to(attention_mask.device)
365
- )
366
  # Combine with existing mask
367
- sliding_window_mask = global_attention_mask.masked_fill(
368
- window_mask.logical_not(), torch.finfo(self.dtype).min
369
- )
370
 
371
  return global_attention_mask, sliding_window_mask # type: ignore
372
 
@@ -445,28 +419,22 @@ class ModChemBertForMaskedLM(InitWeightsMixin, ModernBertPreTrainedModel):
445
  device = input_ids.device if input_ids is not None else inputs_embeds.device # type: ignore
446
 
447
  if attention_mask is None:
448
- attention_mask = torch.ones(
449
- (batch_size, seq_len), device=device, dtype=torch.bool
450
- ) # type: ignore
451
 
452
  if inputs_embeds is None:
453
  with torch.no_grad():
454
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = (
455
- _unpad_modernbert_input(
456
- inputs=input_ids, # type: ignore
457
- attention_mask=attention_mask, # type: ignore
458
- position_ids=position_ids,
459
- labels=labels,
460
- )
461
- )
462
- else:
463
- inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = (
464
- _unpad_modernbert_input(
465
- inputs=inputs_embeds,
466
  attention_mask=attention_mask, # type: ignore
467
  position_ids=position_ids,
468
  labels=labels,
469
  )
 
 
 
 
 
 
470
  )
471
 
472
  outputs = self.model(
@@ -507,14 +475,8 @@ class ModChemBertForMaskedLM(InitWeightsMixin, ModernBertPreTrainedModel):
507
  loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
508
 
509
  if self.config._attn_implementation == "flash_attention_2":
510
- with (
511
- nullcontext()
512
- if self.config.repad_logits_with_grad or labels is None
513
- else torch.no_grad()
514
- ):
515
- logits = _pad_modernbert_output(
516
- inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len
517
- ) # type: ignore
518
 
519
  if not return_dict:
520
  output = (logits,)
@@ -537,7 +499,7 @@ class ModChemBertForSequenceClassification(InitWeightsMixin, ModernBertPreTraine
537
  self.config = config
538
 
539
  self.model = ModernBertModel(config)
540
- if self.config.classifier_pooling in {"cls_mha", "max_seq_mha"}:
541
  self.pooling_attn = ModChemBertPoolingAttention(config=self.config)
542
  else:
543
  self.pooling_attn = None
@@ -638,9 +600,7 @@ class ModChemBertForSequenceClassification(InitWeightsMixin, ModernBertPreTraine
638
  if self.config.problem_type is None:
639
  if self.num_labels == 1:
640
  self.config.problem_type = "regression"
641
- elif self.num_labels > 1 and (
642
- labels.dtype == torch.long or labels.dtype == torch.int
643
- ):
644
  self.config.problem_type = "single_label_classification"
645
  else:
646
  self.config.problem_type = "multi_label_classification"
@@ -689,6 +649,7 @@ def _pool_modchembert_output(
689
  - max_cls: Element-wise max pooling over the last k hidden states, then take CLS token
690
  - cls_mha: Multi-head attention with CLS token as query and full sequence as keys/values
691
  - max_seq_mha: Max pooling over last k states + multi-head attention with CLS as query
 
692
  - max_seq_mean: Max pooling over last k hidden states, then mean pooling over sequence
693
  - sum_mean: Sum all hidden states across layers, then mean pool over sequence
694
  - sum_sum: Sum all hidden states across layers, then sum pool over sequence
@@ -705,22 +666,20 @@ def _pool_modchembert_output(
705
  torch.Tensor: Pooled representation of shape (batch_size, hidden_size)
706
 
707
  Note:
708
- Some pooling strategies (cls_mha, max_seq_mha) require the module to have a pooling_attn
709
  attribute containing a ModChemBertPoolingAttention instance.
710
  """
711
  config = typing.cast(ModChemBertConfig, module.config)
712
  if config.classifier_pooling == "cls":
713
  last_hidden_state = last_hidden_state[:, 0]
714
  elif config.classifier_pooling == "mean":
715
- last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(
716
- dim=1
717
- ) / attention_mask.sum(dim=1, keepdim=True)
718
  elif config.classifier_pooling == "max_cls":
719
  k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
720
  theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
721
- pooled_seq = torch.max(
722
- theta, dim=1
723
- ).values # Element-wise max over k -> (batch, seq_len, hidden)
724
  last_hidden_state = pooled_seq[:, 0, :] # (batch, hidden)
725
  elif config.classifier_pooling == "cls_mha":
726
  # Similar to max_seq_mha but without the max pooling step
@@ -731,12 +690,13 @@ def _pool_modchembert_output(
731
  q=q, kv=last_hidden_state, attention_mask=attention_mask
732
  ) # (batch, seq_len, hidden)
733
  last_hidden_state = torch.mean(attn_out, dim=1)
734
- elif config.classifier_pooling == "max_seq_mha":
735
  k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
736
  theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
737
- pooled_seq = torch.max(
738
- theta, dim=1
739
- ).values # Element-wise max over k -> (batch, seq_len, hidden)
 
740
  # Query is pooled CLS token (position 0); Keys/Values are pooled sequence
741
  q = pooled_seq[:, 0, :].unsqueeze(1) # (batch, 1, hidden)
742
  q = q.expand(-1, pooled_seq.shape[1], -1) # (batch, seq_len, hidden)
@@ -747,9 +707,7 @@ def _pool_modchembert_output(
747
  elif config.classifier_pooling == "max_seq_mean":
748
  k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
749
  theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
750
- pooled_seq = torch.max(
751
- theta, dim=1
752
- ).values # Element-wise max over k -> (batch, seq_len, hidden)
753
  last_hidden_state = torch.mean(pooled_seq, dim=1) # Mean over sequence length
754
  elif config.classifier_pooling == "sum_mean":
755
  # ChemLM uses the mean of all hidden states
@@ -775,6 +733,7 @@ def _pool_modchembert_output(
775
 
776
 
777
  __all__ = [
 
778
  "ModChemBertForMaskedLM",
779
  "ModChemBertForSequenceClassification",
780
  ]
 
19
  # Modifications include:
20
  # - Additional classifier_pooling options for ModChemBertForSequenceClassification
21
  # - sum_mean, sum_sum, mean_sum, mean_mean: from ChemLM (utilizes all hidden states)
22
+ # - max_cls, cls_mha, max_seq_mha, mean_seq_mha: from MaxPoolBERT (utilizes last k hidden states)
23
  # - max_seq_mean: a merge between sum_mean and max_cls (utilizes last k hidden states)
24
+ # - Addition of ModChemBertPoolingAttention for cls_mha, max_seq_mha, and mean_seq_mha pooling options
25
 
26
  import copy
27
  import math
 
122
  self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
123
 
124
  self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
125
+ self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
 
 
 
 
126
  self.pruned_heads = set()
127
 
128
  def forward(
 
175
  self.config = config
176
  self.embeddings = ModernBertEmbeddings(config)
177
  self.layers = nn.ModuleList(
178
+ [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
 
 
 
 
 
 
179
  )
180
+ self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
181
  self.gradient_checkpointing = False
182
  self.post_init()
183
 
 
219
  seq_len (`int`, *optional*):
220
  Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
221
  """ # noqa: E501
222
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
223
  output_hidden_states = (
224
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
225
  )
226
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
227
 
 
303
  )
304
  if all_hidden_states is not None:
305
  all_hidden_states = tuple(
306
+ _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) # type: ignore
 
 
307
  for hs in all_hidden_states
308
  )
309
 
310
  if not return_dict:
311
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
 
 
312
  return BaseModelOutput(
313
  last_hidden_state=hidden_states, # type: ignore
314
  hidden_states=all_hidden_states, # type: ignore
315
  attentions=all_self_attentions,
316
  )
317
 
318
+ def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
 
 
319
  if output_attentions:
320
  if self.config._attn_implementation == "sdpa":
321
  logger.warning_once( # type: ignore
 
338
  distance = torch.abs(rows - rows.T)
339
 
340
  # Create sliding window mask (1 for positions within window, 0 outside)
341
+ window_mask = (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
 
 
 
 
 
342
  # Combine with existing mask
343
+ sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
 
 
344
 
345
  return global_attention_mask, sliding_window_mask # type: ignore
346
 
 
419
  device = input_ids.device if input_ids is not None else inputs_embeds.device # type: ignore
420
 
421
  if attention_mask is None:
422
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) # type: ignore
 
 
423
 
424
  if inputs_embeds is None:
425
  with torch.no_grad():
426
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
427
+ inputs=input_ids, # type: ignore
 
 
 
 
 
 
 
 
 
 
428
  attention_mask=attention_mask, # type: ignore
429
  position_ids=position_ids,
430
  labels=labels,
431
  )
432
+ else:
433
+ inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
434
+ inputs=inputs_embeds,
435
+ attention_mask=attention_mask, # type: ignore
436
+ position_ids=position_ids,
437
+ labels=labels,
438
  )
439
 
440
  outputs = self.model(
 
475
  loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
476
 
477
  if self.config._attn_implementation == "flash_attention_2":
478
+ with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
479
+ logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) # type: ignore
 
 
 
 
 
 
480
 
481
  if not return_dict:
482
  output = (logits,)
 
499
  self.config = config
500
 
501
  self.model = ModernBertModel(config)
502
+ if self.config.classifier_pooling in {"cls_mha", "max_seq_mha", "mean_seq_mha"}:
503
  self.pooling_attn = ModChemBertPoolingAttention(config=self.config)
504
  else:
505
  self.pooling_attn = None
 
600
  if self.config.problem_type is None:
601
  if self.num_labels == 1:
602
  self.config.problem_type = "regression"
603
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
 
 
604
  self.config.problem_type = "single_label_classification"
605
  else:
606
  self.config.problem_type = "multi_label_classification"
 
649
  - max_cls: Element-wise max pooling over the last k hidden states, then take CLS token
650
  - cls_mha: Multi-head attention with CLS token as query and full sequence as keys/values
651
  - max_seq_mha: Max pooling over last k states + multi-head attention with CLS as query
652
+ - mean_seq_mha: Mean pooling over last k states + multi-head attention with CLS as query
653
  - max_seq_mean: Max pooling over last k hidden states, then mean pooling over sequence
654
  - sum_mean: Sum all hidden states across layers, then mean pool over sequence
655
  - sum_sum: Sum all hidden states across layers, then sum pool over sequence
 
666
  torch.Tensor: Pooled representation of shape (batch_size, hidden_size)
667
 
668
  Note:
669
+ Some pooling strategies (cls_mha, max_seq_mha, mean_seq_mha) require the module to have a pooling_attn
670
  attribute containing a ModChemBertPoolingAttention instance.
671
  """
672
  config = typing.cast(ModChemBertConfig, module.config)
673
  if config.classifier_pooling == "cls":
674
  last_hidden_state = last_hidden_state[:, 0]
675
  elif config.classifier_pooling == "mean":
676
+ last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
677
+ dim=1, keepdim=True
678
+ )
679
  elif config.classifier_pooling == "max_cls":
680
  k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
681
  theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
682
+ pooled_seq = torch.max(theta, dim=1).values # Element-wise max over k -> (batch, seq_len, hidden)
 
 
683
  last_hidden_state = pooled_seq[:, 0, :] # (batch, hidden)
684
  elif config.classifier_pooling == "cls_mha":
685
  # Similar to max_seq_mha but without the max pooling step
 
690
  q=q, kv=last_hidden_state, attention_mask=attention_mask
691
  ) # (batch, seq_len, hidden)
692
  last_hidden_state = torch.mean(attn_out, dim=1)
693
+ elif config.classifier_pooling in {"max_seq_mha", "mean_seq_mha"}:
694
  k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
695
  theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
696
+ if config.classifier_pooling == "max_seq_mha":
697
+ pooled_seq = torch.max(theta, dim=1).values # Element-wise max over k -> (batch, seq_len, hidden)
698
+ else:
699
+ pooled_seq = torch.mean(theta, dim=1) # Element-wise mean over k -> (batch, seq_len, hidden)
700
  # Query is pooled CLS token (position 0); Keys/Values are pooled sequence
701
  q = pooled_seq[:, 0, :].unsqueeze(1) # (batch, 1, hidden)
702
  q = q.expand(-1, pooled_seq.shape[1], -1) # (batch, seq_len, hidden)
 
707
  elif config.classifier_pooling == "max_seq_mean":
708
  k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
709
  theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
710
+ pooled_seq = torch.max(theta, dim=1).values # Element-wise max over k -> (batch, seq_len, hidden)
 
 
711
  last_hidden_state = torch.mean(pooled_seq, dim=1) # Mean over sequence length
712
  elif config.classifier_pooling == "sum_mean":
713
  # ChemLM uses the mean of all hidden states
 
733
 
734
 
735
  __all__ = [
736
+ "ModChemBertModel",
737
  "ModChemBertForMaskedLM",
738
  "ModChemBertForSequenceClassification",
739
  ]
similarity_evaluation_pubchem_10m_genmol_similarity_float32_results.csv CHANGED
@@ -1,14 +1,14 @@
1
  epoch,steps,spearman
2
  0,0,0.7261446896400275
3
- 0.2499925700642937,25235,0.899727524994741
4
- 0.4999851401285874,50470,0.9599428082697957
5
- 0.7499777101928812,75705,0.9755030703217896
6
- 0.9999702802571748,100940,0.9809624466313892
7
- 1.2499628503214686,126175,0.9838128954121899
8
- 1.4999554203857621,151410,0.9854756886661312
9
- 1.749947990450056,176645,0.9865980464822579
10
- 1.9999405605143497,201880,0.9873943693937194
11
- 2.2499331305786434,227115,0.9878659546563734
12
- 2.499925700642937,252350,0.9879865870047979
13
- 2.749918270707231,277585,0.9881075350289332
14
- 2.9999108407715243,302820,0.9881056976837288
 
1
  epoch,steps,spearman
2
  0,0,0.7261446896400275
3
+ 0.2499925700642937,25235,0.906718265018918
4
+ 0.4999851401285874,50470,0.9655444741087182
5
+ 0.7499777101928812,75705,0.9779964615343857
6
+ 0.9999702802571748,100940,0.9828579834801283
7
+ 1.2499628503214686,126175,0.9855222540861318
8
+ 1.4999554203857621,151410,0.986820997047069
9
+ 1.749947990450056,176645,0.9879349539641308
10
+ 1.9999405605143497,201880,0.9885304751015874
11
+ 2.2499331305786434,227115,0.9889206748932795
12
+ 2.499925700642937,252350,0.989034117619882
13
+ 2.749918270707231,277585,0.9891366381020936
14
+ 2.9999108407715243,302820,0.9891427187036199