|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import math |
|
|
import typing |
|
|
from contextlib import nullcontext |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
|
|
from transformers.modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput |
|
|
from transformers.models.modernbert.modeling_modernbert import ( |
|
|
MODERNBERT_ATTENTION_FUNCTION, |
|
|
ModernBertEmbeddings, |
|
|
ModernBertEncoderLayer, |
|
|
ModernBertModel, |
|
|
ModernBertPredictionHead, |
|
|
ModernBertPreTrainedModel, |
|
|
ModernBertRotaryEmbedding, |
|
|
_pad_modernbert_output, |
|
|
_unpad_modernbert_input, |
|
|
) |
|
|
from transformers.utils import logging |
|
|
|
|
|
from .configuration_modchembert import ModChemBertConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class InitWeightsMixin: |
|
|
def _init_weights(self, module: nn.Module): |
|
|
super()._init_weights(module) |
|
|
|
|
|
cutoff_factor = self.config.initializer_cutoff_factor |
|
|
if cutoff_factor is None: |
|
|
cutoff_factor = 3 |
|
|
|
|
|
def init_weight(module: nn.Module, std: float): |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.trunc_normal_( |
|
|
module.weight, |
|
|
mean=0.0, |
|
|
std=std, |
|
|
a=-cutoff_factor * std, |
|
|
b=cutoff_factor * std, |
|
|
) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
stds = { |
|
|
"in": self.config.initializer_range, |
|
|
"out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), |
|
|
"final_out": self.config.hidden_size**-0.5, |
|
|
} |
|
|
|
|
|
if isinstance(module, ModChemBertForMaskedLM): |
|
|
init_weight(module.decoder, stds["out"]) |
|
|
elif isinstance(module, ModChemBertForSequenceClassification): |
|
|
init_weight(module.classifier, stds["final_out"]) |
|
|
elif isinstance(module, ModChemBertPoolingAttention): |
|
|
init_weight(module.Wq, stds["in"]) |
|
|
init_weight(module.Wk, stds["in"]) |
|
|
init_weight(module.Wv, stds["in"]) |
|
|
init_weight(module.Wo, stds["out"]) |
|
|
|
|
|
|
|
|
class ModChemBertPoolingAttention(nn.Module): |
|
|
"""Performs multi-headed self attention on a batch of sequences.""" |
|
|
|
|
|
def __init__(self, config: ModChemBertConfig): |
|
|
super().__init__() |
|
|
self.config = copy.deepcopy(config) |
|
|
|
|
|
self.config.num_attention_heads = config.classifier_pooling_num_attention_heads |
|
|
|
|
|
self.config.attention_dropout = config.classifier_pooling_attention_dropout |
|
|
|
|
|
if config.hidden_size % config.num_attention_heads != 0: |
|
|
raise ValueError( |
|
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads " |
|
|
f"({config.num_attention_heads})" |
|
|
) |
|
|
|
|
|
self.attention_dropout = config.attention_dropout |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.head_dim = config.hidden_size // config.num_attention_heads |
|
|
self.all_head_size = self.head_dim * self.num_heads |
|
|
self.Wq = nn.Linear(config.hidden_size, self.all_head_size, bias=config.attention_bias) |
|
|
self.Wk = nn.Linear(config.hidden_size, self.all_head_size, bias=config.attention_bias) |
|
|
self.Wv = nn.Linear(config.hidden_size, self.all_head_size, bias=config.attention_bias) |
|
|
|
|
|
|
|
|
self.local_attention = (-1, -1) |
|
|
rope_theta = config.global_rope_theta |
|
|
|
|
|
config_copy = copy.deepcopy(config) |
|
|
config_copy.rope_theta = rope_theta |
|
|
self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy) |
|
|
|
|
|
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) |
|
|
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() |
|
|
self.pruned_heads = set() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
q: torch.Tensor, |
|
|
kv: torch.Tensor, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
bs, seq_len = kv.shape[:2] |
|
|
q_proj: torch.Tensor = self.Wq(q) |
|
|
k_proj: torch.Tensor = self.Wk(kv) |
|
|
v_proj: torch.Tensor = self.Wv(kv) |
|
|
qkv = torch.stack( |
|
|
( |
|
|
q_proj.reshape(bs, seq_len, self.num_heads, self.head_dim), |
|
|
k_proj.reshape(bs, seq_len, self.num_heads, self.head_dim), |
|
|
v_proj.reshape(bs, seq_len, self.num_heads, self.head_dim), |
|
|
), |
|
|
dim=2, |
|
|
) |
|
|
|
|
|
device = kv.device |
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones((bs, seq_len), device=device, dtype=torch.bool) |
|
|
position_ids = torch.arange(seq_len, device=device).unsqueeze(0).long() |
|
|
|
|
|
attn_outputs = MODERNBERT_ATTENTION_FUNCTION["sdpa"]( |
|
|
self, |
|
|
qkv=qkv, |
|
|
attention_mask=_prepare_4d_attention_mask(attention_mask, kv.dtype), |
|
|
sliding_window_mask=None, |
|
|
position_ids=position_ids, |
|
|
local_attention=self.local_attention, |
|
|
bs=bs, |
|
|
dim=self.all_head_size, |
|
|
**kwargs, |
|
|
) |
|
|
hidden_states = attn_outputs[0] |
|
|
hidden_states = self.out_drop(self.Wo(hidden_states)) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class ModChemBertModel(ModernBertPreTrainedModel): |
|
|
config_class = ModChemBertConfig |
|
|
|
|
|
def __init__(self, config: ModChemBertConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.embeddings = ModernBertEmbeddings(config) |
|
|
self.layers = nn.ModuleList( |
|
|
[ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] |
|
|
) |
|
|
self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) |
|
|
self.gradient_checkpointing = False |
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embeddings.tok_embeddings |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.embeddings.tok_embeddings = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
sliding_window_mask: torch.Tensor | None = None, |
|
|
position_ids: torch.LongTensor | None = None, |
|
|
inputs_embeds: torch.Tensor | None = None, |
|
|
indices: torch.Tensor | None = None, |
|
|
cu_seqlens: torch.Tensor | None = None, |
|
|
max_seqlen: int | None = None, |
|
|
batch_size: int | None = None, |
|
|
seq_len: int | None = None, |
|
|
output_attentions: bool | None = None, |
|
|
output_hidden_states: bool | None = None, |
|
|
return_dict: bool | None = None, |
|
|
) -> tuple[torch.Tensor, ...] | BaseModelOutput: |
|
|
r""" |
|
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers |
|
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to |
|
|
far-away tokens in the local attention layers when not using Flash Attention. |
|
|
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): |
|
|
Indices of the non-padding tokens in the input sequence. Used for unpadding the output. |
|
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): |
|
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. |
|
|
max_seqlen (`int`, *optional*): |
|
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors. |
|
|
batch_size (`int`, *optional*): |
|
|
Batch size of the input sequences. Used to pad the output tensors. |
|
|
seq_len (`int`, *optional*): |
|
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors. |
|
|
""" |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_self_attentions = () if output_attentions else None |
|
|
|
|
|
self._maybe_set_compile() |
|
|
|
|
|
if input_ids is not None: |
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
|
|
|
|
|
if batch_size is None and seq_len is None: |
|
|
if inputs_embeds is not None: |
|
|
batch_size, seq_len = inputs_embeds.shape[:2] |
|
|
else: |
|
|
batch_size, seq_len = input_ids.shape[:2] |
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) |
|
|
|
|
|
repad = False |
|
|
if self.config._attn_implementation == "flash_attention_2": |
|
|
if indices is None and cu_seqlens is None and max_seqlen is None: |
|
|
repad = True |
|
|
if inputs_embeds is None: |
|
|
with torch.no_grad(): |
|
|
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( |
|
|
inputs=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
else: |
|
|
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( |
|
|
inputs=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
else: |
|
|
if position_ids is None: |
|
|
position_ids = torch.arange(seq_len, device=device).unsqueeze(0) |
|
|
|
|
|
attention_mask, sliding_window_mask = self._update_attention_mask( |
|
|
attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
|
|
|
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) |
|
|
|
|
|
for encoder_layer in self.layers: |
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
layer_outputs = encoder_layer( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
sliding_window_mask=sliding_window_mask, |
|
|
position_ids=position_ids, |
|
|
cu_seqlens=cu_seqlens, |
|
|
max_seqlen=max_seqlen, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
hidden_states = layer_outputs[0] |
|
|
if output_attentions and len(layer_outputs) > 1: |
|
|
all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
hidden_states = self.final_norm(hidden_states) |
|
|
|
|
|
if repad: |
|
|
hidden_states = _pad_modernbert_output( |
|
|
inputs=hidden_states, |
|
|
indices=indices, |
|
|
batch=batch_size, |
|
|
seqlen=seq_len, |
|
|
) |
|
|
if all_hidden_states is not None: |
|
|
all_hidden_states = tuple( |
|
|
_pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) |
|
|
for hs in all_hidden_states |
|
|
) |
|
|
|
|
|
if not return_dict: |
|
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) |
|
|
return BaseModelOutput( |
|
|
last_hidden_state=hidden_states, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_self_attentions, |
|
|
) |
|
|
|
|
|
def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor: |
|
|
if output_attentions: |
|
|
if self.config._attn_implementation == "sdpa": |
|
|
logger.warning_once( |
|
|
"Outputting attentions is only supported with the 'eager' attention implementation, " |
|
|
'not with "sdpa". Falling back to `attn_implementation="eager"`.' |
|
|
) |
|
|
self.config._attn_implementation = "eager" |
|
|
elif self.config._attn_implementation != "eager": |
|
|
logger.warning_once( |
|
|
"Outputting attentions is only supported with the eager attention implementation, " |
|
|
f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.' |
|
|
" Setting `output_attentions=False`." |
|
|
) |
|
|
|
|
|
global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) |
|
|
|
|
|
|
|
|
rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0) |
|
|
|
|
|
distance = torch.abs(rows - rows.T) |
|
|
|
|
|
|
|
|
window_mask = (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device) |
|
|
|
|
|
sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min) |
|
|
|
|
|
return global_attention_mask, sliding_window_mask |
|
|
|
|
|
|
|
|
class ModChemBertForMaskedLM(InitWeightsMixin, ModernBertPreTrainedModel): |
|
|
config_class = ModChemBertConfig |
|
|
_tied_weights_keys = ["decoder.weight"] |
|
|
|
|
|
def __init__(self, config: ModChemBertConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.model = ModChemBertModel(config) |
|
|
self.head = ModernBertPredictionHead(config) |
|
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias) |
|
|
|
|
|
self.sparse_prediction = self.config.sparse_prediction |
|
|
self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.decoder |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings: nn.Linear): |
|
|
self.decoder = new_embeddings |
|
|
|
|
|
@torch.compile(dynamic=True) |
|
|
def compiled_head(self, output: torch.Tensor) -> torch.Tensor: |
|
|
return self.decoder(self.head(output)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
sliding_window_mask: torch.Tensor | None = None, |
|
|
position_ids: torch.Tensor | None = None, |
|
|
inputs_embeds: torch.Tensor | None = None, |
|
|
labels: torch.Tensor | None = None, |
|
|
indices: torch.Tensor | None = None, |
|
|
cu_seqlens: torch.Tensor | None = None, |
|
|
max_seqlen: int | None = None, |
|
|
batch_size: int | None = None, |
|
|
seq_len: int | None = None, |
|
|
output_attentions: bool | None = None, |
|
|
output_hidden_states: bool | None = None, |
|
|
return_dict: bool | None = None, |
|
|
**kwargs, |
|
|
) -> tuple[torch.Tensor] | tuple[torch.Tensor, typing.Any] | MaskedLMOutput: |
|
|
r""" |
|
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers |
|
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to |
|
|
far-away tokens in the local attention layers when not using Flash Attention. |
|
|
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): |
|
|
Indices of the non-padding tokens in the input sequence. Used for unpadding the output. |
|
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): |
|
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. |
|
|
max_seqlen (`int`, *optional*): |
|
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids & pad output tensors. |
|
|
batch_size (`int`, *optional*): |
|
|
Batch size of the input sequences. Used to pad the output tensors. |
|
|
seq_len (`int`, *optional*): |
|
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors. |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
self._maybe_set_compile() |
|
|
|
|
|
if self.config._attn_implementation == "flash_attention_2": |
|
|
if indices is None and cu_seqlens is None and max_seqlen is None: |
|
|
if batch_size is None and seq_len is None: |
|
|
if inputs_embeds is not None: |
|
|
batch_size, seq_len = inputs_embeds.shape[:2] |
|
|
else: |
|
|
batch_size, seq_len = input_ids.shape[:2] |
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) |
|
|
|
|
|
if inputs_embeds is None: |
|
|
with torch.no_grad(): |
|
|
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( |
|
|
inputs=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
labels=labels, |
|
|
) |
|
|
else: |
|
|
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( |
|
|
inputs=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
labels=labels, |
|
|
) |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
sliding_window_mask=sliding_window_mask, |
|
|
position_ids=position_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
indices=indices, |
|
|
cu_seqlens=cu_seqlens, |
|
|
max_seqlen=max_seqlen, |
|
|
batch_size=batch_size, |
|
|
seq_len=seq_len, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
last_hidden_state = outputs[0] |
|
|
|
|
|
if self.sparse_prediction and labels is not None: |
|
|
|
|
|
labels = labels.view(-1) |
|
|
last_hidden_state = last_hidden_state.view(labels.shape[0], -1) |
|
|
|
|
|
|
|
|
mask_tokens = labels != self.sparse_pred_ignore_index |
|
|
last_hidden_state = last_hidden_state[mask_tokens] |
|
|
labels = labels[mask_tokens] |
|
|
|
|
|
logits = ( |
|
|
self.compiled_head(last_hidden_state) |
|
|
if self.config.reference_compile |
|
|
else self.decoder(self.head(last_hidden_state)) |
|
|
) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
|
|
|
if self.config._attn_implementation == "flash_attention_2": |
|
|
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): |
|
|
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return MaskedLMOutput( |
|
|
loss=loss, |
|
|
logits=typing.cast(torch.FloatTensor, logits), |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
class ModChemBertForSequenceClassification(InitWeightsMixin, ModernBertPreTrainedModel): |
|
|
config_class = ModChemBertConfig |
|
|
|
|
|
def __init__(self, config: ModChemBertConfig): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
self.config = config |
|
|
|
|
|
self.model = ModernBertModel(config) |
|
|
if self.config.classifier_pooling in {"cls_mha", "max_seq_mha", "mean_seq_mha"}: |
|
|
self.pooling_attn = ModChemBertPoolingAttention(config=self.config) |
|
|
else: |
|
|
self.pooling_attn = None |
|
|
self.head = ModernBertPredictionHead(config) |
|
|
self.drop = torch.nn.Dropout(config.classifier_dropout) |
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor | None = None, |
|
|
attention_mask: torch.Tensor | None = None, |
|
|
sliding_window_mask: torch.Tensor | None = None, |
|
|
position_ids: torch.Tensor | None = None, |
|
|
inputs_embeds: torch.Tensor | None = None, |
|
|
labels: torch.Tensor | None = None, |
|
|
indices: torch.Tensor | None = None, |
|
|
cu_seqlens: torch.Tensor | None = None, |
|
|
max_seqlen: int | None = None, |
|
|
batch_size: int | None = None, |
|
|
seq_len: int | None = None, |
|
|
output_attentions: bool | None = None, |
|
|
output_hidden_states: bool | None = None, |
|
|
return_dict: bool | None = None, |
|
|
**kwargs, |
|
|
) -> tuple[torch.Tensor] | tuple[torch.Tensor, typing.Any] | SequenceClassifierOutput: |
|
|
r""" |
|
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers |
|
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to |
|
|
far-away tokens in the local attention layers when not using Flash Attention. |
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
|
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*): |
|
|
Indices of the non-padding tokens in the input sequence. Used for unpadding the output. |
|
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*): |
|
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors. |
|
|
max_seqlen (`int`, *optional*): |
|
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids & pad output tensors. |
|
|
batch_size (`int`, *optional*): |
|
|
Batch size of the input sequences. Used to pad the output tensors. |
|
|
seq_len (`int`, *optional*): |
|
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors. |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
self._maybe_set_compile() |
|
|
|
|
|
if input_ids is not None: |
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
|
|
|
|
|
if batch_size is None and seq_len is None: |
|
|
if inputs_embeds is not None: |
|
|
batch_size, seq_len = inputs_embeds.shape[:2] |
|
|
else: |
|
|
batch_size, seq_len = input_ids.shape[:2] |
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) |
|
|
|
|
|
|
|
|
output_hidden_states = True |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
sliding_window_mask=sliding_window_mask, |
|
|
position_ids=position_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
indices=indices, |
|
|
cu_seqlens=cu_seqlens, |
|
|
max_seqlen=max_seqlen, |
|
|
batch_size=batch_size, |
|
|
seq_len=seq_len, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
last_hidden_state = outputs[0] |
|
|
hidden_states = outputs[1] |
|
|
|
|
|
last_hidden_state = _pool_modchembert_output( |
|
|
self, |
|
|
last_hidden_state, |
|
|
hidden_states, |
|
|
typing.cast(torch.Tensor, attention_mask), |
|
|
) |
|
|
pooled_output = self.head(last_hidden_state) |
|
|
pooled_output = self.drop(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
if self.config.problem_type is None: |
|
|
if self.num_labels == 1: |
|
|
self.config.problem_type = "regression" |
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
|
self.config.problem_type = "single_label_classification" |
|
|
else: |
|
|
self.config.problem_type = "multi_label_classification" |
|
|
|
|
|
if self.config.problem_type == "regression": |
|
|
loss_fct = MSELoss() |
|
|
if self.num_labels == 1: |
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
|
else: |
|
|
loss = loss_fct(logits, labels) |
|
|
elif self.config.problem_type == "single_label_classification": |
|
|
loss_fct = CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
elif self.config.problem_type == "multi_label_classification": |
|
|
loss_fct = BCEWithLogitsLoss() |
|
|
loss = loss_fct(logits, labels) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return SequenceClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
def _pool_modchembert_output( |
|
|
module: ModChemBertForSequenceClassification, |
|
|
last_hidden_state: torch.Tensor, |
|
|
hidden_states: list[torch.Tensor], |
|
|
attention_mask: torch.Tensor, |
|
|
): |
|
|
""" |
|
|
Apply pooling strategy to hidden states for sequence-level classification/regression tasks. |
|
|
|
|
|
This function implements various pooling strategies to aggregate sequence representations |
|
|
into a single vector for downstream classification or regression tasks. The pooling method |
|
|
is determined by the `classifier_pooling` configuration parameter. |
|
|
|
|
|
Available pooling strategies: |
|
|
- cls: Use the CLS token ([CLS]) representation from the last hidden state |
|
|
- mean: Average pooling over all tokens in the sequence (attention-weighted) |
|
|
- max_cls: Element-wise max pooling over the last k hidden states, then take CLS token |
|
|
- cls_mha: Multi-head attention with CLS token as query and full sequence as keys/values |
|
|
- max_seq_mha: Max pooling over last k states + multi-head attention with CLS as query |
|
|
- mean_seq_mha: Mean pooling over last k states + multi-head attention with CLS as query |
|
|
- max_seq_mean: Max pooling over last k hidden states, then mean pooling over sequence |
|
|
- sum_mean: Sum all hidden states across layers, then mean pool over sequence |
|
|
- sum_sum: Sum all hidden states across layers, then sum pool over sequence |
|
|
- mean_sum: Mean all hidden states across layers, then sum pool over sequence |
|
|
- mean_mean: Mean all hidden states across layers, then mean pool over sequence |
|
|
|
|
|
Args: |
|
|
module: The model instance containing configuration and pooling attention if needed |
|
|
last_hidden_state: Final layer hidden states of shape (batch_size, seq_len, hidden_size) |
|
|
hidden_states: List of hidden states from all layers, each of shape (batch_size, seq_len, hidden_size) |
|
|
attention_mask: Attention mask of shape (batch_size, seq_len) indicating valid tokens |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Pooled representation of shape (batch_size, hidden_size) |
|
|
|
|
|
Note: |
|
|
Some pooling strategies (cls_mha, max_seq_mha, mean_seq_mha) require the module to have a pooling_attn |
|
|
attribute containing a ModChemBertPoolingAttention instance. |
|
|
""" |
|
|
config = typing.cast(ModChemBertConfig, module.config) |
|
|
if config.classifier_pooling == "cls": |
|
|
last_hidden_state = last_hidden_state[:, 0] |
|
|
elif config.classifier_pooling == "mean": |
|
|
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( |
|
|
dim=1, keepdim=True |
|
|
) |
|
|
elif config.classifier_pooling == "max_cls": |
|
|
k_hidden_states = hidden_states[-config.classifier_pooling_last_k :] |
|
|
theta = torch.stack(k_hidden_states, dim=1) |
|
|
pooled_seq = torch.max(theta, dim=1).values |
|
|
last_hidden_state = pooled_seq[:, 0, :] |
|
|
elif config.classifier_pooling == "cls_mha": |
|
|
|
|
|
|
|
|
q = last_hidden_state[:, 0, :].unsqueeze(1) |
|
|
q = q.expand(-1, last_hidden_state.shape[1], -1) |
|
|
attn_out: torch.Tensor = module.pooling_attn( |
|
|
q=q, kv=last_hidden_state, attention_mask=attention_mask |
|
|
) |
|
|
last_hidden_state = torch.mean(attn_out, dim=1) |
|
|
elif config.classifier_pooling in {"max_seq_mha", "mean_seq_mha"}: |
|
|
k_hidden_states = hidden_states[-config.classifier_pooling_last_k :] |
|
|
theta = torch.stack(k_hidden_states, dim=1) |
|
|
if config.classifier_pooling == "max_seq_mha": |
|
|
pooled_seq = torch.max(theta, dim=1).values |
|
|
else: |
|
|
pooled_seq = torch.mean(theta, dim=1) |
|
|
|
|
|
q = pooled_seq[:, 0, :].unsqueeze(1) |
|
|
q = q.expand(-1, pooled_seq.shape[1], -1) |
|
|
attn_out: torch.Tensor = module.pooling_attn( |
|
|
q=q, kv=pooled_seq, attention_mask=attention_mask |
|
|
) |
|
|
last_hidden_state = torch.mean(attn_out, dim=1) |
|
|
elif config.classifier_pooling == "max_seq_mean": |
|
|
k_hidden_states = hidden_states[-config.classifier_pooling_last_k :] |
|
|
theta = torch.stack(k_hidden_states, dim=1) |
|
|
pooled_seq = torch.max(theta, dim=1).values |
|
|
last_hidden_state = torch.mean(pooled_seq, dim=1) |
|
|
elif config.classifier_pooling == "sum_mean": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_hidden_states = torch.stack(hidden_states) |
|
|
w = torch.sum(all_hidden_states, dim=0) |
|
|
last_hidden_state = torch.mean(w, dim=1) |
|
|
elif config.classifier_pooling == "sum_sum": |
|
|
all_hidden_states = torch.stack(hidden_states) |
|
|
w = torch.sum(all_hidden_states, dim=0) |
|
|
last_hidden_state = torch.sum(w, dim=1) |
|
|
elif config.classifier_pooling == "mean_sum": |
|
|
all_hidden_states = torch.stack(hidden_states) |
|
|
w = torch.mean(all_hidden_states, dim=0) |
|
|
last_hidden_state = torch.sum(w, dim=1) |
|
|
elif config.classifier_pooling == "mean_mean": |
|
|
all_hidden_states = torch.stack(hidden_states) |
|
|
w = torch.mean(all_hidden_states, dim=0) |
|
|
last_hidden_state = torch.mean(w, dim=1) |
|
|
return last_hidden_state |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"ModChemBertModel", |
|
|
"ModChemBertForMaskedLM", |
|
|
"ModChemBertForSequenceClassification", |
|
|
] |
|
|
|