File size: 3,185 Bytes
df0983e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel
from .configuration_viconbert import ViConBERTConfig


class MLPBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim,
                 num_layers=2, dropout=0.3, activation=nn.GELU, use_residual=True):
        super().__init__()
        self.use_residual = use_residual
        self.activation_fn = activation()

        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.hidden_layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        for _ in range(num_layers):
            self.hidden_layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.norms.append(nn.LayerNorm(hidden_dim))
            self.dropouts.append(nn.Dropout(dropout))
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.input_layer(x)
        for layer, norm, dropout in zip(self.hidden_layers, self.norms, self.dropouts):
            residual = x
            x = layer(x)
            x = norm(x)
            x = dropout(x)
            x = self.activation_fn(x)
            if self.use_residual:
                x = x + residual
        x = self.output_layer(x)
        return x


class ViConBERT(PreTrainedModel):
    config_class = ViConBERTConfig

    def __init__(self, config):
        super().__init__(config)
        self.context_encoder = AutoModel.from_pretrained(
            config.base_model, cache_dir=config.base_model_cache_dir
        )
        self.context_projection = MLPBlock(
            self.context_encoder.config.hidden_size,
            config.hidden_dim,
            config.out_dim,
            dropout=config.dropout,
            num_layers=config.num_layers
        )
        self.context_attention = nn.MultiheadAttention(
            self.context_encoder.config.hidden_size,
            num_heads=config.num_head,
            dropout=config.dropout
        )
        self.context_window_size = config.context_window_size
        self.context_layer_weights = nn.Parameter(
            torch.zeros(self.context_encoder.config.num_hidden_layers)
        )
        self.post_init()

    def _encode_context_attentive(self, text, target_span):
        outputs = self.context_encoder(**text)
        hidden_states = outputs[0]
        start_pos, end_pos = target_span[:, 0], target_span[:, 1]

        positions = torch.arange(hidden_states.size(1), device=hidden_states.device)
        mask = (positions >= start_pos.unsqueeze(1)) & (positions <= end_pos.unsqueeze(1))
        masked_states = hidden_states * mask.unsqueeze(-1)
        span_lengths = mask.sum(dim=1, keepdim=True).clamp(min=1)
        pooled_embeddings = masked_states.sum(dim=1) / span_lengths

        Q_value = pooled_embeddings.unsqueeze(0)
        KV_value = hidden_states.permute(1, 0, 2)
        context_emb, _ = self.context_attention(Q_value, KV_value, KV_value)
        return context_emb

    def forward(self, context, target_span):
        context_emb = self._encode_context_attentive(context, target_span)
        return self.context_projection(context_emb.squeeze(0))