k050506koch commited on
Commit
60976eb
·
verified ·
1 Parent(s): 67f3f99

Uploaded custom arch

Browse files

This file will tell Transformers that this model has custom arch so it will automatically use it on inference

Files changed (1) hide show
  1. modeling_gpt3dev.py +220 -0
modeling_gpt3dev.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
5
+ from transformers.models.gpt2.modeling_gpt2 import (
6
+ GPT2LMHeadModel,
7
+ GPT2Model,
8
+ GPT2Block,
9
+ GPT2Attention,
10
+ GPT2MLP,
11
+ CausalLMOutputWithCrossAttentions
12
+ )
13
+
14
+ from transformers import (
15
+ CONFIG_MAPPING,
16
+ AutoConfig,
17
+ AutoModel,
18
+ AutoModelForCausalLM,
19
+ )
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ # Custom Configuration Class
25
+ class GPT3DevConfig(GPT2Config):
26
+ model_type = "gpt3dev"
27
+
28
+ def __init__(self, use_pre_layernorm=True, **kwargs):
29
+ super().__init__(**kwargs)
30
+ self.use_pre_layernorm = use_pre_layernorm
31
+
32
+ # Register the configuration with AutoConfig
33
+ CONFIG_MAPPING.register("gpt3dev", GPT3DevConfig)
34
+ AutoConfig.register("gpt3dev", GPT3DevConfig)
35
+
36
+ # Custom Attention Module
37
+ class GPT3DevAttention(GPT2Attention):
38
+ def __init__(self, config, is_cross_attention=False):
39
+ super().__init__(config, is_cross_attention)
40
+
41
+ # Ensure biases are included
42
+ self.c_attn = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True)
43
+ self.c_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
44
+
45
+ # Custom MLP Module
46
+ class GPT3DevMLP(GPT2MLP):
47
+ def __init__(self, intermediate_size, config):
48
+ super().__init__(intermediate_size, config)
49
+ self.c_fc = nn.Linear(config.hidden_size, intermediate_size, bias=True)
50
+ self.c_proj = nn.Linear(intermediate_size, config.hidden_size, bias=True)
51
+ self.act = nn.GELU() # Use standard GeLU
52
+
53
+ # Custom Transformer Block
54
+ class GPT3DevBlock(GPT2Block):
55
+ def __init__(self, config):
56
+ super().__init__(config)
57
+ self.use_pre_layernorm = config.use_pre_layernorm
58
+ self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
59
+ self.attn = GPT3DevAttention(config)
60
+ self.mlp = GPT3DevMLP(4 * config.hidden_size, config)
61
+ self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
62
+
63
+ def forward(
64
+ self,
65
+ hidden_states,
66
+ layer_past=None,
67
+ attention_mask=None,
68
+ head_mask=None,
69
+ encoder_hidden_states=None,
70
+ encoder_attention_mask=None,
71
+ use_cache=None,
72
+ output_attentions=False,
73
+ ):
74
+ if self.use_pre_layernorm:
75
+ # Pre-LayerNorm
76
+ residual = hidden_states
77
+ hidden_states = self.ln_1(hidden_states)
78
+ attn_outputs = self.attn(
79
+ hidden_states,
80
+ layer_past=layer_past,
81
+ attention_mask=attention_mask,
82
+ head_mask=head_mask,
83
+ encoder_hidden_states=encoder_hidden_states,
84
+ encoder_attention_mask=encoder_attention_mask,
85
+ use_cache=use_cache,
86
+ output_attentions=output_attentions,
87
+ )
88
+ attn_output = attn_outputs[0]
89
+ outputs = attn_outputs[1:] # present, (attentions)
90
+
91
+ hidden_states = residual + attn_output
92
+
93
+ residual = hidden_states
94
+ hidden_states = self.ln_2(hidden_states)
95
+ feed_forward_hidden_states = self.mlp(hidden_states)
96
+ hidden_states = residual + feed_forward_hidden_states
97
+ else:
98
+ # Original GPT-2 Post-LayerNorm
99
+ residual = hidden_states
100
+ attn_outputs = self.attn(
101
+ hidden_states,
102
+ layer_past=layer_past,
103
+ attention_mask=attention_mask,
104
+ head_mask=head_mask,
105
+ encoder_hidden_states=encoder_hidden_states,
106
+ encoder_attention_mask=encoder_attention_mask,
107
+ use_cache=use_cache,
108
+ output_attentions=output_attentions,
109
+ )
110
+ attn_output = attn_outputs[0]
111
+ outputs = attn_outputs[1:] # present, (attentions)
112
+
113
+ hidden_states = residual + attn_output
114
+ hidden_states = self.ln_1(hidden_states)
115
+
116
+ residual = hidden_states
117
+ feed_forward_hidden_states = self.mlp(hidden_states)
118
+ hidden_states = residual + feed_forward_hidden_states
119
+ hidden_states = self.ln_2(hidden_states)
120
+
121
+ if use_cache:
122
+ outputs = (hidden_states,) + outputs
123
+ else:
124
+ outputs = (hidden_states,) + outputs[1:]
125
+
126
+ return outputs # hidden_states, present, (attentions)
127
+
128
+ # Custom Transformer Model
129
+ class GPT3DevModel(GPT2Model):
130
+ config_class = GPT3DevConfig
131
+
132
+ def __init__(self, config):
133
+ super().__init__(config)
134
+
135
+ self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
136
+ self.wpe = nn.Embedding(config.n_positions, config.hidden_size)
137
+ self.drop = nn.Dropout(config.embd_pdrop)
138
+ self.h = nn.ModuleList(
139
+ [GPT3DevBlock(config) for _ in range(config.num_hidden_layers)]
140
+ )
141
+ self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
142
+
143
+ # Initialize weights
144
+ self.post_init()
145
+
146
+ # Custom LM Head Model
147
+ class GPT3DevLMHeadModel(GPT2LMHeadModel):
148
+ config_class = GPT3DevConfig
149
+
150
+ def __init__(self, config):
151
+ super().__init__(config)
152
+ self.transformer = GPT3DevModel(config)
153
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
154
+
155
+ # Initialize weights
156
+ self.post_init()
157
+
158
+ def forward(
159
+ self,
160
+ input_ids=None,
161
+ past_key_values=None,
162
+ attention_mask=None,
163
+ token_type_ids=None,
164
+ position_ids=None,
165
+ head_mask=None,
166
+ inputs_embeds=None,
167
+ labels=None,
168
+ use_cache=None,
169
+ output_attentions=None,
170
+ output_hidden_states=None,
171
+ return_dict=None,
172
+ ):
173
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
174
+
175
+ transformer_outputs = self.transformer(
176
+ input_ids,
177
+ past_key_values=past_key_values,
178
+ attention_mask=attention_mask,
179
+ token_type_ids=token_type_ids,
180
+ position_ids=position_ids,
181
+ head_mask=head_mask,
182
+ inputs_embeds=inputs_embeds,
183
+ use_cache=use_cache,
184
+ output_attentions=output_attentions,
185
+ output_hidden_states=output_hidden_states,
186
+ return_dict=return_dict,
187
+ )
188
+
189
+ hidden_states = transformer_outputs[0]
190
+
191
+ lm_logits = self.lm_head(hidden_states)
192
+
193
+ loss = None
194
+ if labels is not None:
195
+ # Shift so that tokens < n predict n
196
+ shift_logits = lm_logits[..., :-1, :].contiguous()
197
+ shift_labels = labels[..., 1:].contiguous()
198
+ loss_fct = nn.CrossEntropyLoss()
199
+ loss = loss_fct(
200
+ shift_logits.view(-1, shift_logits.size(-1)),
201
+ shift_labels.view(-1)
202
+ )
203
+
204
+ if not return_dict:
205
+ output = (lm_logits,) + transformer_outputs[1:]
206
+ return ((loss,) + output) if loss is not None else output
207
+
208
+ return CausalLMOutputWithCrossAttentions(
209
+ loss=loss,
210
+ logits=lm_logits,
211
+ past_key_values=transformer_outputs.past_key_values,
212
+ hidden_states=transformer_outputs.hidden_states,
213
+ attentions=transformer_outputs.attentions,
214
+ cross_attentions=transformer_outputs.cross_attentions,
215
+ )
216
+
217
+ # Register the custom model with AutoModel and AutoModelForCausalLM
218
+ AutoConfig.register("gpt3dev", GPT3DevConfig)
219
+ AutoModel.register(GPT3DevConfig, GPT3DevModel)
220
+ AutoModelForCausalLM.register(GPT3DevConfig, GPT3DevLMHeadModel)