lpepino commited on
Commit
d1983e2
·
1 Parent(s): b31535a

Create config.gin

Browse files
Files changed (1) hide show
  1. config.gin +57 -0
config.gin ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NUM_ENCODEC_TARGETS=8
2
+ NUM_TOTAL_TARGETS=8
3
+ NUM_TARGET_TOKENS=1024
4
+ MASK_AMOUNT=150
5
+ MASK_GAP_SIZE=15
6
+ MASK_PROP=0.5
7
+ MODEL_DIM=768
8
+ NUM_ENCODER_LAYERS=10
9
+ NUM_ENCODER_HEADS=12
10
+ NUM_DECODER_LAYERS=2
11
+ NUM_DECODER_HEADS=12
12
+ MASKED_LOSS_WEIGHT=0.9
13
14
+ models.EncodecMAE:
15
+ wav_encoder = @models.encodecmae.encoders.EncodecEncoder
16
+ target_encoder = @models.encodecmae.targets.EncodecQuantizer
17
+ masker = @models.encodecmae.masking.TimeGapMask
18
+ visible_encoder = @encoder/models.transformers.TransformerEncoder
19
+ positional_encoder = @models.transformers.SinusoidalPositionalEmbeddings
20
+ decoder = @decoder/models.transformers.TransformerEncoder
21
+ head = @models.encodecmae.heads.FrameLevelClassificationHead
22
23
+ lr_scheduler=None
24
+ masked_weight=%MASKED_LOSS_WEIGHT
25
+ quantizer_weights=[0.22407463, 0.1759858 , 0.14499009, 0.12150037, 0.10315603, 0.08831368, 0.07608274, 0.06589669, 1.0]
26
+ n_extra_targets=1
27
+ torch.optim.AdamW:
28
+ lr=%PRETRAIN_MAX_LR
29
+ betas=(0.9,0.95)
30
+ weight_decay=0.05
31
+ models.encodecmae.targets.EncodecQuantizer:
32
+ n = %NUM_ENCODEC_TARGETS
33
+ models.encodecmae.masking.TimeGapMask:
34
+ mask_amount = %MASK_AMOUNT
35
+ gap_size = %MASK_GAP_SIZE
36
+ mask_prop = %MASK_PROP
37
+ encoder/models.transformers.TransformerEncoder:
38
+ model_dim=%MODEL_DIM
39
+ num_layers=%NUM_ENCODER_LAYERS
40
+ attention_layer=@encoder/models.transformers.MultiHeadAttention
41
+ compile=True
42
+ encoder/models.transformers.MultiHeadAttention:
43
+ model_dim=%MODEL_DIM
44
+ num_heads=%NUM_ENCODER_HEADS
45
+ decoder/models.transformers.TransformerEncoder:
46
+ model_dim=%MODEL_DIM
47
+ num_layers=%NUM_DECODER_LAYERS
48
+ attention_layer=@decoder/models.transformers.MultiHeadAttention
49
+ compile=True
50
+ decoder/models.transformers.MultiHeadAttention:
51
+ model_dim=%MODEL_DIM
52
+ num_heads=%NUM_DECODER_HEADS
53
+ models.transformers.SinusoidalPositionalEmbeddings.embedding_dim = %MODEL_DIM
54
+ models.encodecmae.heads.FrameLevelClassificationHead:
55
+ model_dim=%MODEL_DIM
56
+ num_tokens=%NUM_TARGET_TOKENS
57
+ num_streams=%NUM_TOTAL_TARGETS