viconbert-base / modeling_viconbert.py
tkhangg0910's picture
Upload folder using huggingface_hub
df0983e verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel
from .configuration_viconbert import ViConBERTConfig
class MLPBlock(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim,
num_layers=2, dropout=0.3, activation=nn.GELU, use_residual=True):
super().__init__()
self.use_residual = use_residual
self.activation_fn = activation()
self.input_layer = nn.Linear(input_dim, hidden_dim)
self.hidden_layers = nn.ModuleList()
self.norms = nn.ModuleList()
self.dropouts = nn.ModuleList()
for _ in range(num_layers):
self.hidden_layers.append(nn.Linear(hidden_dim, hidden_dim))
self.norms.append(nn.LayerNorm(hidden_dim))
self.dropouts.append(nn.Dropout(dropout))
self.output_layer = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.input_layer(x)
for layer, norm, dropout in zip(self.hidden_layers, self.norms, self.dropouts):
residual = x
x = layer(x)
x = norm(x)
x = dropout(x)
x = self.activation_fn(x)
if self.use_residual:
x = x + residual
x = self.output_layer(x)
return x
class ViConBERT(PreTrainedModel):
config_class = ViConBERTConfig
def __init__(self, config):
super().__init__(config)
self.context_encoder = AutoModel.from_pretrained(
config.base_model, cache_dir=config.base_model_cache_dir
)
self.context_projection = MLPBlock(
self.context_encoder.config.hidden_size,
config.hidden_dim,
config.out_dim,
dropout=config.dropout,
num_layers=config.num_layers
)
self.context_attention = nn.MultiheadAttention(
self.context_encoder.config.hidden_size,
num_heads=config.num_head,
dropout=config.dropout
)
self.context_window_size = config.context_window_size
self.context_layer_weights = nn.Parameter(
torch.zeros(self.context_encoder.config.num_hidden_layers)
)
self.post_init()
def _encode_context_attentive(self, text, target_span):
outputs = self.context_encoder(**text)
hidden_states = outputs[0]
start_pos, end_pos = target_span[:, 0], target_span[:, 1]
positions = torch.arange(hidden_states.size(1), device=hidden_states.device)
mask = (positions >= start_pos.unsqueeze(1)) & (positions <= end_pos.unsqueeze(1))
masked_states = hidden_states * mask.unsqueeze(-1)
span_lengths = mask.sum(dim=1, keepdim=True).clamp(min=1)
pooled_embeddings = masked_states.sum(dim=1) / span_lengths
Q_value = pooled_embeddings.unsqueeze(0)
KV_value = hidden_states.permute(1, 0, 2)
context_emb, _ = self.context_attention(Q_value, KV_value, KV_value)
return context_emb
def forward(self, context, target_span):
context_emb = self._encode_context_attentive(context, target_span)
return self.context_projection(context_emb.squeeze(0))