import torch import torch.nn as nn from transformers import PreTrainedModel, AutoModel from .configuration_viconbert import ViConBERTConfig class MLPBlock(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, dropout=0.3, activation=nn.GELU, use_residual=True): super().__init__() self.use_residual = use_residual self.activation_fn = activation() self.input_layer = nn.Linear(input_dim, hidden_dim) self.hidden_layers = nn.ModuleList() self.norms = nn.ModuleList() self.dropouts = nn.ModuleList() for _ in range(num_layers): self.hidden_layers.append(nn.Linear(hidden_dim, hidden_dim)) self.norms.append(nn.LayerNorm(hidden_dim)) self.dropouts.append(nn.Dropout(dropout)) self.output_layer = nn.Linear(hidden_dim, output_dim) def forward(self, x): x = self.input_layer(x) for layer, norm, dropout in zip(self.hidden_layers, self.norms, self.dropouts): residual = x x = layer(x) x = norm(x) x = dropout(x) x = self.activation_fn(x) if self.use_residual: x = x + residual x = self.output_layer(x) return x class ViConBERT(PreTrainedModel): config_class = ViConBERTConfig def __init__(self, config): super().__init__(config) self.context_encoder = AutoModel.from_pretrained( config.base_model, cache_dir=config.base_model_cache_dir ) self.context_projection = MLPBlock( self.context_encoder.config.hidden_size, config.hidden_dim, config.out_dim, dropout=config.dropout, num_layers=config.num_layers ) self.context_attention = nn.MultiheadAttention( self.context_encoder.config.hidden_size, num_heads=config.num_head, dropout=config.dropout ) self.context_window_size = config.context_window_size self.context_layer_weights = nn.Parameter( torch.zeros(self.context_encoder.config.num_hidden_layers) ) self.post_init() def _encode_context_attentive(self, text, target_span): outputs = self.context_encoder(**text) hidden_states = outputs[0] start_pos, end_pos = target_span[:, 0], target_span[:, 1] positions = torch.arange(hidden_states.size(1), device=hidden_states.device) mask = (positions >= start_pos.unsqueeze(1)) & (positions <= end_pos.unsqueeze(1)) masked_states = hidden_states * mask.unsqueeze(-1) span_lengths = mask.sum(dim=1, keepdim=True).clamp(min=1) pooled_embeddings = masked_states.sum(dim=1) / span_lengths Q_value = pooled_embeddings.unsqueeze(0) KV_value = hidden_states.permute(1, 0, 2) context_emb, _ = self.context_attention(Q_value, KV_value, KV_value) return context_emb def forward(self, context, target_span): context_emb = self._encode_context_attentive(context, target_span) return self.context_projection(context_emb.squeeze(0))