Commit
·
77af1c7
1
Parent(s):
1c61b96
add stochastic_depth
Browse files- block.py +26 -14
- modeling_xlm_roberta.py +121 -61
- stochastic_depth.py +97 -0
block.py
CHANGED
|
@@ -10,8 +10,8 @@ import torch
|
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
| 12 |
from torch import Tensor
|
| 13 |
-
from torchvision.ops import StochasticDepth
|
| 14 |
|
|
|
|
| 15 |
from .mha import MHA
|
| 16 |
from .mlp import Mlp
|
| 17 |
|
|
@@ -106,7 +106,9 @@ class Block(nn.Module):
|
|
| 106 |
p._shared_params = True
|
| 107 |
|
| 108 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 109 |
-
return self.mixer.allocate_inference_cache(
|
|
|
|
|
|
|
| 110 |
|
| 111 |
def forward(
|
| 112 |
self,
|
|
@@ -152,7 +154,7 @@ class Block(nn.Module):
|
|
| 152 |
rowscale=rowscale1,
|
| 153 |
prenorm=True,
|
| 154 |
residual_in_fp32=self.residual_in_fp32,
|
| 155 |
-
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
| 156 |
)
|
| 157 |
if mixer_kwargs is None:
|
| 158 |
mixer_kwargs = {}
|
|
@@ -165,7 +167,9 @@ class Block(nn.Module):
|
|
| 165 |
if not self.fused_dropout_add_ln:
|
| 166 |
dropped = self.drop_path2(self.dropout2(hidden_states))
|
| 167 |
residual = (dropped + residual) if residual is not None else dropped
|
| 168 |
-
hidden_states = self.norm2(
|
|
|
|
|
|
|
| 169 |
if self.residual_in_fp32:
|
| 170 |
residual = residual.to(torch.float32)
|
| 171 |
else:
|
|
@@ -189,7 +193,7 @@ class Block(nn.Module):
|
|
| 189 |
rowscale=rowscale2,
|
| 190 |
prenorm=True,
|
| 191 |
residual_in_fp32=self.residual_in_fp32,
|
| 192 |
-
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
| 193 |
)
|
| 194 |
hidden_states = self.mlp(hidden_states)
|
| 195 |
return hidden_states, residual
|
|
@@ -212,7 +216,9 @@ class Block(nn.Module):
|
|
| 212 |
else:
|
| 213 |
rowscale1 = self.drop_path1(
|
| 214 |
torch.ones(
|
| 215 |
-
mixer_out.shape[:-1],
|
|
|
|
|
|
|
| 216 |
)
|
| 217 |
)
|
| 218 |
hidden_states = layer_norm_fn(
|
|
@@ -224,7 +230,7 @@ class Block(nn.Module):
|
|
| 224 |
dropout_p=self.dropout1.p if self.training else 0.0,
|
| 225 |
rowscale=rowscale1,
|
| 226 |
prenorm=False,
|
| 227 |
-
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
| 228 |
)
|
| 229 |
if not isinstance(self.mlp, nn.Identity):
|
| 230 |
mlp_out = self.mlp(hidden_states)
|
|
@@ -242,7 +248,9 @@ class Block(nn.Module):
|
|
| 242 |
else:
|
| 243 |
rowscale2 = self.drop_path2(
|
| 244 |
torch.ones(
|
| 245 |
-
mlp_out.shape[:-1],
|
|
|
|
|
|
|
| 246 |
)
|
| 247 |
)
|
| 248 |
hidden_states = layer_norm_fn(
|
|
@@ -254,7 +262,7 @@ class Block(nn.Module):
|
|
| 254 |
dropout_p=self.dropout2.p if self.training else 0.0,
|
| 255 |
rowscale=rowscale2,
|
| 256 |
prenorm=False,
|
| 257 |
-
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
| 258 |
)
|
| 259 |
return hidden_states
|
| 260 |
|
|
@@ -333,7 +341,9 @@ class ParallelBlock(nn.Module):
|
|
| 333 |
p._shared_params = True
|
| 334 |
|
| 335 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 336 |
-
return self.mixer.allocate_inference_cache(
|
|
|
|
|
|
|
| 337 |
|
| 338 |
def forward(
|
| 339 |
self,
|
|
@@ -373,7 +383,9 @@ class ParallelBlock(nn.Module):
|
|
| 373 |
residual = residual.to(torch.float32)
|
| 374 |
else:
|
| 375 |
weight2, bias2 = (
|
| 376 |
-
(self.norm2.weight, self.norm2.bias)
|
|
|
|
|
|
|
| 377 |
)
|
| 378 |
hidden_states1, *rest, residual = layer_norm_fn(
|
| 379 |
hidden_states1,
|
|
@@ -387,14 +399,14 @@ class ParallelBlock(nn.Module):
|
|
| 387 |
dropout_p=self.dropout1.p if self.training else 0.0,
|
| 388 |
prenorm=True,
|
| 389 |
residual_in_fp32=self.residual_in_fp32,
|
| 390 |
-
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
| 391 |
)
|
| 392 |
if self.tied_norm:
|
| 393 |
hidden_states2 = hidden_states1
|
| 394 |
else:
|
| 395 |
-
hidden_states2, = rest
|
| 396 |
if mixer_kwargs is None:
|
| 397 |
mixer_kwargs = {}
|
| 398 |
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
|
| 399 |
hidden_states2 = self.mlp(hidden_states2)
|
| 400 |
-
return hidden_states1, hidden_states2, residual
|
|
|
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
| 12 |
from torch import Tensor
|
|
|
|
| 13 |
|
| 14 |
+
from .stochastic_depth import StochasticDepth
|
| 15 |
from .mha import MHA
|
| 16 |
from .mlp import Mlp
|
| 17 |
|
|
|
|
| 106 |
p._shared_params = True
|
| 107 |
|
| 108 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 109 |
+
return self.mixer.allocate_inference_cache(
|
| 110 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 111 |
+
)
|
| 112 |
|
| 113 |
def forward(
|
| 114 |
self,
|
|
|
|
| 154 |
rowscale=rowscale1,
|
| 155 |
prenorm=True,
|
| 156 |
residual_in_fp32=self.residual_in_fp32,
|
| 157 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 158 |
)
|
| 159 |
if mixer_kwargs is None:
|
| 160 |
mixer_kwargs = {}
|
|
|
|
| 167 |
if not self.fused_dropout_add_ln:
|
| 168 |
dropped = self.drop_path2(self.dropout2(hidden_states))
|
| 169 |
residual = (dropped + residual) if residual is not None else dropped
|
| 170 |
+
hidden_states = self.norm2(
|
| 171 |
+
residual.to(dtype=self.norm2.weight.dtype)
|
| 172 |
+
)
|
| 173 |
if self.residual_in_fp32:
|
| 174 |
residual = residual.to(torch.float32)
|
| 175 |
else:
|
|
|
|
| 193 |
rowscale=rowscale2,
|
| 194 |
prenorm=True,
|
| 195 |
residual_in_fp32=self.residual_in_fp32,
|
| 196 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
| 197 |
)
|
| 198 |
hidden_states = self.mlp(hidden_states)
|
| 199 |
return hidden_states, residual
|
|
|
|
| 216 |
else:
|
| 217 |
rowscale1 = self.drop_path1(
|
| 218 |
torch.ones(
|
| 219 |
+
mixer_out.shape[:-1],
|
| 220 |
+
device=mixer_out.device,
|
| 221 |
+
dtype=mixer_out.dtype,
|
| 222 |
)
|
| 223 |
)
|
| 224 |
hidden_states = layer_norm_fn(
|
|
|
|
| 230 |
dropout_p=self.dropout1.p if self.training else 0.0,
|
| 231 |
rowscale=rowscale1,
|
| 232 |
prenorm=False,
|
| 233 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 234 |
)
|
| 235 |
if not isinstance(self.mlp, nn.Identity):
|
| 236 |
mlp_out = self.mlp(hidden_states)
|
|
|
|
| 248 |
else:
|
| 249 |
rowscale2 = self.drop_path2(
|
| 250 |
torch.ones(
|
| 251 |
+
mlp_out.shape[:-1],
|
| 252 |
+
device=mlp_out.device,
|
| 253 |
+
dtype=mlp_out.dtype,
|
| 254 |
)
|
| 255 |
)
|
| 256 |
hidden_states = layer_norm_fn(
|
|
|
|
| 262 |
dropout_p=self.dropout2.p if self.training else 0.0,
|
| 263 |
rowscale=rowscale2,
|
| 264 |
prenorm=False,
|
| 265 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
| 266 |
)
|
| 267 |
return hidden_states
|
| 268 |
|
|
|
|
| 341 |
p._shared_params = True
|
| 342 |
|
| 343 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 344 |
+
return self.mixer.allocate_inference_cache(
|
| 345 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
| 346 |
+
)
|
| 347 |
|
| 348 |
def forward(
|
| 349 |
self,
|
|
|
|
| 383 |
residual = residual.to(torch.float32)
|
| 384 |
else:
|
| 385 |
weight2, bias2 = (
|
| 386 |
+
(self.norm2.weight, self.norm2.bias)
|
| 387 |
+
if not self.tied_norm
|
| 388 |
+
else (None, None)
|
| 389 |
)
|
| 390 |
hidden_states1, *rest, residual = layer_norm_fn(
|
| 391 |
hidden_states1,
|
|
|
|
| 399 |
dropout_p=self.dropout1.p if self.training else 0.0,
|
| 400 |
prenorm=True,
|
| 401 |
residual_in_fp32=self.residual_in_fp32,
|
| 402 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
| 403 |
)
|
| 404 |
if self.tied_norm:
|
| 405 |
hidden_states2 = hidden_states1
|
| 406 |
else:
|
| 407 |
+
(hidden_states2,) = rest
|
| 408 |
if mixer_kwargs is None:
|
| 409 |
mixer_kwargs = {}
|
| 410 |
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
|
| 411 |
hidden_states2 = self.mlp(hidden_states2)
|
| 412 |
+
return hidden_states1, hidden_states2, residual
|
modeling_xlm_roberta.py
CHANGED
|
@@ -42,6 +42,7 @@ from .block import Block
|
|
| 42 |
from .embedding import XLMRobertaEmbeddings
|
| 43 |
from .mha import MHA
|
| 44 |
from .mlp import FusedMLP, Mlp
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
try:
|
|
@@ -69,10 +70,16 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
| 69 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 70 |
rotary_kwargs = {}
|
| 71 |
if config.position_embedding_type == "rotary":
|
| 72 |
-
rotary_kwargs["rotary_emb_dim"] = getattr(
|
|
|
|
|
|
|
| 73 |
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
|
| 74 |
-
rotary_kwargs["rotary_emb_scale_base"] = getattr(
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
mixer_cls = partial(
|
| 77 |
MHA,
|
| 78 |
num_heads=config.num_attention_heads,
|
|
@@ -183,7 +190,9 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 183 |
"""
|
| 184 |
if key_padding_mask is None or not self.use_flash_attn:
|
| 185 |
mixer_kwargs = (
|
| 186 |
-
{"key_padding_mask": key_padding_mask.bool()}
|
|
|
|
|
|
|
| 187 |
)
|
| 188 |
for layer in self.layers:
|
| 189 |
if self._grad_checkpointing:
|
|
@@ -191,7 +200,7 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 191 |
layer,
|
| 192 |
hidden_states,
|
| 193 |
use_reentrant=False,
|
| 194 |
-
mixer_kwargs=mixer_kwargs
|
| 195 |
)
|
| 196 |
else:
|
| 197 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
@@ -210,7 +219,7 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 210 |
layer,
|
| 211 |
hidden_states,
|
| 212 |
use_reentrant=False,
|
| 213 |
-
mixer_kwargs=mixer_kwargs
|
| 214 |
)
|
| 215 |
else:
|
| 216 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
@@ -222,7 +231,7 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 222 |
layer,
|
| 223 |
hidden_states,
|
| 224 |
use_reentrant=False,
|
| 225 |
-
mixer_kwargs=mixer_kwargs
|
| 226 |
)
|
| 227 |
else:
|
| 228 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
@@ -230,15 +239,19 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 230 |
subset_idx = torch.nonzero(
|
| 231 |
subset_mask[key_padding_mask], as_tuple=False
|
| 232 |
).flatten()
|
| 233 |
-
subset_seqlens = (subset_mask & key_padding_mask).sum(
|
|
|
|
|
|
|
| 234 |
subset_cu_seqlens = F.pad(
|
| 235 |
-
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
|
|
|
| 236 |
)
|
| 237 |
else:
|
| 238 |
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
| 239 |
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
| 240 |
subset_cu_seqlens = F.pad(
|
| 241 |
-
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
|
|
|
| 242 |
)
|
| 243 |
hidden_states_subset, hidden_states = index_first_axis_residual(
|
| 244 |
hidden_states, subset_idx
|
|
@@ -256,10 +269,12 @@ class XLMRobertaEncoder(nn.Module):
|
|
| 256 |
self.layers[-1],
|
| 257 |
hidden_states_subset,
|
| 258 |
use_reentrant=False,
|
| 259 |
-
mixer_kwargs=mixer_kwargs
|
| 260 |
)
|
| 261 |
else:
|
| 262 |
-
hidden_states = self.layers[-1](
|
|
|
|
|
|
|
| 263 |
return hidden_states
|
| 264 |
|
| 265 |
|
|
@@ -308,7 +323,10 @@ class XLMRobertaPredictionHeadTransform(nn.Module):
|
|
| 308 |
hidden_states = self.layer_norm(hidden_states)
|
| 309 |
else:
|
| 310 |
hidden_states = layer_norm_fn(
|
| 311 |
-
hidden_states,
|
|
|
|
|
|
|
|
|
|
| 312 |
)
|
| 313 |
return hidden_states
|
| 314 |
|
|
@@ -349,6 +367,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
|
|
| 349 |
"""An abstract class to handle weights initialization and
|
| 350 |
a simple interface for dowloading and loading pretrained models.
|
| 351 |
"""
|
|
|
|
| 352 |
config_class = XLMRobertaFlashConfig
|
| 353 |
base_model_prefix = "roberta"
|
| 354 |
supports_gradient_checkpointing = True
|
|
@@ -358,7 +377,6 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
|
|
| 358 |
module.gradient_checkpointing = value
|
| 359 |
|
| 360 |
|
| 361 |
-
|
| 362 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
| 363 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
| 364 |
super().__init__(config)
|
|
@@ -370,7 +388,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 370 |
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
| 371 |
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
| 372 |
raise ImportError("Triton is not installed")
|
| 373 |
-
assert config.hidden_act in [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
self.embeddings = XLMRobertaEmbeddings(
|
| 376 |
config.hidden_size,
|
|
@@ -386,7 +409,6 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 386 |
|
| 387 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 388 |
|
| 389 |
-
|
| 390 |
def forward(
|
| 391 |
self,
|
| 392 |
input_ids,
|
|
@@ -406,9 +428,14 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 406 |
if kwargs:
|
| 407 |
for key, value in kwargs.items():
|
| 408 |
if value is not None:
|
| 409 |
-
logger.warning(
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
-
return_dict =
|
|
|
|
|
|
|
| 412 |
|
| 413 |
hidden_states = self.embeddings(
|
| 414 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
|
@@ -439,17 +466,23 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
| 439 |
)
|
| 440 |
|
| 441 |
if masked_tokens_mask is None:
|
| 442 |
-
pooled_output =
|
|
|
|
|
|
|
| 443 |
else:
|
| 444 |
# TD [2022-03-01]: the indexing here is very tricky.
|
| 445 |
if attention_mask is not None:
|
| 446 |
subset_idx = subset_mask[attention_mask]
|
| 447 |
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
|
| 448 |
-
sequence_output = sequence_output[
|
|
|
|
|
|
|
| 449 |
else:
|
| 450 |
pool_input = sequence_output[first_col_mask[subset_mask]]
|
| 451 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
| 452 |
-
pooled_output =
|
|
|
|
|
|
|
| 453 |
|
| 454 |
if not return_dict:
|
| 455 |
return sequence_output, pooled_output
|
|
@@ -487,7 +520,6 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
|
|
| 487 |
def set_output_embeddings(self, new_embeddings):
|
| 488 |
self.lm_head.decoder = new_embeddings
|
| 489 |
|
| 490 |
-
|
| 491 |
def forward(
|
| 492 |
self,
|
| 493 |
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -511,7 +543,9 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
|
|
| 511 |
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
|
| 512 |
Used to hide legacy arguments that have been deprecated.
|
| 513 |
"""
|
| 514 |
-
return_dict =
|
|
|
|
|
|
|
| 515 |
|
| 516 |
outputs = self.roberta(
|
| 517 |
input_ids,
|
|
@@ -534,11 +568,15 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
|
|
| 534 |
# move labels to correct device to enable model parallelism
|
| 535 |
labels = labels.to(prediction_scores.device)
|
| 536 |
loss_fct = CrossEntropyLoss()
|
| 537 |
-
masked_lm_loss = loss_fct(
|
|
|
|
|
|
|
| 538 |
|
| 539 |
if not return_dict:
|
| 540 |
output = (prediction_scores,) + outputs[2:]
|
| 541 |
-
return (
|
|
|
|
|
|
|
| 542 |
|
| 543 |
return MaskedLMOutput(
|
| 544 |
loss=masked_lm_loss,
|
|
@@ -656,7 +694,9 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
|
|
| 656 |
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
|
| 657 |
return key
|
| 658 |
|
| 659 |
-
state_dict = OrderedDict(
|
|
|
|
|
|
|
| 660 |
|
| 661 |
# Layers
|
| 662 |
def key_mapping_layers(key):
|
|
@@ -715,12 +755,18 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
|
|
| 715 |
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
|
| 716 |
[Wq, Wk, Wv], dim=0
|
| 717 |
)
|
| 718 |
-
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat(
|
|
|
|
|
|
|
| 719 |
else:
|
| 720 |
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
|
| 721 |
-
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat(
|
|
|
|
|
|
|
| 722 |
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
|
| 723 |
-
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat(
|
|
|
|
|
|
|
| 724 |
|
| 725 |
def key_mapping_attn(key):
|
| 726 |
return re.sub(
|
|
@@ -734,7 +780,9 @@ def remap_state_dict(state_dict, config: PretrainedConfig):
|
|
| 734 |
def key_mapping_decoder_bias(key):
|
| 735 |
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
| 736 |
|
| 737 |
-
state_dict = OrderedDict(
|
|
|
|
|
|
|
| 738 |
|
| 739 |
# Word embedding
|
| 740 |
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
|
@@ -774,51 +822,59 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
|
|
| 774 |
state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
|
| 775 |
: config.orig_vocab_size, :
|
| 776 |
]
|
| 777 |
-
state_dict["cls.predictions.decoder.weight"] = decoder_weight[
|
| 778 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 779 |
|
| 780 |
for d in range(config.num_hidden_layers):
|
| 781 |
last_layer_subset = getattr(config, "last_layer_subset", False)
|
| 782 |
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
| 783 |
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
| 784 |
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
| 785 |
-
state_dict[
|
| 786 |
-
|
| 787 |
-
]
|
| 788 |
-
state_dict[
|
|
|
|
|
|
|
| 789 |
Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
|
| 790 |
]
|
| 791 |
-
state_dict[
|
| 792 |
-
|
| 793 |
-
]
|
| 794 |
-
state_dict[
|
| 795 |
-
|
| 796 |
-
]
|
| 797 |
-
state_dict[
|
| 798 |
-
|
| 799 |
-
]
|
| 800 |
-
state_dict[
|
| 801 |
-
|
| 802 |
-
]
|
| 803 |
else:
|
| 804 |
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
| 805 |
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
| 806 |
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
| 807 |
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
| 808 |
-
state_dict[
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
|
|
|
|
|
|
| 815 |
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
| 816 |
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
| 817 |
: Wkv_biases.shape[0] // 2
|
| 818 |
]
|
| 819 |
-
state_dict[
|
| 820 |
-
|
| 821 |
-
]
|
| 822 |
|
| 823 |
def inv_key_mapping_ln(key):
|
| 824 |
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
|
@@ -870,14 +926,18 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
|
|
| 870 |
def inv_key_mapping_decoder_bias(key):
|
| 871 |
return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
|
| 872 |
|
| 873 |
-
state_dict = OrderedDict(
|
|
|
|
|
|
|
| 874 |
state_dict = OrderedDict(
|
| 875 |
(inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
|
| 876 |
)
|
| 877 |
state_dict = OrderedDict(
|
| 878 |
(inv_key_mapping_layers(key), value) for key, value in state_dict.items()
|
| 879 |
)
|
| 880 |
-
state_dict = OrderedDict(
|
|
|
|
|
|
|
| 881 |
state_dict = OrderedDict(
|
| 882 |
(inv_key_mapping_attn(key), value) for key, value in state_dict.items()
|
| 883 |
)
|
|
@@ -885,4 +945,4 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
|
|
| 885 |
(inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
|
| 886 |
)
|
| 887 |
|
| 888 |
-
return state_dict
|
|
|
|
| 42 |
from .embedding import XLMRobertaEmbeddings
|
| 43 |
from .mha import MHA
|
| 44 |
from .mlp import FusedMLP, Mlp
|
| 45 |
+
from .stochastic_depth import StochasticDepth
|
| 46 |
|
| 47 |
|
| 48 |
try:
|
|
|
|
| 70 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 71 |
rotary_kwargs = {}
|
| 72 |
if config.position_embedding_type == "rotary":
|
| 73 |
+
rotary_kwargs["rotary_emb_dim"] = getattr(
|
| 74 |
+
config, "rotary_emb_dim", config.hidden_size
|
| 75 |
+
)
|
| 76 |
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
|
| 77 |
+
rotary_kwargs["rotary_emb_scale_base"] = getattr(
|
| 78 |
+
config, "rotary_emb_scale_base", None
|
| 79 |
+
)
|
| 80 |
+
rotary_kwargs["rotary_emb_interleaved"] = getattr(
|
| 81 |
+
config, "rotary_emb_interleaved", False
|
| 82 |
+
)
|
| 83 |
mixer_cls = partial(
|
| 84 |
MHA,
|
| 85 |
num_heads=config.num_attention_heads,
|
|
|
|
| 190 |
"""
|
| 191 |
if key_padding_mask is None or not self.use_flash_attn:
|
| 192 |
mixer_kwargs = (
|
| 193 |
+
{"key_padding_mask": key_padding_mask.bool()}
|
| 194 |
+
if key_padding_mask is not None
|
| 195 |
+
else None
|
| 196 |
)
|
| 197 |
for layer in self.layers:
|
| 198 |
if self._grad_checkpointing:
|
|
|
|
| 200 |
layer,
|
| 201 |
hidden_states,
|
| 202 |
use_reentrant=False,
|
| 203 |
+
mixer_kwargs=mixer_kwargs,
|
| 204 |
)
|
| 205 |
else:
|
| 206 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
|
| 219 |
layer,
|
| 220 |
hidden_states,
|
| 221 |
use_reentrant=False,
|
| 222 |
+
mixer_kwargs=mixer_kwargs,
|
| 223 |
)
|
| 224 |
else:
|
| 225 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
|
| 231 |
layer,
|
| 232 |
hidden_states,
|
| 233 |
use_reentrant=False,
|
| 234 |
+
mixer_kwargs=mixer_kwargs,
|
| 235 |
)
|
| 236 |
else:
|
| 237 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
|
| 239 |
subset_idx = torch.nonzero(
|
| 240 |
subset_mask[key_padding_mask], as_tuple=False
|
| 241 |
).flatten()
|
| 242 |
+
subset_seqlens = (subset_mask & key_padding_mask).sum(
|
| 243 |
+
dim=-1, dtype=torch.int32
|
| 244 |
+
)
|
| 245 |
subset_cu_seqlens = F.pad(
|
| 246 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
| 247 |
+
(1, 0),
|
| 248 |
)
|
| 249 |
else:
|
| 250 |
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
| 251 |
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
| 252 |
subset_cu_seqlens = F.pad(
|
| 253 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
| 254 |
+
(1, 0),
|
| 255 |
)
|
| 256 |
hidden_states_subset, hidden_states = index_first_axis_residual(
|
| 257 |
hidden_states, subset_idx
|
|
|
|
| 269 |
self.layers[-1],
|
| 270 |
hidden_states_subset,
|
| 271 |
use_reentrant=False,
|
| 272 |
+
mixer_kwargs=mixer_kwargs,
|
| 273 |
)
|
| 274 |
else:
|
| 275 |
+
hidden_states = self.layers[-1](
|
| 276 |
+
hidden_states_subset, mixer_kwargs=mixer_kwargs
|
| 277 |
+
)
|
| 278 |
return hidden_states
|
| 279 |
|
| 280 |
|
|
|
|
| 323 |
hidden_states = self.layer_norm(hidden_states)
|
| 324 |
else:
|
| 325 |
hidden_states = layer_norm_fn(
|
| 326 |
+
hidden_states,
|
| 327 |
+
self.layer_norm.weight,
|
| 328 |
+
self.layer_norm.bias,
|
| 329 |
+
eps=self.layer_norm.eps,
|
| 330 |
)
|
| 331 |
return hidden_states
|
| 332 |
|
|
|
|
| 367 |
"""An abstract class to handle weights initialization and
|
| 368 |
a simple interface for dowloading and loading pretrained models.
|
| 369 |
"""
|
| 370 |
+
|
| 371 |
config_class = XLMRobertaFlashConfig
|
| 372 |
base_model_prefix = "roberta"
|
| 373 |
supports_gradient_checkpointing = True
|
|
|
|
| 377 |
module.gradient_checkpointing = value
|
| 378 |
|
| 379 |
|
|
|
|
| 380 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
| 381 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
| 382 |
super().__init__(config)
|
|
|
|
| 388 |
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
| 389 |
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
| 390 |
raise ImportError("Triton is not installed")
|
| 391 |
+
assert config.hidden_act in [
|
| 392 |
+
"gelu",
|
| 393 |
+
"gelu_new",
|
| 394 |
+
"gelu_fast",
|
| 395 |
+
"gelu_pytorch_tanh",
|
| 396 |
+
]
|
| 397 |
|
| 398 |
self.embeddings = XLMRobertaEmbeddings(
|
| 399 |
config.hidden_size,
|
|
|
|
| 409 |
|
| 410 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 411 |
|
|
|
|
| 412 |
def forward(
|
| 413 |
self,
|
| 414 |
input_ids,
|
|
|
|
| 428 |
if kwargs:
|
| 429 |
for key, value in kwargs.items():
|
| 430 |
if value is not None:
|
| 431 |
+
logger.warning(
|
| 432 |
+
'Flash attention implementation does not support kwargs: %s',
|
| 433 |
+
key,
|
| 434 |
+
)
|
| 435 |
|
| 436 |
+
return_dict = (
|
| 437 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 438 |
+
)
|
| 439 |
|
| 440 |
hidden_states = self.embeddings(
|
| 441 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
|
|
|
| 466 |
)
|
| 467 |
|
| 468 |
if masked_tokens_mask is None:
|
| 469 |
+
pooled_output = (
|
| 470 |
+
self.pooler(sequence_output) if self.pooler is not None else None
|
| 471 |
+
)
|
| 472 |
else:
|
| 473 |
# TD [2022-03-01]: the indexing here is very tricky.
|
| 474 |
if attention_mask is not None:
|
| 475 |
subset_idx = subset_mask[attention_mask]
|
| 476 |
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
|
| 477 |
+
sequence_output = sequence_output[
|
| 478 |
+
masked_tokens_mask[attention_mask][subset_idx]
|
| 479 |
+
]
|
| 480 |
else:
|
| 481 |
pool_input = sequence_output[first_col_mask[subset_mask]]
|
| 482 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
| 483 |
+
pooled_output = (
|
| 484 |
+
self.pooler(pool_input, pool=False) if self.pooler is not None else None
|
| 485 |
+
)
|
| 486 |
|
| 487 |
if not return_dict:
|
| 488 |
return sequence_output, pooled_output
|
|
|
|
| 520 |
def set_output_embeddings(self, new_embeddings):
|
| 521 |
self.lm_head.decoder = new_embeddings
|
| 522 |
|
|
|
|
| 523 |
def forward(
|
| 524 |
self,
|
| 525 |
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
| 543 |
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
|
| 544 |
Used to hide legacy arguments that have been deprecated.
|
| 545 |
"""
|
| 546 |
+
return_dict = (
|
| 547 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 548 |
+
)
|
| 549 |
|
| 550 |
outputs = self.roberta(
|
| 551 |
input_ids,
|
|
|
|
| 568 |
# move labels to correct device to enable model parallelism
|
| 569 |
labels = labels.to(prediction_scores.device)
|
| 570 |
loss_fct = CrossEntropyLoss()
|
| 571 |
+
masked_lm_loss = loss_fct(
|
| 572 |
+
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
|
| 573 |
+
)
|
| 574 |
|
| 575 |
if not return_dict:
|
| 576 |
output = (prediction_scores,) + outputs[2:]
|
| 577 |
+
return (
|
| 578 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
| 579 |
+
)
|
| 580 |
|
| 581 |
return MaskedLMOutput(
|
| 582 |
loss=masked_lm_loss,
|
|
|
|
| 694 |
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
|
| 695 |
return key
|
| 696 |
|
| 697 |
+
state_dict = OrderedDict(
|
| 698 |
+
(key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()
|
| 699 |
+
)
|
| 700 |
|
| 701 |
# Layers
|
| 702 |
def key_mapping_layers(key):
|
|
|
|
| 755 |
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
|
| 756 |
[Wq, Wk, Wv], dim=0
|
| 757 |
)
|
| 758 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat(
|
| 759 |
+
[bq, bk, bv], dim=0
|
| 760 |
+
)
|
| 761 |
else:
|
| 762 |
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
|
| 763 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat(
|
| 764 |
+
[Wk, Wv], dim=0
|
| 765 |
+
)
|
| 766 |
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
|
| 767 |
+
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat(
|
| 768 |
+
[bk, bv], dim=0
|
| 769 |
+
)
|
| 770 |
|
| 771 |
def key_mapping_attn(key):
|
| 772 |
return re.sub(
|
|
|
|
| 780 |
def key_mapping_decoder_bias(key):
|
| 781 |
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
| 782 |
|
| 783 |
+
state_dict = OrderedDict(
|
| 784 |
+
(key_mapping_decoder_bias(k), v) for k, v in state_dict.items()
|
| 785 |
+
)
|
| 786 |
|
| 787 |
# Word embedding
|
| 788 |
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
|
|
|
| 822 |
state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
|
| 823 |
: config.orig_vocab_size, :
|
| 824 |
]
|
| 825 |
+
state_dict["cls.predictions.decoder.weight"] = decoder_weight[
|
| 826 |
+
: config.orig_vocab_size, :
|
| 827 |
+
]
|
| 828 |
+
state_dict["cls.predictions.decoder.bias"] = decoder_bias[
|
| 829 |
+
: config.orig_vocab_size
|
| 830 |
+
]
|
| 831 |
|
| 832 |
for d in range(config.num_hidden_layers):
|
| 833 |
last_layer_subset = getattr(config, "last_layer_subset", False)
|
| 834 |
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
| 835 |
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
| 836 |
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
| 837 |
+
state_dict[
|
| 838 |
+
f"bert.encoder.layers.{d}.attention.self.query.weight"
|
| 839 |
+
] = Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
|
| 840 |
+
state_dict[
|
| 841 |
+
f"bert.encoder.layers.{d}.attention.self.key.weight"
|
| 842 |
+
] = Wqkv_weights[
|
| 843 |
Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
|
| 844 |
]
|
| 845 |
+
state_dict[
|
| 846 |
+
f"bert.encoder.layers.{d}.attention.self.value.weight"
|
| 847 |
+
] = Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
|
| 848 |
+
state_dict[
|
| 849 |
+
f"bert.encoder.layers.{d}.attention.self.query.bias"
|
| 850 |
+
] = Wqkv_biases[: Wqkv_biases.shape[0] // 3]
|
| 851 |
+
state_dict[
|
| 852 |
+
f"bert.encoder.layers.{d}.attention.self.key.bias"
|
| 853 |
+
] = Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
|
| 854 |
+
state_dict[
|
| 855 |
+
f"bert.encoder.layers.{d}.attention.self.value.bias"
|
| 856 |
+
] = Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
|
| 857 |
else:
|
| 858 |
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
| 859 |
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
| 860 |
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
| 861 |
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
| 862 |
+
state_dict[
|
| 863 |
+
f"bert.encoder.layers.{d}.attention.self.query.weight"
|
| 864 |
+
] = Wq_weight
|
| 865 |
+
state_dict[
|
| 866 |
+
f"bert.encoder.layers.{d}.attention.self.key.weight"
|
| 867 |
+
] = Wkv_weights[: Wkv_weights.shape[0] // 2, :]
|
| 868 |
+
state_dict[
|
| 869 |
+
f"bert.encoder.layers.{d}.attention.self.value.weight"
|
| 870 |
+
] = Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
|
| 871 |
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
| 872 |
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
| 873 |
: Wkv_biases.shape[0] // 2
|
| 874 |
]
|
| 875 |
+
state_dict[
|
| 876 |
+
f"bert.encoder.layers.{d}.attention.self.value.bias"
|
| 877 |
+
] = Wkv_biases[Wkv_biases.shape[0] // 2 :]
|
| 878 |
|
| 879 |
def inv_key_mapping_ln(key):
|
| 880 |
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
|
|
|
| 926 |
def inv_key_mapping_decoder_bias(key):
|
| 927 |
return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
|
| 928 |
|
| 929 |
+
state_dict = OrderedDict(
|
| 930 |
+
(inv_key_mapping_ln(key), value) for key, value in state_dict.items()
|
| 931 |
+
)
|
| 932 |
state_dict = OrderedDict(
|
| 933 |
(inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
|
| 934 |
)
|
| 935 |
state_dict = OrderedDict(
|
| 936 |
(inv_key_mapping_layers(key), value) for key, value in state_dict.items()
|
| 937 |
)
|
| 938 |
+
state_dict = OrderedDict(
|
| 939 |
+
(inv_key_mapping_mlp(key), value) for key, value in state_dict.items()
|
| 940 |
+
)
|
| 941 |
state_dict = OrderedDict(
|
| 942 |
(inv_key_mapping_attn(key), value) for key, value in state_dict.items()
|
| 943 |
)
|
|
|
|
| 945 |
(inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
|
| 946 |
)
|
| 947 |
|
| 948 |
+
return state_dict
|
stochastic_depth.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implementation modified from torchvision:
|
| 2 |
+
# https://github.com/pytorch/vision/blob/main/torchvision/ops/stochastic_depth.py
|
| 3 |
+
#
|
| 4 |
+
# License:
|
| 5 |
+
# BSD 3-Clause License
|
| 6 |
+
#
|
| 7 |
+
# Copyright (c) Soumith Chintala 2016,
|
| 8 |
+
# All rights reserved.
|
| 9 |
+
#
|
| 10 |
+
# Redistribution and use in source and binary forms, with or without
|
| 11 |
+
# modification, are permitted provided that the following conditions are met:
|
| 12 |
+
#
|
| 13 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
| 14 |
+
# list of conditions and the following disclaimer.
|
| 15 |
+
#
|
| 16 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
| 17 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 18 |
+
# and/or other materials provided with the distribution.
|
| 19 |
+
#
|
| 20 |
+
# * Neither the name of the copyright holder nor the names of its
|
| 21 |
+
# contributors may be used to endorse or promote products derived from
|
| 22 |
+
# this software without specific prior written permission.
|
| 23 |
+
#
|
| 24 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 25 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 26 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 27 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 28 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 29 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 30 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 31 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 32 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 33 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.fx
|
| 37 |
+
from torch import nn, Tensor
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def stochastic_depth(
|
| 41 |
+
input: Tensor, p: float, mode: str, training: bool = True
|
| 42 |
+
) -> Tensor:
|
| 43 |
+
"""
|
| 44 |
+
Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
|
| 45 |
+
<https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
|
| 46 |
+
branches of residual architectures.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
|
| 50 |
+
being its batch i.e. a batch with ``N`` rows.
|
| 51 |
+
p (float): probability of the input to be zeroed.
|
| 52 |
+
mode (str): ``"batch"`` or ``"row"``.
|
| 53 |
+
``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
|
| 54 |
+
randomly selected rows from the batch.
|
| 55 |
+
training: apply stochastic depth if is ``True``. Default: ``True``
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Tensor[N, ...]: The randomly zeroed tensor.
|
| 59 |
+
"""
|
| 60 |
+
if p < 0.0 or p > 1.0:
|
| 61 |
+
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
|
| 62 |
+
if mode not in ["batch", "row"]:
|
| 63 |
+
raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
|
| 64 |
+
if not training or p == 0.0:
|
| 65 |
+
return input
|
| 66 |
+
|
| 67 |
+
survival_rate = 1.0 - p
|
| 68 |
+
if mode == "row":
|
| 69 |
+
size = [input.shape[0]] + [1] * (input.ndim - 1)
|
| 70 |
+
else:
|
| 71 |
+
size = [1] * input.ndim
|
| 72 |
+
noise = torch.empty(size, dtype=input.dtype, device=input.device)
|
| 73 |
+
noise = noise.bernoulli_(survival_rate)
|
| 74 |
+
if survival_rate > 0.0:
|
| 75 |
+
noise.div_(survival_rate)
|
| 76 |
+
return input * noise
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
torch.fx.wrap("stochastic_depth")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class StochasticDepth(nn.Module):
|
| 83 |
+
"""
|
| 84 |
+
See :func:`stochastic_depth`.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, p: float, mode: str) -> None:
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.p = p
|
| 90 |
+
self.mode = mode
|
| 91 |
+
|
| 92 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 93 |
+
return stochastic_depth(input, self.p, self.mode, self.training)
|
| 94 |
+
|
| 95 |
+
def __repr__(self) -> str:
|
| 96 |
+
s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
|
| 97 |
+
return s
|