from torch import nn from transformers import AutoModel, PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import QuestionAnsweringModelOutput class CustomQAModelConfig(PretrainedConfig): model_type = "modernbert" def __init__(self, base_model_name_or_path="answerdotai/ModernBERT-base", **kwargs): self.base_model_name_or_path = base_model_name_or_path super().__init__(**kwargs) class CustomQAModel(PreTrainedModel): config_class = CustomQAModelConfig def __init__(self, config): super().__init__(config) self.base = AutoModel.from_pretrained(config.base_model_name_or_path) hidden_size = self.base.config.hidden_size self.qa_outputs = nn.Linear(hidden_size, 2) self.loss_fn = nn.CrossEntropyLoss() def forward( self, input_ids=None, attention_mask=None, start_positions=None, end_positions=None, ): outputs = self.base( input_ids=input_ids, attention_mask=attention_mask, ) hidden_states = outputs.last_hidden_state logits = self.qa_outputs(hidden_states) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) loss = None if start_positions is not None and end_positions is not None: start_positions = start_positions.clamp(0, start_logits.size(1) - 1) end_positions = end_positions.clamp(0, end_logits.size(1) - 1) start_loss = self.loss_fn(start_logits, start_positions) end_loss = self.loss_fn(end_logits, end_positions) loss = (start_loss + end_loss) / 2 return QuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, end_logits=end_logits, )