| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, AutoModel | |
| from .configuration_viconbert import ViConBERTConfig | |
| class MLPBlock(nn.Module): | |
| def __init__(self, input_dim, hidden_dim, output_dim, | |
| num_layers=2, dropout=0.3, activation=nn.GELU, use_residual=True): | |
| super().__init__() | |
| self.use_residual = use_residual | |
| self.activation_fn = activation() | |
| self.input_layer = nn.Linear(input_dim, hidden_dim) | |
| self.hidden_layers = nn.ModuleList() | |
| self.norms = nn.ModuleList() | |
| self.dropouts = nn.ModuleList() | |
| for _ in range(num_layers): | |
| self.hidden_layers.append(nn.Linear(hidden_dim, hidden_dim)) | |
| self.norms.append(nn.LayerNorm(hidden_dim)) | |
| self.dropouts.append(nn.Dropout(dropout)) | |
| self.output_layer = nn.Linear(hidden_dim, output_dim) | |
| def forward(self, x): | |
| x = self.input_layer(x) | |
| for layer, norm, dropout in zip(self.hidden_layers, self.norms, self.dropouts): | |
| residual = x | |
| x = layer(x) | |
| x = norm(x) | |
| x = dropout(x) | |
| x = self.activation_fn(x) | |
| if self.use_residual: | |
| x = x + residual | |
| x = self.output_layer(x) | |
| return x | |
| class ViConBERT(PreTrainedModel): | |
| config_class = ViConBERTConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.context_encoder = AutoModel.from_pretrained( | |
| config.base_model, cache_dir=config.base_model_cache_dir | |
| ) | |
| self.context_projection = MLPBlock( | |
| self.context_encoder.config.hidden_size, | |
| config.hidden_dim, | |
| config.out_dim, | |
| dropout=config.dropout, | |
| num_layers=config.num_layers | |
| ) | |
| self.context_attention = nn.MultiheadAttention( | |
| self.context_encoder.config.hidden_size, | |
| num_heads=config.num_head, | |
| dropout=config.dropout | |
| ) | |
| self.context_window_size = config.context_window_size | |
| self.context_layer_weights = nn.Parameter( | |
| torch.zeros(self.context_encoder.config.num_hidden_layers) | |
| ) | |
| self.post_init() | |
| def _encode_context_attentive(self, text, target_span): | |
| outputs = self.context_encoder(**text) | |
| hidden_states = outputs[0] | |
| start_pos, end_pos = target_span[:, 0], target_span[:, 1] | |
| positions = torch.arange(hidden_states.size(1), device=hidden_states.device) | |
| mask = (positions >= start_pos.unsqueeze(1)) & (positions <= end_pos.unsqueeze(1)) | |
| masked_states = hidden_states * mask.unsqueeze(-1) | |
| span_lengths = mask.sum(dim=1, keepdim=True).clamp(min=1) | |
| pooled_embeddings = masked_states.sum(dim=1) / span_lengths | |
| Q_value = pooled_embeddings.unsqueeze(0) | |
| KV_value = hidden_states.permute(1, 0, 2) | |
| context_emb, _ = self.context_attention(Q_value, KV_value, KV_value) | |
| return context_emb | |
| def forward(self, context, target_span): | |
| context_emb = self._encode_context_attentive(context, target_span) | |
| return self.context_projection(context_emb.squeeze(0)) | |