gradient checkpoint + cleanup
Browse files- modeling_lsg_bart.py +16 -453
modeling_lsg_bart.py
CHANGED
|
@@ -81,51 +81,6 @@ class LSGBartConfig(BartConfig):
|
|
| 81 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
| 82 |
|
| 83 |
|
| 84 |
-
def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
|
| 85 |
-
"""
|
| 86 |
-
Shift input ids one token to the right.
|
| 87 |
-
"""
|
| 88 |
-
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 89 |
-
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
| 90 |
-
shifted_input_ids[:, 0] = decoder_start_token_id
|
| 91 |
-
|
| 92 |
-
if pad_token_id is None:
|
| 93 |
-
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
| 94 |
-
# replace possible -100 values in labels by `pad_token_id`
|
| 95 |
-
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
| 96 |
-
|
| 97 |
-
return shifted_input_ids
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def _make_causal_mask(input_ids_shape, dtype, past_key_values_length=0):
|
| 101 |
-
"""
|
| 102 |
-
Make causal mask used for bi-directional self-attention.
|
| 103 |
-
"""
|
| 104 |
-
bsz, tgt_len = input_ids_shape
|
| 105 |
-
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
| 106 |
-
mask_cond = torch.arange(mask.size(-1))
|
| 107 |
-
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
| 108 |
-
mask = mask.to(dtype)
|
| 109 |
-
|
| 110 |
-
if past_key_values_length > 0:
|
| 111 |
-
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
|
| 112 |
-
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def _expand_mask(mask, dtype, tgt_len=None):
|
| 116 |
-
"""
|
| 117 |
-
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 118 |
-
"""
|
| 119 |
-
bsz, src_len = mask.size()
|
| 120 |
-
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 121 |
-
|
| 122 |
-
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 123 |
-
|
| 124 |
-
inverted_mask = 1.0 - expanded_mask
|
| 125 |
-
|
| 126 |
-
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
class BaseSelfAttention(nn.Module):
|
| 130 |
|
| 131 |
def __init__(
|
|
@@ -663,364 +618,27 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
| 663 |
return x.reshape(n, h, -1, chunk_size, d)
|
| 664 |
|
| 665 |
|
| 666 |
-
class
|
| 667 |
-
|
| 668 |
-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 669 |
-
|
| 670 |
-
def __init__(
|
| 671 |
-
self,
|
| 672 |
-
embed_dim,
|
| 673 |
-
num_heads,
|
| 674 |
-
dropout=0.0,
|
| 675 |
-
is_decoder=False,
|
| 676 |
-
bias=True,
|
| 677 |
-
):
|
| 678 |
-
|
| 679 |
-
super().__init__()
|
| 680 |
-
self.embed_dim = embed_dim
|
| 681 |
-
self.num_heads = num_heads
|
| 682 |
-
self.dropout = dropout
|
| 683 |
-
self.head_dim = embed_dim // num_heads
|
| 684 |
-
|
| 685 |
-
if (self.head_dim * num_heads) != self.embed_dim:
|
| 686 |
-
raise ValueError(
|
| 687 |
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
| 688 |
-
f" and `num_heads`: {num_heads})."
|
| 689 |
-
)
|
| 690 |
-
self.scaling = self.head_dim ** -0.5
|
| 691 |
-
self.is_decoder = is_decoder
|
| 692 |
-
|
| 693 |
-
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 694 |
-
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 695 |
-
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 696 |
-
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 697 |
-
|
| 698 |
-
def _shape(self, tensor, seq_len, bsz):
|
| 699 |
-
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 700 |
-
|
| 701 |
-
def forward(
|
| 702 |
-
self,
|
| 703 |
-
hidden_states,
|
| 704 |
-
key_value_states=None,
|
| 705 |
-
past_key_value=None,
|
| 706 |
-
attention_mask=None,
|
| 707 |
-
layer_head_mask=None,
|
| 708 |
-
output_attentions=False,
|
| 709 |
-
):
|
| 710 |
-
|
| 711 |
-
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 712 |
-
# for the decoder
|
| 713 |
-
is_cross_attention = key_value_states is not None
|
| 714 |
-
|
| 715 |
-
bsz, tgt_len, _ = hidden_states.size()
|
| 716 |
-
|
| 717 |
-
# get query proj
|
| 718 |
-
query_states = self.q_proj(hidden_states) * self.scaling
|
| 719 |
-
# get key, value proj
|
| 720 |
-
if is_cross_attention and past_key_value is not None:
|
| 721 |
-
# reuse k,v, cross_attentions
|
| 722 |
-
key_states = past_key_value[0]
|
| 723 |
-
value_states = past_key_value[1]
|
| 724 |
-
elif is_cross_attention:
|
| 725 |
-
# cross_attentions
|
| 726 |
-
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
| 727 |
-
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
| 728 |
-
elif past_key_value is not None:
|
| 729 |
-
# reuse k, v, self_attention
|
| 730 |
-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 731 |
-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 732 |
-
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 733 |
-
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 734 |
-
else:
|
| 735 |
-
# self_attention
|
| 736 |
-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 737 |
-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 738 |
-
|
| 739 |
-
if self.is_decoder:
|
| 740 |
-
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
| 741 |
-
# Further calls to cross_attention layer can then reuse all cross-attention
|
| 742 |
-
# key/value_states (first "if" case)
|
| 743 |
-
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
| 744 |
-
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 745 |
-
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 746 |
-
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 747 |
-
past_key_value = (key_states, value_states)
|
| 748 |
-
|
| 749 |
-
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
| 750 |
-
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
| 751 |
-
key_states = key_states.view(*proj_shape)
|
| 752 |
-
value_states = value_states.view(*proj_shape)
|
| 753 |
-
|
| 754 |
-
src_len = key_states.size(1)
|
| 755 |
-
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
| 756 |
-
|
| 757 |
-
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
| 758 |
-
raise ValueError(
|
| 759 |
-
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
| 760 |
-
)
|
| 761 |
-
|
| 762 |
-
if attention_mask is not None:
|
| 763 |
-
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
| 764 |
-
raise ValueError(
|
| 765 |
-
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
| 766 |
-
)
|
| 767 |
-
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
| 768 |
-
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 769 |
-
|
| 770 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 771 |
-
|
| 772 |
-
if layer_head_mask is not None:
|
| 773 |
-
if layer_head_mask.size() != (self.num_heads,):
|
| 774 |
-
raise ValueError(
|
| 775 |
-
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
| 776 |
-
)
|
| 777 |
-
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 778 |
-
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 779 |
-
|
| 780 |
-
if output_attentions:
|
| 781 |
-
# this operation is a bit awkward, but it's required to
|
| 782 |
-
# make sure that attn_weights keeps its gradient.
|
| 783 |
-
# In order to do so, attn_weights have to be reshaped
|
| 784 |
-
# twice and have to be reused in the following
|
| 785 |
-
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 786 |
-
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
| 787 |
-
else:
|
| 788 |
-
attn_weights_reshaped = None
|
| 789 |
-
|
| 790 |
-
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 791 |
-
|
| 792 |
-
attn_output = torch.bmm(attn_probs, value_states)
|
| 793 |
-
|
| 794 |
-
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
| 795 |
-
raise ValueError(
|
| 796 |
-
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
| 797 |
-
)
|
| 798 |
-
|
| 799 |
-
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
| 800 |
-
attn_output = attn_output.transpose(1, 2)
|
| 801 |
-
|
| 802 |
-
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
| 803 |
-
# partitioned aross GPUs when using tensor-parallelism.
|
| 804 |
-
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
| 805 |
-
|
| 806 |
-
attn_output = self.out_proj(attn_output)
|
| 807 |
-
|
| 808 |
-
return attn_output, attn_weights_reshaped, past_key_value
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
class LSGBartLearnedPositionalEmbedding(nn.Embedding):
|
| 812 |
-
"""
|
| 813 |
-
This module learns positional embeddings up to a fixed maximum size.
|
| 814 |
-
"""
|
| 815 |
-
|
| 816 |
-
def __init__(self, num_embeddings, embedding_dim):
|
| 817 |
-
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
| 818 |
-
# and adjust num_embeddings appropriately. Other models don't have this hack
|
| 819 |
-
self.offset = 2
|
| 820 |
-
super().__init__(num_embeddings + self.offset, embedding_dim)
|
| 821 |
-
|
| 822 |
-
def forward(self, input_ids_shape, past_key_values_length=0):
|
| 823 |
-
|
| 824 |
-
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
| 825 |
-
bsz, seq_len = input_ids_shape[:2]
|
| 826 |
-
positions = torch.arange(
|
| 827 |
-
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
| 828 |
-
)
|
| 829 |
-
return super().forward(positions + self.offset)
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
class LSGBartEncoderLayer(nn.Module):
|
| 833 |
|
| 834 |
def __init__(self, config):
|
| 835 |
|
| 836 |
-
super().__init__()
|
| 837 |
-
self.embed_dim = config.d_model
|
| 838 |
self.self_attn = LSGBartEncoderAttention(
|
| 839 |
config=config,
|
| 840 |
embed_dim=self.embed_dim,
|
| 841 |
num_heads=config.encoder_attention_heads,
|
| 842 |
dropout=config.attention_dropout,
|
| 843 |
)
|
| 844 |
-
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 845 |
-
self.dropout = config.dropout
|
| 846 |
-
self.activation_fn = ACT2FN[config.activation_function]
|
| 847 |
-
self.activation_dropout = config.activation_dropout
|
| 848 |
-
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
| 849 |
-
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
| 850 |
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 851 |
-
|
| 852 |
-
def forward(
|
| 853 |
-
self,
|
| 854 |
-
hidden_states,
|
| 855 |
-
attention_mask,
|
| 856 |
-
layer_head_mask,
|
| 857 |
-
output_attentions=False,
|
| 858 |
-
):
|
| 859 |
-
"""
|
| 860 |
-
Args:
|
| 861 |
-
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
| 862 |
-
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
| 863 |
-
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 864 |
-
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
| 865 |
-
`(encoder_attention_heads,)`.
|
| 866 |
-
output_attentions (:obj:`bool`, `optional`):
|
| 867 |
-
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
| 868 |
-
returned tensors for more detail.
|
| 869 |
-
"""
|
| 870 |
-
residual = hidden_states
|
| 871 |
-
hidden_states, attn_weights, _ = self.self_attn(
|
| 872 |
-
hidden_states=hidden_states,
|
| 873 |
-
attention_mask=attention_mask,
|
| 874 |
-
layer_head_mask=layer_head_mask,
|
| 875 |
-
output_attentions=output_attentions,
|
| 876 |
-
)
|
| 877 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 878 |
-
hidden_states = residual + hidden_states
|
| 879 |
-
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 880 |
-
|
| 881 |
-
residual = hidden_states
|
| 882 |
-
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 883 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 884 |
-
hidden_states = self.fc2(hidden_states)
|
| 885 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 886 |
-
hidden_states = residual + hidden_states
|
| 887 |
-
hidden_states = self.final_layer_norm(hidden_states)
|
| 888 |
-
|
| 889 |
-
if hidden_states.dtype == torch.float16 and (
|
| 890 |
-
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
| 891 |
-
):
|
| 892 |
-
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 893 |
-
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
| 894 |
-
|
| 895 |
-
outputs = (hidden_states,)
|
| 896 |
-
|
| 897 |
-
if output_attentions:
|
| 898 |
-
outputs += (attn_weights,)
|
| 899 |
-
|
| 900 |
-
return outputs
|
| 901 |
|
| 902 |
|
| 903 |
-
class LSGBartDecoderLayer(
|
| 904 |
|
| 905 |
def __init__(self, config):
|
| 906 |
|
| 907 |
-
super().__init__()
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
self.self_attn = LSGBartDecoderAttention(
|
| 911 |
-
embed_dim=self.embed_dim,
|
| 912 |
-
num_heads=config.decoder_attention_heads,
|
| 913 |
-
dropout=config.attention_dropout,
|
| 914 |
-
is_decoder=True,
|
| 915 |
-
)
|
| 916 |
-
self.dropout = config.dropout
|
| 917 |
-
self.activation_fn = ACT2FN[config.activation_function]
|
| 918 |
-
self.activation_dropout = config.activation_dropout
|
| 919 |
-
|
| 920 |
-
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 921 |
-
self.encoder_attn = LSGBartDecoderAttention(
|
| 922 |
-
self.embed_dim,
|
| 923 |
-
config.decoder_attention_heads,
|
| 924 |
-
dropout=config.attention_dropout,
|
| 925 |
-
is_decoder=True,
|
| 926 |
-
)
|
| 927 |
-
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 928 |
-
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
| 929 |
-
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
| 930 |
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
| 931 |
-
|
| 932 |
-
def forward(
|
| 933 |
-
self,
|
| 934 |
-
hidden_states,
|
| 935 |
-
attention_mask=None,
|
| 936 |
-
encoder_hidden_states=None,
|
| 937 |
-
encoder_attention_mask=None,
|
| 938 |
-
layer_head_mask=None,
|
| 939 |
-
cross_attn_layer_head_mask=None,
|
| 940 |
-
past_key_value=None,
|
| 941 |
-
output_attentions=False,
|
| 942 |
-
use_cache=True,
|
| 943 |
-
):
|
| 944 |
-
"""
|
| 945 |
-
Args:
|
| 946 |
-
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 947 |
-
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
| 948 |
-
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 949 |
-
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 950 |
-
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
| 951 |
-
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 952 |
-
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
| 953 |
-
`(encoder_attention_heads,)`.
|
| 954 |
-
cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
| 955 |
-
size `(decoder_attention_heads,)`.
|
| 956 |
-
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
| 957 |
-
output_attentions (:obj:`bool`, `optional`):
|
| 958 |
-
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
| 959 |
-
returned tensors for more detail.
|
| 960 |
-
"""
|
| 961 |
-
residual = hidden_states
|
| 962 |
-
|
| 963 |
-
# Self Attention
|
| 964 |
-
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 965 |
-
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 966 |
-
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
| 967 |
-
|
| 968 |
-
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 969 |
-
hidden_states=hidden_states,
|
| 970 |
-
past_key_value=self_attn_past_key_value,
|
| 971 |
-
attention_mask=attention_mask,
|
| 972 |
-
layer_head_mask=layer_head_mask,
|
| 973 |
-
output_attentions=output_attentions,
|
| 974 |
-
)
|
| 975 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 976 |
-
hidden_states = residual + hidden_states
|
| 977 |
-
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 978 |
-
|
| 979 |
-
# Cross-Attention Block
|
| 980 |
-
cross_attn_present_key_value = None
|
| 981 |
-
cross_attn_weights = None
|
| 982 |
-
if encoder_hidden_states is not None:
|
| 983 |
-
residual = hidden_states
|
| 984 |
-
|
| 985 |
-
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
| 986 |
-
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
| 987 |
-
|
| 988 |
-
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
| 989 |
-
hidden_states=hidden_states,
|
| 990 |
-
key_value_states=encoder_hidden_states,
|
| 991 |
-
attention_mask=encoder_attention_mask,
|
| 992 |
-
layer_head_mask=cross_attn_layer_head_mask,
|
| 993 |
-
past_key_value=cross_attn_past_key_value,
|
| 994 |
-
output_attentions=output_attentions,
|
| 995 |
-
)
|
| 996 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 997 |
-
hidden_states = residual + hidden_states
|
| 998 |
-
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
| 999 |
-
|
| 1000 |
-
# add cross-attn to positions 3,4 of present_key_value tuple
|
| 1001 |
-
present_key_value = present_key_value + cross_attn_present_key_value
|
| 1002 |
-
|
| 1003 |
-
# Fully Connected
|
| 1004 |
-
residual = hidden_states
|
| 1005 |
-
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 1006 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
| 1007 |
-
hidden_states = self.fc2(hidden_states)
|
| 1008 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
| 1009 |
-
hidden_states = residual + hidden_states
|
| 1010 |
-
hidden_states = self.final_layer_norm(hidden_states)
|
| 1011 |
-
|
| 1012 |
-
outputs = (hidden_states,)
|
| 1013 |
-
|
| 1014 |
-
if output_attentions:
|
| 1015 |
-
outputs += (self_attn_weights, cross_attn_weights)
|
| 1016 |
-
|
| 1017 |
-
if use_cache:
|
| 1018 |
-
outputs += (present_key_value,)
|
| 1019 |
-
|
| 1020 |
-
return outputs
|
| 1021 |
-
|
| 1022 |
|
| 1023 |
-
class LSGBartClassificationHead(
|
| 1024 |
"""Head for sentence-level classification tasks."""
|
| 1025 |
|
| 1026 |
def __init__(
|
|
@@ -1031,55 +649,18 @@ class LSGBartClassificationHead(nn.Module):
|
|
| 1031 |
pooler_dropout,
|
| 1032 |
):
|
| 1033 |
|
| 1034 |
-
super().__init__()
|
| 1035 |
-
self.dense = nn.Linear(input_dim, inner_dim)
|
| 1036 |
-
self.dropout = nn.Dropout(p=pooler_dropout)
|
| 1037 |
-
self.out_proj = nn.Linear(inner_dim, num_classes)
|
| 1038 |
-
|
| 1039 |
-
def forward(self, hidden_states):
|
| 1040 |
-
|
| 1041 |
-
hidden_states = self.dropout(hidden_states)
|
| 1042 |
-
hidden_states = self.dense(hidden_states)
|
| 1043 |
-
hidden_states = torch.tanh(hidden_states)
|
| 1044 |
-
hidden_states = self.dropout(hidden_states)
|
| 1045 |
-
hidden_states = self.out_proj(hidden_states)
|
| 1046 |
-
return hidden_states
|
| 1047 |
|
| 1048 |
|
| 1049 |
-
class LSGBartPretrainedModel(
|
| 1050 |
|
| 1051 |
config_class = LSGBartConfig
|
| 1052 |
-
base_model_prefix = "model"
|
| 1053 |
-
supports_gradient_checkpointing = True
|
| 1054 |
-
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
|
| 1055 |
-
|
| 1056 |
-
def _init_weights(self, module):
|
| 1057 |
-
|
| 1058 |
-
std = self.config.init_std
|
| 1059 |
-
if isinstance(module, nn.Linear):
|
| 1060 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 1061 |
-
if module.bias is not None:
|
| 1062 |
-
module.bias.data.zero_()
|
| 1063 |
-
elif isinstance(module, nn.Embedding):
|
| 1064 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 1065 |
-
if module.padding_idx is not None:
|
| 1066 |
-
module.weight.data[module.padding_idx].zero_()
|
| 1067 |
|
| 1068 |
def _set_gradient_checkpointing(self, module, value=False):
|
| 1069 |
|
| 1070 |
-
if isinstance(module, (LSGBartDecoder, LSGBartEncoder)):
|
| 1071 |
module.gradient_checkpointing = value
|
| 1072 |
|
| 1073 |
-
@property
|
| 1074 |
-
def dummy_inputs(self):
|
| 1075 |
-
pad_token = self.config.pad_token_id
|
| 1076 |
-
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
|
| 1077 |
-
dummy_inputs = {
|
| 1078 |
-
"attention_mask": input_ids.ne(pad_token),
|
| 1079 |
-
"input_ids": input_ids,
|
| 1080 |
-
}
|
| 1081 |
-
return dummy_inputs
|
| 1082 |
-
|
| 1083 |
|
| 1084 |
class PretrainedLSGBartModel(LSGBartPretrainedModel):
|
| 1085 |
|
|
@@ -1090,7 +671,7 @@ class PretrainedLSGBartModel(LSGBartPretrainedModel):
|
|
| 1090 |
)
|
| 1091 |
|
| 1092 |
|
| 1093 |
-
class LSGBartEncoder(LSGBartPretrainedModel):
|
| 1094 |
"""
|
| 1095 |
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
| 1096 |
:class:`BartEncoderLayer`.
|
|
@@ -1115,7 +696,7 @@ class LSGBartEncoder(LSGBartPretrainedModel):
|
|
| 1115 |
else:
|
| 1116 |
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
| 1117 |
|
| 1118 |
-
self.embed_positions =
|
| 1119 |
config.max_position_embeddings,
|
| 1120 |
embed_dim,
|
| 1121 |
)
|
|
@@ -1140,12 +721,6 @@ class LSGBartEncoder(LSGBartPretrainedModel):
|
|
| 1140 |
# Initialize weights and apply final processing
|
| 1141 |
self.post_init()
|
| 1142 |
|
| 1143 |
-
def get_input_embeddings(self):
|
| 1144 |
-
return self.embed_tokens
|
| 1145 |
-
|
| 1146 |
-
def set_input_embeddings(self, value):
|
| 1147 |
-
self.embed_tokens = value
|
| 1148 |
-
|
| 1149 |
def forward(self,
|
| 1150 |
input_ids=None,
|
| 1151 |
attention_mask=None,
|
|
@@ -1335,7 +910,7 @@ class LSGBartDecoder(BartDecoder, LSGBartPretrainedModel):
|
|
| 1335 |
else:
|
| 1336 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
| 1337 |
|
| 1338 |
-
self.embed_positions =
|
| 1339 |
config.max_position_embeddings,
|
| 1340 |
config.d_model,
|
| 1341 |
)
|
|
@@ -1348,36 +923,24 @@ class LSGBartDecoder(BartDecoder, LSGBartPretrainedModel):
|
|
| 1348 |
self.post_init()
|
| 1349 |
|
| 1350 |
|
| 1351 |
-
class LSGBartModel(LSGBartPretrainedModel):
|
| 1352 |
|
| 1353 |
def __init__(self, config):
|
| 1354 |
|
| 1355 |
-
|
| 1356 |
|
| 1357 |
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
| 1358 |
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
|
|
|
| 1359 |
self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
|
| 1360 |
self.num_global_tokens = config.num_global_tokens
|
|
|
|
| 1361 |
self.encoder = LSGBartEncoder(config, self.shared)
|
| 1362 |
self.decoder = LSGBartDecoder(config, self.shared)
|
| 1363 |
|
| 1364 |
# Initialize weights and apply final processing
|
| 1365 |
self.post_init()
|
| 1366 |
|
| 1367 |
-
def get_input_embeddings(self):
|
| 1368 |
-
return self.shared
|
| 1369 |
-
|
| 1370 |
-
def set_input_embeddings(self, value):
|
| 1371 |
-
self.shared = value
|
| 1372 |
-
self.encoder.embed_tokens = self.shared
|
| 1373 |
-
self.decoder.embed_tokens = self.shared
|
| 1374 |
-
|
| 1375 |
-
def get_encoder(self):
|
| 1376 |
-
return self.encoder
|
| 1377 |
-
|
| 1378 |
-
def get_decoder(self):
|
| 1379 |
-
return self.decoder
|
| 1380 |
-
|
| 1381 |
def forward(
|
| 1382 |
self,
|
| 1383 |
input_ids=None,
|
|
|
|
| 81 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
| 82 |
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
class BaseSelfAttention(nn.Module):
|
| 85 |
|
| 86 |
def __init__(
|
|
|
|
| 618 |
return x.reshape(n, h, -1, chunk_size, d)
|
| 619 |
|
| 620 |
|
| 621 |
+
class LSGBartEncoderLayer(BartEncoderLayer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
|
| 623 |
def __init__(self, config):
|
| 624 |
|
| 625 |
+
super().__init__(config)
|
|
|
|
| 626 |
self.self_attn = LSGBartEncoderAttention(
|
| 627 |
config=config,
|
| 628 |
embed_dim=self.embed_dim,
|
| 629 |
num_heads=config.encoder_attention_heads,
|
| 630 |
dropout=config.attention_dropout,
|
| 631 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
|
| 633 |
|
| 634 |
+
class LSGBartDecoderLayer(BartDecoderLayer):
|
| 635 |
|
| 636 |
def __init__(self, config):
|
| 637 |
|
| 638 |
+
super().__init__(config)
|
| 639 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
|
| 641 |
+
class LSGBartClassificationHead(BartClassificationHead):
|
| 642 |
"""Head for sentence-level classification tasks."""
|
| 643 |
|
| 644 |
def __init__(
|
|
|
|
| 649 |
pooler_dropout,
|
| 650 |
):
|
| 651 |
|
| 652 |
+
super().__init__(input_dim, inner_dim, num_classes, pooler_dropout)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
|
| 654 |
|
| 655 |
+
class LSGBartPretrainedModel(BartPretrainedModel):
|
| 656 |
|
| 657 |
config_class = LSGBartConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
|
| 659 |
def _set_gradient_checkpointing(self, module, value=False):
|
| 660 |
|
| 661 |
+
if isinstance(module, (BartDecoder, BartEncoder, LSGBartDecoder, LSGBartEncoder)):
|
| 662 |
module.gradient_checkpointing = value
|
| 663 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
|
| 665 |
class PretrainedLSGBartModel(LSGBartPretrainedModel):
|
| 666 |
|
|
|
|
| 671 |
)
|
| 672 |
|
| 673 |
|
| 674 |
+
class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
| 675 |
"""
|
| 676 |
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
| 677 |
:class:`BartEncoderLayer`.
|
|
|
|
| 696 |
else:
|
| 697 |
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
| 698 |
|
| 699 |
+
self.embed_positions = BartLearnedPositionalEmbedding(
|
| 700 |
config.max_position_embeddings,
|
| 701 |
embed_dim,
|
| 702 |
)
|
|
|
|
| 721 |
# Initialize weights and apply final processing
|
| 722 |
self.post_init()
|
| 723 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
def forward(self,
|
| 725 |
input_ids=None,
|
| 726 |
attention_mask=None,
|
|
|
|
| 910 |
else:
|
| 911 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
| 912 |
|
| 913 |
+
self.embed_positions = BartLearnedPositionalEmbedding(
|
| 914 |
config.max_position_embeddings,
|
| 915 |
config.d_model,
|
| 916 |
)
|
|
|
|
| 923 |
self.post_init()
|
| 924 |
|
| 925 |
|
| 926 |
+
class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
| 927 |
|
| 928 |
def __init__(self, config):
|
| 929 |
|
| 930 |
+
LSGBartPretrainedModel.__init__(self, config)
|
| 931 |
|
| 932 |
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
| 933 |
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
| 934 |
+
|
| 935 |
self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
|
| 936 |
self.num_global_tokens = config.num_global_tokens
|
| 937 |
+
|
| 938 |
self.encoder = LSGBartEncoder(config, self.shared)
|
| 939 |
self.decoder = LSGBartDecoder(config, self.shared)
|
| 940 |
|
| 941 |
# Initialize weights and apply final processing
|
| 942 |
self.post_init()
|
| 943 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 944 |
def forward(
|
| 945 |
self,
|
| 946 |
input_ids=None,
|