Upload 14 files
Browse files- README.md +60 -61
- config_chem_mrl.json +8 -9
- config_sentence_transformers.json +1 -1
- configuration_modchembert.py +25 -4
- model.safetensors +1 -1
- modeling_modchembert.py +38 -79
- similarity_evaluation_pubchem_10m_genmol_similarity_float32_results.csv +12 -12
README.md
CHANGED
|
@@ -47,14 +47,14 @@ library_name: sentence-transformers
|
|
| 47 |
metrics:
|
| 48 |
- spearman
|
| 49 |
co2_eq_emissions:
|
| 50 |
-
emissions:
|
| 51 |
-
energy_consumed:
|
| 52 |
source: codecarbon
|
| 53 |
training_type: fine-tuning
|
| 54 |
on_cloud: false
|
| 55 |
cpu_model: AMD Ryzen 7 3700X 8-Core Processor
|
| 56 |
ram_total_size: 62.69887161254883
|
| 57 |
-
hours_used:
|
| 58 |
hardware_used: 2 x NVIDIA GeForce RTX 3090
|
| 59 |
model-index:
|
| 60 |
- name: 'ChemMRL: SMILES Matryoshka Representation Learning Embedding Transformer'
|
|
@@ -67,7 +67,7 @@ model-index:
|
|
| 67 |
type: pubchem_10m_genmol_similarity_validation
|
| 68 |
metrics:
|
| 69 |
- type: spearman
|
| 70 |
-
value: 0.
|
| 71 |
name: Spearman
|
| 72 |
- task:
|
| 73 |
type: semantic-similarity
|
|
@@ -77,9 +77,8 @@ model-index:
|
|
| 77 |
type: pubchem_10m_genmol_similarity_test
|
| 78 |
metrics:
|
| 79 |
- type: spearman
|
| 80 |
-
value: 0.
|
| 81 |
name: Spearman
|
| 82 |
-
new_version: Derify/ChemMRL
|
| 83 |
---
|
| 84 |
|
| 85 |
# ChemMRL: SMILES Matryoshka Representation Learning Embedding Transformer
|
|
@@ -146,9 +145,9 @@ print(embeddings.shape)
|
|
| 146 |
# Get the similarity scores for the embeddings
|
| 147 |
similarities = model.backbone.similarity(embeddings, embeddings)
|
| 148 |
print(similarities)
|
| 149 |
-
# tensor([[1.0000, 0.
|
| 150 |
-
# [0.
|
| 151 |
-
# [0.
|
| 152 |
```
|
| 153 |
|
| 154 |
### Direct Usage (Sentence Transformers)
|
|
@@ -186,9 +185,9 @@ print(embeddings.shape)
|
|
| 186 |
# Get the similarity scores for the embeddings
|
| 187 |
similarities = model.similarity(embeddings, embeddings)
|
| 188 |
print(similarities)
|
| 189 |
-
# tensor([[1.0000, 0.
|
| 190 |
-
# [0.
|
| 191 |
-
# [0.
|
| 192 |
```
|
| 193 |
|
| 194 |
</details>
|
|
@@ -209,8 +208,8 @@ print(similarities)
|
|
| 209 |
|
| 210 |
| Split | Metric | Value |
|
| 211 |
| :------------- | :----------- | :---------- |
|
| 212 |
-
| **validation** | **spearman** | **0.
|
| 213 |
-
| **test** | **spearman** | **0.
|
| 214 |
|
| 215 |
## Training Details
|
| 216 |
|
|
@@ -236,11 +235,11 @@ print(similarities)
|
|
| 236 |
```json
|
| 237 |
{
|
| 238 |
"loss": "TanimotoSentLoss",
|
| 239 |
-
"n_layers_per_step":
|
| 240 |
-
"last_layer_weight":
|
| 241 |
-
"prior_layers_weight": 1.
|
| 242 |
-
"kl_div_weight": 0.
|
| 243 |
-
"kl_temperature": 0.
|
| 244 |
"matryoshka_dims": [
|
| 245 |
1024,
|
| 246 |
512,
|
|
@@ -261,7 +260,7 @@ print(similarities)
|
|
| 261 |
1,
|
| 262 |
1
|
| 263 |
],
|
| 264 |
-
"n_dims_per_step":
|
| 265 |
}
|
| 266 |
```
|
| 267 |
|
|
@@ -287,11 +286,11 @@ print(similarities)
|
|
| 287 |
```json
|
| 288 |
{
|
| 289 |
"loss": "TanimotoSentLoss",
|
| 290 |
-
"n_layers_per_step":
|
| 291 |
-
"last_layer_weight":
|
| 292 |
-
"prior_layers_weight": 1.
|
| 293 |
-
"kl_div_weight": 0.
|
| 294 |
-
"kl_temperature": 0.
|
| 295 |
"matryoshka_dims": [
|
| 296 |
1024,
|
| 297 |
512,
|
|
@@ -312,7 +311,7 @@ print(similarities)
|
|
| 312 |
1,
|
| 313 |
1
|
| 314 |
],
|
| 315 |
-
"n_dims_per_step":
|
| 316 |
}
|
| 317 |
```
|
| 318 |
|
|
@@ -334,7 +333,7 @@ print(similarities)
|
|
| 334 |
- `tf32`: True
|
| 335 |
- `optim`: stable_adamw
|
| 336 |
- `optim_args`: decouple_lr=True,max_lr=8.0e-6
|
| 337 |
-
- `
|
| 338 |
- `eval_on_start`: True
|
| 339 |
|
| 340 |
#### All Hyperparameters
|
|
@@ -416,7 +415,7 @@ print(similarities)
|
|
| 416 |
- `ddp_find_unused_parameters`: None
|
| 417 |
- `ddp_bucket_cap_mb`: None
|
| 418 |
- `ddp_broadcast_buffers`: False
|
| 419 |
-
- `dataloader_pin_memory`:
|
| 420 |
- `dataloader_persistent_workers`: False
|
| 421 |
- `skip_memory_metrics`: True
|
| 422 |
- `use_legacy_prediction_loop`: False
|
|
@@ -427,7 +426,7 @@ print(similarities)
|
|
| 427 |
- `hub_private_repo`: None
|
| 428 |
- `hub_always_push`: False
|
| 429 |
- `hub_revision`: None
|
| 430 |
-
- `gradient_checkpointing`:
|
| 431 |
- `gradient_checkpointing_kwargs`: None
|
| 432 |
- `include_inputs_for_metrics`: False
|
| 433 |
- `include_for_metrics`: []
|
|
@@ -467,41 +466,41 @@ print(similarities)
|
|
| 467 |
|
| 468 |
| Epoch | Step | Training Loss | pubchem 10m genmol similarity loss | pubchem_10m_genmol_similarity_spearman |
|
| 469 |
| :----: | :----: | :-----------: | :--------------------------------: | :------------------------------------: |
|
| 470 |
-
| 0 | 0 | - |
|
| 471 |
-
| 0.0000 | 1 |
|
| 472 |
-
| 0.2477 | 25000 |
|
| 473 |
-
| 0.2500 | 25235 | - |
|
| 474 |
-
| 0.4978 | 50250 |
|
| 475 |
-
| 0.5000 | 50470 | - |
|
| 476 |
-
| 0.7479 | 75500 |
|
| 477 |
-
| 0.7500 | 75705 | - |
|
| 478 |
-
| 0.9981 | 100750 |
|
| 479 |
-
| 1.0000 | 100940 | - |
|
| 480 |
-
| 1.2482 | 126000 |
|
| 481 |
-
| 1.2500 | 126175 | - |
|
| 482 |
-
| 1.4984 | 151250 |
|
| 483 |
-
| 1.5000 | 151410 | - |
|
| 484 |
-
| 1.7485 | 176500 |
|
| 485 |
-
| 1.7499 | 176645 | - |
|
| 486 |
-
| 1.9987 | 201750 |
|
| 487 |
-
| 1.9999 | 201880 | - |
|
| 488 |
-
| 2.2488 | 227000 |
|
| 489 |
-
| 2.2499 | 227115 | - |
|
| 490 |
-
| 2.4989 | 252250 |
|
| 491 |
-
| 2.4999 | 252350 | - |
|
| 492 |
-
| 2.7491 | 277500 |
|
| 493 |
-
| 2.7499 | 277585 | - |
|
| 494 |
-
| 2.9992 | 302750 |
|
| 495 |
-
| 2.9999 | 302820 | - |
|
| 496 |
-
| 3.0000 | 302829 | - | - | 0.
|
| 497 |
|
| 498 |
</details>
|
| 499 |
|
| 500 |
### Environmental Impact
|
| 501 |
Carbon emissions were measured using [CodeCarbon](https://github.com/mlco2/codecarbon).
|
| 502 |
-
- **Energy Consumed**:
|
| 503 |
-
- **Carbon Emitted**:
|
| 504 |
-
- **Hours Used**:
|
| 505 |
|
| 506 |
### Training Hardware
|
| 507 |
- **On Cloud**: No
|
|
@@ -511,11 +510,11 @@ Carbon emissions were measured using [CodeCarbon](https://github.com/mlco2/codec
|
|
| 511 |
|
| 512 |
### Framework Versions
|
| 513 |
- Python: 3.13.7
|
| 514 |
-
- Sentence Transformers: 5.1.
|
| 515 |
- Transformers: 4.57.1
|
| 516 |
- PyTorch: 2.8.0+cu128
|
| 517 |
- Accelerate: 1.10.1
|
| 518 |
-
- Datasets: 3.
|
| 519 |
- Tokenizers: 0.22.1
|
| 520 |
|
| 521 |
## Citation
|
|
|
|
| 47 |
metrics:
|
| 48 |
- spearman
|
| 49 |
co2_eq_emissions:
|
| 50 |
+
emissions: 6350.153020081601
|
| 51 |
+
energy_consumed: 30.935740629629628
|
| 52 |
source: codecarbon
|
| 53 |
training_type: fine-tuning
|
| 54 |
on_cloud: false
|
| 55 |
cpu_model: AMD Ryzen 7 3700X 8-Core Processor
|
| 56 |
ram_total_size: 62.69887161254883
|
| 57 |
+
hours_used: 116.388
|
| 58 |
hardware_used: 2 x NVIDIA GeForce RTX 3090
|
| 59 |
model-index:
|
| 60 |
- name: 'ChemMRL: SMILES Matryoshka Representation Learning Embedding Transformer'
|
|
|
|
| 67 |
type: pubchem_10m_genmol_similarity_validation
|
| 68 |
metrics:
|
| 69 |
- type: spearman
|
| 70 |
+
value: 0.989142152637452
|
| 71 |
name: Spearman
|
| 72 |
- task:
|
| 73 |
type: semantic-similarity
|
|
|
|
| 77 |
type: pubchem_10m_genmol_similarity_test
|
| 78 |
metrics:
|
| 79 |
- type: spearman
|
| 80 |
+
value: 0.9891625268496924
|
| 81 |
name: Spearman
|
|
|
|
| 82 |
---
|
| 83 |
|
| 84 |
# ChemMRL: SMILES Matryoshka Representation Learning Embedding Transformer
|
|
|
|
| 145 |
# Get the similarity scores for the embeddings
|
| 146 |
similarities = model.backbone.similarity(embeddings, embeddings)
|
| 147 |
print(similarities)
|
| 148 |
+
# tensor([[1.0000, 0.4179, 0.0165],
|
| 149 |
+
# [0.4179, 1.0000, 0.0140],
|
| 150 |
+
# [0.0165, 0.0140, 1.0000]])
|
| 151 |
```
|
| 152 |
|
| 153 |
### Direct Usage (Sentence Transformers)
|
|
|
|
| 185 |
# Get the similarity scores for the embeddings
|
| 186 |
similarities = model.similarity(embeddings, embeddings)
|
| 187 |
print(similarities)
|
| 188 |
+
# tensor([[1.0000, 0.5894, 0.0326],
|
| 189 |
+
# [0.5894, 1.0000, 0.0275],
|
| 190 |
+
# [0.0326, 0.0275, 1.0000]])
|
| 191 |
```
|
| 192 |
|
| 193 |
</details>
|
|
|
|
| 208 |
|
| 209 |
| Split | Metric | Value |
|
| 210 |
| :------------- | :----------- | :---------- |
|
| 211 |
+
| **validation** | **spearman** | **0.98914** |
|
| 212 |
+
| **test** | **spearman** | **0.98916** |
|
| 213 |
|
| 214 |
## Training Details
|
| 215 |
|
|
|
|
| 235 |
```json
|
| 236 |
{
|
| 237 |
"loss": "TanimotoSentLoss",
|
| 238 |
+
"n_layers_per_step": -1,
|
| 239 |
+
"last_layer_weight": 2.0,
|
| 240 |
+
"prior_layers_weight": 1.0,
|
| 241 |
+
"kl_div_weight": 0.0,
|
| 242 |
+
"kl_temperature": 0.0,
|
| 243 |
"matryoshka_dims": [
|
| 244 |
1024,
|
| 245 |
512,
|
|
|
|
| 260 |
1,
|
| 261 |
1
|
| 262 |
],
|
| 263 |
+
"n_dims_per_step": -1
|
| 264 |
}
|
| 265 |
```
|
| 266 |
|
|
|
|
| 286 |
```json
|
| 287 |
{
|
| 288 |
"loss": "TanimotoSentLoss",
|
| 289 |
+
"n_layers_per_step": -1,
|
| 290 |
+
"last_layer_weight": 2.0,
|
| 291 |
+
"prior_layers_weight": 1.0,
|
| 292 |
+
"kl_div_weight": 0.0,
|
| 293 |
+
"kl_temperature": 0.0,
|
| 294 |
"matryoshka_dims": [
|
| 295 |
1024,
|
| 296 |
512,
|
|
|
|
| 311 |
1,
|
| 312 |
1
|
| 313 |
],
|
| 314 |
+
"n_dims_per_step": -1
|
| 315 |
}
|
| 316 |
```
|
| 317 |
|
|
|
|
| 333 |
- `tf32`: True
|
| 334 |
- `optim`: stable_adamw
|
| 335 |
- `optim_args`: decouple_lr=True,max_lr=8.0e-6
|
| 336 |
+
- `gradient_checkpointing`: True
|
| 337 |
- `eval_on_start`: True
|
| 338 |
|
| 339 |
#### All Hyperparameters
|
|
|
|
| 415 |
- `ddp_find_unused_parameters`: None
|
| 416 |
- `ddp_bucket_cap_mb`: None
|
| 417 |
- `ddp_broadcast_buffers`: False
|
| 418 |
+
- `dataloader_pin_memory`: True
|
| 419 |
- `dataloader_persistent_workers`: False
|
| 420 |
- `skip_memory_metrics`: True
|
| 421 |
- `use_legacy_prediction_loop`: False
|
|
|
|
| 426 |
- `hub_private_repo`: None
|
| 427 |
- `hub_always_push`: False
|
| 428 |
- `hub_revision`: None
|
| 429 |
+
- `gradient_checkpointing`: True
|
| 430 |
- `gradient_checkpointing_kwargs`: None
|
| 431 |
- `include_inputs_for_metrics`: False
|
| 432 |
- `include_for_metrics`: []
|
|
|
|
| 466 |
|
| 467 |
| Epoch | Step | Training Loss | pubchem 10m genmol similarity loss | pubchem_10m_genmol_similarity_spearman |
|
| 468 |
| :----: | :----: | :-----------: | :--------------------------------: | :------------------------------------: |
|
| 469 |
+
| 0 | 0 | - | 297.6136 | 0.7261 |
|
| 470 |
+
| 0.0000 | 1 | 244.6862 | - | - |
|
| 471 |
+
| 0.2477 | 25000 | 161.5037 | - | - |
|
| 472 |
+
| 0.2500 | 25235 | - | 195.4624 | 0.9067 |
|
| 473 |
+
| 0.4978 | 50250 | 155.7822 | - | - |
|
| 474 |
+
| 0.5000 | 50470 | - | 189.4068 | 0.9655 |
|
| 475 |
+
| 0.7479 | 75500 | 152.7915 | - | - |
|
| 476 |
+
| 0.7500 | 75705 | - | 186.3661 | 0.9780 |
|
| 477 |
+
| 0.9981 | 100750 | 151.0411 | - | - |
|
| 478 |
+
| 1.0000 | 100940 | - | 184.6362 | 0.9829 |
|
| 479 |
+
| 1.2482 | 126000 | 149.8544 | - | - |
|
| 480 |
+
| 1.2500 | 126175 | - | 183.5648 | 0.9855 |
|
| 481 |
+
| 1.4984 | 151250 | 149.2916 | - | - |
|
| 482 |
+
| 1.5000 | 151410 | - | 182.8947 | 0.9868 |
|
| 483 |
+
| 1.7485 | 176500 | 148.7942 | - | - |
|
| 484 |
+
| 1.7499 | 176645 | - | 182.3662 | 0.9879 |
|
| 485 |
+
| 1.9987 | 201750 | 148.3459 | - | - |
|
| 486 |
+
| 1.9999 | 201880 | - | 181.9855 | 0.9885 |
|
| 487 |
+
| 2.2488 | 227000 | 148.0316 | - | - |
|
| 488 |
+
| 2.2499 | 227115 | - | 181.7683 | 0.9889 |
|
| 489 |
+
| 2.4989 | 252250 | 147.8658 | - | - |
|
| 490 |
+
| 2.4999 | 252350 | - | 181.6711 | 0.9890 |
|
| 491 |
+
| 2.7491 | 277500 | 147.9642 | - | - |
|
| 492 |
+
| 2.7499 | 277585 | - | 181.6077 | 0.9891 |
|
| 493 |
+
| 2.9992 | 302750 | 147.8874 | - | - |
|
| 494 |
+
| 2.9999 | 302820 | - | 181.6066 | 0.9891 |
|
| 495 |
+
| 3.0000 | 302829 | - | - | 0.98914 |
|
| 496 |
|
| 497 |
</details>
|
| 498 |
|
| 499 |
### Environmental Impact
|
| 500 |
Carbon emissions were measured using [CodeCarbon](https://github.com/mlco2/codecarbon).
|
| 501 |
+
- **Energy Consumed**: 30.936 kWh
|
| 502 |
+
- **Carbon Emitted**: 6.350 kg of CO2
|
| 503 |
+
- **Hours Used**: 116.388 hours
|
| 504 |
|
| 505 |
### Training Hardware
|
| 506 |
- **On Cloud**: No
|
|
|
|
| 510 |
|
| 511 |
### Framework Versions
|
| 512 |
- Python: 3.13.7
|
| 513 |
+
- Sentence Transformers: 5.1.2
|
| 514 |
- Transformers: 4.57.1
|
| 515 |
- PyTorch: 2.8.0+cu128
|
| 516 |
- Accelerate: 1.10.1
|
| 517 |
+
- Datasets: 4.3.0
|
| 518 |
- Tokenizers: 0.22.1
|
| 519 |
|
| 520 |
## Citation
|
config_chem_mrl.json
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
{
|
| 2 |
-
"__version__": "0.
|
| 3 |
"embedding_pooling": "mean",
|
| 4 |
"eval_metric": "spearman",
|
| 5 |
"eval_similarity_fct": "tanimoto",
|
| 6 |
-
"kl_div_weight": 0.
|
| 7 |
-
"kl_temperature": 0.
|
| 8 |
-
"last_layer_weight":
|
| 9 |
"loss_func": "tanimotosentloss",
|
| 10 |
"model_name": "Derify/ModChemBERT-IR-BASE",
|
| 11 |
"mrl_dimension_weights": [
|
|
@@ -28,10 +28,9 @@
|
|
| 28 |
16,
|
| 29 |
8
|
| 30 |
],
|
| 31 |
-
"n_dims_per_step":
|
| 32 |
-
"n_layers_per_step":
|
| 33 |
-
"prior_layers_weight": 1.
|
| 34 |
"tanimoto_similarity_loss_func": null,
|
| 35 |
-
"use_2d_matryoshka": true
|
| 36 |
-
"use_query_tokenizer": false
|
| 37 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"__version__": "0.8.0",
|
| 3 |
"embedding_pooling": "mean",
|
| 4 |
"eval_metric": "spearman",
|
| 5 |
"eval_similarity_fct": "tanimoto",
|
| 6 |
+
"kl_div_weight": 0.0,
|
| 7 |
+
"kl_temperature": 0.0,
|
| 8 |
+
"last_layer_weight": 2.0,
|
| 9 |
"loss_func": "tanimotosentloss",
|
| 10 |
"model_name": "Derify/ModChemBERT-IR-BASE",
|
| 11 |
"mrl_dimension_weights": [
|
|
|
|
| 28 |
16,
|
| 29 |
8
|
| 30 |
],
|
| 31 |
+
"n_dims_per_step": -1,
|
| 32 |
+
"n_layers_per_step": -1,
|
| 33 |
+
"prior_layers_weight": 1.0,
|
| 34 |
"tanimoto_similarity_loss_func": null,
|
| 35 |
+
"use_2d_matryoshka": true
|
|
|
|
| 36 |
}
|
config_sentence_transformers.json
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
{
|
| 2 |
"model_type": "SentenceTransformer",
|
| 3 |
"__version__": {
|
| 4 |
-
"sentence_transformers": "5.1.
|
| 5 |
"transformers": "4.57.1",
|
| 6 |
"pytorch": "2.8.0+cu128"
|
| 7 |
},
|
|
|
|
| 1 |
{
|
| 2 |
"model_type": "SentenceTransformer",
|
| 3 |
"__version__": {
|
| 4 |
+
"sentence_transformers": "5.1.2",
|
| 5 |
"transformers": "4.57.1",
|
| 6 |
"pytorch": "2.8.0+cu128"
|
| 7 |
},
|
configuration_modchembert.py
CHANGED
|
@@ -37,14 +37,15 @@ class ModChemBertConfig(ModernBertConfig):
|
|
| 37 |
- "max_cls": Element-wise max pooling over last k hidden states, then take CLS token
|
| 38 |
- "cls_mha": Multi-head attention with CLS token as query and full sequence as keys/values
|
| 39 |
- "max_seq_mha": Max pooling over last k states + multi-head attention with CLS as query
|
|
|
|
| 40 |
- "max_seq_mean": Max pooling over last k hidden states, then mean pooling over sequence
|
| 41 |
Defaults to "sum_mean".
|
| 42 |
classifier_pooling_num_attention_heads (int, optional): Number of attention heads for multi-head attention
|
| 43 |
-
pooling strategies (cls_mha, max_seq_mha). Defaults to 4.
|
| 44 |
classifier_pooling_attention_dropout (float, optional): Dropout probability for multi-head attention
|
| 45 |
-
pooling strategies (cls_mha, max_seq_mha). Defaults to 0.0.
|
| 46 |
-
classifier_pooling_last_k (int, optional): Number of last hidden layers to use for max pooling
|
| 47 |
-
strategies (max_cls, max_seq_mha, max_seq_mean). Defaults to 8.
|
| 48 |
*args: Variable length argument list passed to ModernBertConfig.
|
| 49 |
**kwargs: Arbitrary keyword arguments passed to ModernBertConfig.
|
| 50 |
|
|
@@ -68,6 +69,7 @@ class ModChemBertConfig(ModernBertConfig):
|
|
| 68 |
"max_cls",
|
| 69 |
"cls_mha",
|
| 70 |
"max_seq_mha",
|
|
|
|
| 71 |
"max_seq_mean",
|
| 72 |
] = "max_seq_mha",
|
| 73 |
classifier_pooling_num_attention_heads: int = 4,
|
|
@@ -75,6 +77,25 @@ class ModChemBertConfig(ModernBertConfig):
|
|
| 75 |
classifier_pooling_last_k: int = 8,
|
| 76 |
**kwargs,
|
| 77 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
# Pass classifier_pooling="cls" to circumvent ValueError in ModernBertConfig init
|
| 79 |
super().__init__(*args, classifier_pooling="cls", **kwargs)
|
| 80 |
# Override with custom value
|
|
|
|
| 37 |
- "max_cls": Element-wise max pooling over last k hidden states, then take CLS token
|
| 38 |
- "cls_mha": Multi-head attention with CLS token as query and full sequence as keys/values
|
| 39 |
- "max_seq_mha": Max pooling over last k states + multi-head attention with CLS as query
|
| 40 |
+
- "mean_seq_mha": Mean pooling over last k states + multi-head attention with CLS as query
|
| 41 |
- "max_seq_mean": Max pooling over last k hidden states, then mean pooling over sequence
|
| 42 |
Defaults to "sum_mean".
|
| 43 |
classifier_pooling_num_attention_heads (int, optional): Number of attention heads for multi-head attention
|
| 44 |
+
pooling strategies (cls_mha, max_seq_mha, mean_seq_mha). Defaults to 4.
|
| 45 |
classifier_pooling_attention_dropout (float, optional): Dropout probability for multi-head attention
|
| 46 |
+
pooling strategies (cls_mha, max_seq_mha, mean_seq_mha). Defaults to 0.0.
|
| 47 |
+
classifier_pooling_last_k (int, optional): Number of last hidden layers to use for max/mean pooling
|
| 48 |
+
strategies (max_cls, max_seq_mha, mean_seq_mha, max_seq_mean). Defaults to 8.
|
| 49 |
*args: Variable length argument list passed to ModernBertConfig.
|
| 50 |
**kwargs: Arbitrary keyword arguments passed to ModernBertConfig.
|
| 51 |
|
|
|
|
| 69 |
"max_cls",
|
| 70 |
"cls_mha",
|
| 71 |
"max_seq_mha",
|
| 72 |
+
"mean_seq_mha",
|
| 73 |
"max_seq_mean",
|
| 74 |
] = "max_seq_mha",
|
| 75 |
classifier_pooling_num_attention_heads: int = 4,
|
|
|
|
| 77 |
classifier_pooling_last_k: int = 8,
|
| 78 |
**kwargs,
|
| 79 |
):
|
| 80 |
+
valid_classifier_pooling_options = [
|
| 81 |
+
"cls",
|
| 82 |
+
"mean",
|
| 83 |
+
"sum_mean",
|
| 84 |
+
"sum_sum",
|
| 85 |
+
"mean_mean",
|
| 86 |
+
"mean_sum",
|
| 87 |
+
"max_cls",
|
| 88 |
+
"cls_mha",
|
| 89 |
+
"max_seq_mha",
|
| 90 |
+
"mean_seq_mha",
|
| 91 |
+
"max_seq_mean",
|
| 92 |
+
]
|
| 93 |
+
if classifier_pooling not in valid_classifier_pooling_options:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
f"Invalid value for `classifier_pooling`, should be one of {valid_classifier_pooling_options}, "
|
| 96 |
+
f"but is {classifier_pooling}."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
# Pass classifier_pooling="cls" to circumvent ValueError in ModernBertConfig init
|
| 100 |
super().__init__(*args, classifier_pooling="cls", **kwargs)
|
| 101 |
# Override with custom value
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 397110232
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ed6105dfe64c12207b1e1155a0d85c30cdadebcb42a6f0ea216dc36c3c28cf0c
|
| 3 |
size 397110232
|
modeling_modchembert.py
CHANGED
|
@@ -19,9 +19,9 @@
|
|
| 19 |
# Modifications include:
|
| 20 |
# - Additional classifier_pooling options for ModChemBertForSequenceClassification
|
| 21 |
# - sum_mean, sum_sum, mean_sum, mean_mean: from ChemLM (utilizes all hidden states)
|
| 22 |
-
# - max_cls, cls_mha, max_seq_mha: from MaxPoolBERT (utilizes last k hidden states)
|
| 23 |
# - max_seq_mean: a merge between sum_mean and max_cls (utilizes last k hidden states)
|
| 24 |
-
# - Addition of ModChemBertPoolingAttention for cls_mha and
|
| 25 |
|
| 26 |
import copy
|
| 27 |
import math
|
|
@@ -122,11 +122,7 @@ class ModChemBertPoolingAttention(nn.Module):
|
|
| 122 |
self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
|
| 123 |
|
| 124 |
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
| 125 |
-
self.out_drop = (
|
| 126 |
-
nn.Dropout(config.attention_dropout)
|
| 127 |
-
if config.attention_dropout > 0.0
|
| 128 |
-
else nn.Identity()
|
| 129 |
-
)
|
| 130 |
self.pruned_heads = set()
|
| 131 |
|
| 132 |
def forward(
|
|
@@ -179,14 +175,9 @@ class ModChemBertModel(ModernBertPreTrainedModel):
|
|
| 179 |
self.config = config
|
| 180 |
self.embeddings = ModernBertEmbeddings(config)
|
| 181 |
self.layers = nn.ModuleList(
|
| 182 |
-
[
|
| 183 |
-
ModernBertEncoderLayer(config, layer_id)
|
| 184 |
-
for layer_id in range(config.num_hidden_layers)
|
| 185 |
-
]
|
| 186 |
-
)
|
| 187 |
-
self.final_norm = nn.LayerNorm(
|
| 188 |
-
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
|
| 189 |
)
|
|
|
|
| 190 |
self.gradient_checkpointing = False
|
| 191 |
self.post_init()
|
| 192 |
|
|
@@ -228,13 +219,9 @@ class ModChemBertModel(ModernBertPreTrainedModel):
|
|
| 228 |
seq_len (`int`, *optional*):
|
| 229 |
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
| 230 |
""" # noqa: E501
|
| 231 |
-
output_attentions =
|
| 232 |
-
output_attentions if output_attentions is not None else self.config.output_attentions
|
| 233 |
-
)
|
| 234 |
output_hidden_states = (
|
| 235 |
-
output_hidden_states
|
| 236 |
-
if output_hidden_states is not None
|
| 237 |
-
else self.config.output_hidden_states
|
| 238 |
)
|
| 239 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 240 |
|
|
@@ -316,25 +303,19 @@ class ModChemBertModel(ModernBertPreTrainedModel):
|
|
| 316 |
)
|
| 317 |
if all_hidden_states is not None:
|
| 318 |
all_hidden_states = tuple(
|
| 319 |
-
_pad_modernbert_output(
|
| 320 |
-
inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len
|
| 321 |
-
) # type: ignore
|
| 322 |
for hs in all_hidden_states
|
| 323 |
)
|
| 324 |
|
| 325 |
if not return_dict:
|
| 326 |
-
return tuple(
|
| 327 |
-
v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None
|
| 328 |
-
)
|
| 329 |
return BaseModelOutput(
|
| 330 |
last_hidden_state=hidden_states, # type: ignore
|
| 331 |
hidden_states=all_hidden_states, # type: ignore
|
| 332 |
attentions=all_self_attentions,
|
| 333 |
)
|
| 334 |
|
| 335 |
-
def _update_attention_mask(
|
| 336 |
-
self, attention_mask: torch.Tensor, output_attentions: bool
|
| 337 |
-
) -> torch.Tensor:
|
| 338 |
if output_attentions:
|
| 339 |
if self.config._attn_implementation == "sdpa":
|
| 340 |
logger.warning_once( # type: ignore
|
|
@@ -357,16 +338,9 @@ class ModChemBertModel(ModernBertPreTrainedModel):
|
|
| 357 |
distance = torch.abs(rows - rows.T)
|
| 358 |
|
| 359 |
# Create sliding window mask (1 for positions within window, 0 outside)
|
| 360 |
-
window_mask = (
|
| 361 |
-
(distance <= self.config.local_attention // 2)
|
| 362 |
-
.unsqueeze(0)
|
| 363 |
-
.unsqueeze(0)
|
| 364 |
-
.to(attention_mask.device)
|
| 365 |
-
)
|
| 366 |
# Combine with existing mask
|
| 367 |
-
sliding_window_mask = global_attention_mask.masked_fill(
|
| 368 |
-
window_mask.logical_not(), torch.finfo(self.dtype).min
|
| 369 |
-
)
|
| 370 |
|
| 371 |
return global_attention_mask, sliding_window_mask # type: ignore
|
| 372 |
|
|
@@ -445,28 +419,22 @@ class ModChemBertForMaskedLM(InitWeightsMixin, ModernBertPreTrainedModel):
|
|
| 445 |
device = input_ids.device if input_ids is not None else inputs_embeds.device # type: ignore
|
| 446 |
|
| 447 |
if attention_mask is None:
|
| 448 |
-
attention_mask = torch.ones(
|
| 449 |
-
(batch_size, seq_len), device=device, dtype=torch.bool
|
| 450 |
-
) # type: ignore
|
| 451 |
|
| 452 |
if inputs_embeds is None:
|
| 453 |
with torch.no_grad():
|
| 454 |
-
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = (
|
| 455 |
-
|
| 456 |
-
inputs=input_ids, # type: ignore
|
| 457 |
-
attention_mask=attention_mask, # type: ignore
|
| 458 |
-
position_ids=position_ids,
|
| 459 |
-
labels=labels,
|
| 460 |
-
)
|
| 461 |
-
)
|
| 462 |
-
else:
|
| 463 |
-
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = (
|
| 464 |
-
_unpad_modernbert_input(
|
| 465 |
-
inputs=inputs_embeds,
|
| 466 |
attention_mask=attention_mask, # type: ignore
|
| 467 |
position_ids=position_ids,
|
| 468 |
labels=labels,
|
| 469 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
)
|
| 471 |
|
| 472 |
outputs = self.model(
|
|
@@ -507,14 +475,8 @@ class ModChemBertForMaskedLM(InitWeightsMixin, ModernBertPreTrainedModel):
|
|
| 507 |
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 508 |
|
| 509 |
if self.config._attn_implementation == "flash_attention_2":
|
| 510 |
-
with (
|
| 511 |
-
|
| 512 |
-
if self.config.repad_logits_with_grad or labels is None
|
| 513 |
-
else torch.no_grad()
|
| 514 |
-
):
|
| 515 |
-
logits = _pad_modernbert_output(
|
| 516 |
-
inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len
|
| 517 |
-
) # type: ignore
|
| 518 |
|
| 519 |
if not return_dict:
|
| 520 |
output = (logits,)
|
|
@@ -537,7 +499,7 @@ class ModChemBertForSequenceClassification(InitWeightsMixin, ModernBertPreTraine
|
|
| 537 |
self.config = config
|
| 538 |
|
| 539 |
self.model = ModernBertModel(config)
|
| 540 |
-
if self.config.classifier_pooling in {"cls_mha", "max_seq_mha"}:
|
| 541 |
self.pooling_attn = ModChemBertPoolingAttention(config=self.config)
|
| 542 |
else:
|
| 543 |
self.pooling_attn = None
|
|
@@ -638,9 +600,7 @@ class ModChemBertForSequenceClassification(InitWeightsMixin, ModernBertPreTraine
|
|
| 638 |
if self.config.problem_type is None:
|
| 639 |
if self.num_labels == 1:
|
| 640 |
self.config.problem_type = "regression"
|
| 641 |
-
elif self.num_labels > 1 and (
|
| 642 |
-
labels.dtype == torch.long or labels.dtype == torch.int
|
| 643 |
-
):
|
| 644 |
self.config.problem_type = "single_label_classification"
|
| 645 |
else:
|
| 646 |
self.config.problem_type = "multi_label_classification"
|
|
@@ -689,6 +649,7 @@ def _pool_modchembert_output(
|
|
| 689 |
- max_cls: Element-wise max pooling over the last k hidden states, then take CLS token
|
| 690 |
- cls_mha: Multi-head attention with CLS token as query and full sequence as keys/values
|
| 691 |
- max_seq_mha: Max pooling over last k states + multi-head attention with CLS as query
|
|
|
|
| 692 |
- max_seq_mean: Max pooling over last k hidden states, then mean pooling over sequence
|
| 693 |
- sum_mean: Sum all hidden states across layers, then mean pool over sequence
|
| 694 |
- sum_sum: Sum all hidden states across layers, then sum pool over sequence
|
|
@@ -705,22 +666,20 @@ def _pool_modchembert_output(
|
|
| 705 |
torch.Tensor: Pooled representation of shape (batch_size, hidden_size)
|
| 706 |
|
| 707 |
Note:
|
| 708 |
-
Some pooling strategies (cls_mha, max_seq_mha) require the module to have a pooling_attn
|
| 709 |
attribute containing a ModChemBertPoolingAttention instance.
|
| 710 |
"""
|
| 711 |
config = typing.cast(ModChemBertConfig, module.config)
|
| 712 |
if config.classifier_pooling == "cls":
|
| 713 |
last_hidden_state = last_hidden_state[:, 0]
|
| 714 |
elif config.classifier_pooling == "mean":
|
| 715 |
-
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(
|
| 716 |
-
dim=1
|
| 717 |
-
)
|
| 718 |
elif config.classifier_pooling == "max_cls":
|
| 719 |
k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
|
| 720 |
theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
|
| 721 |
-
pooled_seq = torch.max(
|
| 722 |
-
theta, dim=1
|
| 723 |
-
).values # Element-wise max over k -> (batch, seq_len, hidden)
|
| 724 |
last_hidden_state = pooled_seq[:, 0, :] # (batch, hidden)
|
| 725 |
elif config.classifier_pooling == "cls_mha":
|
| 726 |
# Similar to max_seq_mha but without the max pooling step
|
|
@@ -731,12 +690,13 @@ def _pool_modchembert_output(
|
|
| 731 |
q=q, kv=last_hidden_state, attention_mask=attention_mask
|
| 732 |
) # (batch, seq_len, hidden)
|
| 733 |
last_hidden_state = torch.mean(attn_out, dim=1)
|
| 734 |
-
elif config.classifier_pooling
|
| 735 |
k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
|
| 736 |
theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
|
| 737 |
-
|
| 738 |
-
theta, dim=1
|
| 739 |
-
|
|
|
|
| 740 |
# Query is pooled CLS token (position 0); Keys/Values are pooled sequence
|
| 741 |
q = pooled_seq[:, 0, :].unsqueeze(1) # (batch, 1, hidden)
|
| 742 |
q = q.expand(-1, pooled_seq.shape[1], -1) # (batch, seq_len, hidden)
|
|
@@ -747,9 +707,7 @@ def _pool_modchembert_output(
|
|
| 747 |
elif config.classifier_pooling == "max_seq_mean":
|
| 748 |
k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
|
| 749 |
theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
|
| 750 |
-
pooled_seq = torch.max(
|
| 751 |
-
theta, dim=1
|
| 752 |
-
).values # Element-wise max over k -> (batch, seq_len, hidden)
|
| 753 |
last_hidden_state = torch.mean(pooled_seq, dim=1) # Mean over sequence length
|
| 754 |
elif config.classifier_pooling == "sum_mean":
|
| 755 |
# ChemLM uses the mean of all hidden states
|
|
@@ -775,6 +733,7 @@ def _pool_modchembert_output(
|
|
| 775 |
|
| 776 |
|
| 777 |
__all__ = [
|
|
|
|
| 778 |
"ModChemBertForMaskedLM",
|
| 779 |
"ModChemBertForSequenceClassification",
|
| 780 |
]
|
|
|
|
| 19 |
# Modifications include:
|
| 20 |
# - Additional classifier_pooling options for ModChemBertForSequenceClassification
|
| 21 |
# - sum_mean, sum_sum, mean_sum, mean_mean: from ChemLM (utilizes all hidden states)
|
| 22 |
+
# - max_cls, cls_mha, max_seq_mha, mean_seq_mha: from MaxPoolBERT (utilizes last k hidden states)
|
| 23 |
# - max_seq_mean: a merge between sum_mean and max_cls (utilizes last k hidden states)
|
| 24 |
+
# - Addition of ModChemBertPoolingAttention for cls_mha, max_seq_mha, and mean_seq_mha pooling options
|
| 25 |
|
| 26 |
import copy
|
| 27 |
import math
|
|
|
|
| 122 |
self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
|
| 123 |
|
| 124 |
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
|
| 125 |
+
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
self.pruned_heads = set()
|
| 127 |
|
| 128 |
def forward(
|
|
|
|
| 175 |
self.config = config
|
| 176 |
self.embeddings = ModernBertEmbeddings(config)
|
| 177 |
self.layers = nn.ModuleList(
|
| 178 |
+
[ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
)
|
| 180 |
+
self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
|
| 181 |
self.gradient_checkpointing = False
|
| 182 |
self.post_init()
|
| 183 |
|
|
|
|
| 219 |
seq_len (`int`, *optional*):
|
| 220 |
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
| 221 |
""" # noqa: E501
|
| 222 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
|
|
|
| 223 |
output_hidden_states = (
|
| 224 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
|
|
|
| 225 |
)
|
| 226 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 227 |
|
|
|
|
| 303 |
)
|
| 304 |
if all_hidden_states is not None:
|
| 305 |
all_hidden_states = tuple(
|
| 306 |
+
_pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) # type: ignore
|
|
|
|
|
|
|
| 307 |
for hs in all_hidden_states
|
| 308 |
)
|
| 309 |
|
| 310 |
if not return_dict:
|
| 311 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
|
|
|
|
| 312 |
return BaseModelOutput(
|
| 313 |
last_hidden_state=hidden_states, # type: ignore
|
| 314 |
hidden_states=all_hidden_states, # type: ignore
|
| 315 |
attentions=all_self_attentions,
|
| 316 |
)
|
| 317 |
|
| 318 |
+
def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
|
|
|
|
|
|
|
| 319 |
if output_attentions:
|
| 320 |
if self.config._attn_implementation == "sdpa":
|
| 321 |
logger.warning_once( # type: ignore
|
|
|
|
| 338 |
distance = torch.abs(rows - rows.T)
|
| 339 |
|
| 340 |
# Create sliding window mask (1 for positions within window, 0 outside)
|
| 341 |
+
window_mask = (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
# Combine with existing mask
|
| 343 |
+
sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
|
|
|
|
|
|
|
| 344 |
|
| 345 |
return global_attention_mask, sliding_window_mask # type: ignore
|
| 346 |
|
|
|
|
| 419 |
device = input_ids.device if input_ids is not None else inputs_embeds.device # type: ignore
|
| 420 |
|
| 421 |
if attention_mask is None:
|
| 422 |
+
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) # type: ignore
|
|
|
|
|
|
|
| 423 |
|
| 424 |
if inputs_embeds is None:
|
| 425 |
with torch.no_grad():
|
| 426 |
+
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
| 427 |
+
inputs=input_ids, # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
attention_mask=attention_mask, # type: ignore
|
| 429 |
position_ids=position_ids,
|
| 430 |
labels=labels,
|
| 431 |
)
|
| 432 |
+
else:
|
| 433 |
+
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
|
| 434 |
+
inputs=inputs_embeds,
|
| 435 |
+
attention_mask=attention_mask, # type: ignore
|
| 436 |
+
position_ids=position_ids,
|
| 437 |
+
labels=labels,
|
| 438 |
)
|
| 439 |
|
| 440 |
outputs = self.model(
|
|
|
|
| 475 |
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 476 |
|
| 477 |
if self.config._attn_implementation == "flash_attention_2":
|
| 478 |
+
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
|
| 479 |
+
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
|
| 481 |
if not return_dict:
|
| 482 |
output = (logits,)
|
|
|
|
| 499 |
self.config = config
|
| 500 |
|
| 501 |
self.model = ModernBertModel(config)
|
| 502 |
+
if self.config.classifier_pooling in {"cls_mha", "max_seq_mha", "mean_seq_mha"}:
|
| 503 |
self.pooling_attn = ModChemBertPoolingAttention(config=self.config)
|
| 504 |
else:
|
| 505 |
self.pooling_attn = None
|
|
|
|
| 600 |
if self.config.problem_type is None:
|
| 601 |
if self.num_labels == 1:
|
| 602 |
self.config.problem_type = "regression"
|
| 603 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
|
|
|
|
|
| 604 |
self.config.problem_type = "single_label_classification"
|
| 605 |
else:
|
| 606 |
self.config.problem_type = "multi_label_classification"
|
|
|
|
| 649 |
- max_cls: Element-wise max pooling over the last k hidden states, then take CLS token
|
| 650 |
- cls_mha: Multi-head attention with CLS token as query and full sequence as keys/values
|
| 651 |
- max_seq_mha: Max pooling over last k states + multi-head attention with CLS as query
|
| 652 |
+
- mean_seq_mha: Mean pooling over last k states + multi-head attention with CLS as query
|
| 653 |
- max_seq_mean: Max pooling over last k hidden states, then mean pooling over sequence
|
| 654 |
- sum_mean: Sum all hidden states across layers, then mean pool over sequence
|
| 655 |
- sum_sum: Sum all hidden states across layers, then sum pool over sequence
|
|
|
|
| 666 |
torch.Tensor: Pooled representation of shape (batch_size, hidden_size)
|
| 667 |
|
| 668 |
Note:
|
| 669 |
+
Some pooling strategies (cls_mha, max_seq_mha, mean_seq_mha) require the module to have a pooling_attn
|
| 670 |
attribute containing a ModChemBertPoolingAttention instance.
|
| 671 |
"""
|
| 672 |
config = typing.cast(ModChemBertConfig, module.config)
|
| 673 |
if config.classifier_pooling == "cls":
|
| 674 |
last_hidden_state = last_hidden_state[:, 0]
|
| 675 |
elif config.classifier_pooling == "mean":
|
| 676 |
+
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
|
| 677 |
+
dim=1, keepdim=True
|
| 678 |
+
)
|
| 679 |
elif config.classifier_pooling == "max_cls":
|
| 680 |
k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
|
| 681 |
theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
|
| 682 |
+
pooled_seq = torch.max(theta, dim=1).values # Element-wise max over k -> (batch, seq_len, hidden)
|
|
|
|
|
|
|
| 683 |
last_hidden_state = pooled_seq[:, 0, :] # (batch, hidden)
|
| 684 |
elif config.classifier_pooling == "cls_mha":
|
| 685 |
# Similar to max_seq_mha but without the max pooling step
|
|
|
|
| 690 |
q=q, kv=last_hidden_state, attention_mask=attention_mask
|
| 691 |
) # (batch, seq_len, hidden)
|
| 692 |
last_hidden_state = torch.mean(attn_out, dim=1)
|
| 693 |
+
elif config.classifier_pooling in {"max_seq_mha", "mean_seq_mha"}:
|
| 694 |
k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
|
| 695 |
theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
|
| 696 |
+
if config.classifier_pooling == "max_seq_mha":
|
| 697 |
+
pooled_seq = torch.max(theta, dim=1).values # Element-wise max over k -> (batch, seq_len, hidden)
|
| 698 |
+
else:
|
| 699 |
+
pooled_seq = torch.mean(theta, dim=1) # Element-wise mean over k -> (batch, seq_len, hidden)
|
| 700 |
# Query is pooled CLS token (position 0); Keys/Values are pooled sequence
|
| 701 |
q = pooled_seq[:, 0, :].unsqueeze(1) # (batch, 1, hidden)
|
| 702 |
q = q.expand(-1, pooled_seq.shape[1], -1) # (batch, seq_len, hidden)
|
|
|
|
| 707 |
elif config.classifier_pooling == "max_seq_mean":
|
| 708 |
k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
|
| 709 |
theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
|
| 710 |
+
pooled_seq = torch.max(theta, dim=1).values # Element-wise max over k -> (batch, seq_len, hidden)
|
|
|
|
|
|
|
| 711 |
last_hidden_state = torch.mean(pooled_seq, dim=1) # Mean over sequence length
|
| 712 |
elif config.classifier_pooling == "sum_mean":
|
| 713 |
# ChemLM uses the mean of all hidden states
|
|
|
|
| 733 |
|
| 734 |
|
| 735 |
__all__ = [
|
| 736 |
+
"ModChemBertModel",
|
| 737 |
"ModChemBertForMaskedLM",
|
| 738 |
"ModChemBertForSequenceClassification",
|
| 739 |
]
|
similarity_evaluation_pubchem_10m_genmol_similarity_float32_results.csv
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
epoch,steps,spearman
|
| 2 |
0,0,0.7261446896400275
|
| 3 |
-
0.2499925700642937,25235,0.
|
| 4 |
-
0.4999851401285874,50470,0.
|
| 5 |
-
0.7499777101928812,75705,0.
|
| 6 |
-
0.9999702802571748,100940,0.
|
| 7 |
-
1.2499628503214686,126175,0.
|
| 8 |
-
1.4999554203857621,151410,0.
|
| 9 |
-
1.749947990450056,176645,0.
|
| 10 |
-
1.9999405605143497,201880,0.
|
| 11 |
-
2.2499331305786434,227115,0.
|
| 12 |
-
2.499925700642937,252350,0.
|
| 13 |
-
2.749918270707231,277585,0.
|
| 14 |
-
2.9999108407715243,302820,0.
|
|
|
|
| 1 |
epoch,steps,spearman
|
| 2 |
0,0,0.7261446896400275
|
| 3 |
+
0.2499925700642937,25235,0.906718265018918
|
| 4 |
+
0.4999851401285874,50470,0.9655444741087182
|
| 5 |
+
0.7499777101928812,75705,0.9779964615343857
|
| 6 |
+
0.9999702802571748,100940,0.9828579834801283
|
| 7 |
+
1.2499628503214686,126175,0.9855222540861318
|
| 8 |
+
1.4999554203857621,151410,0.986820997047069
|
| 9 |
+
1.749947990450056,176645,0.9879349539641308
|
| 10 |
+
1.9999405605143497,201880,0.9885304751015874
|
| 11 |
+
2.2499331305786434,227115,0.9889206748932795
|
| 12 |
+
2.499925700642937,252350,0.989034117619882
|
| 13 |
+
2.749918270707231,277585,0.9891366381020936
|
| 14 |
+
2.9999108407715243,302820,0.9891427187036199
|