Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import wget | |
| import json | |
| import os | |
| import sentencepiece as spm | |
| import re | |
| CODEGEN_FOLDER = "./CodeGenModel" | |
| CODEGEN_MODEL_NAME = "codegen-350M-multi" | |
| CODEGEN_MODEL_WEIGHTS = "pytorch_model.bin" | |
| CODEGEN_CONFIG = "config.json" | |
| CODEGEN_VOCAB = "vocab.json" | |
| CODEGEN_MERGES = "merges.txt" | |
| CODEGEN_MODEL_WEIGHTS_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/pytorch_model.bin" | |
| CODEGEN_CONFIG_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/config.json" | |
| CODEGEN_VOCAB_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/vocab.json" | |
| CODEGEN_MERGES_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/merges.txt" | |
| CODEGEN_FILES_URLS = [ | |
| (CODEGEN_MODEL_WEIGHTS_URL, CODEGEN_MODEL_WEIGHTS), | |
| (CODEGEN_CONFIG_URL, CODEGEN_CONFIG), | |
| (CODEGEN_VOCAB_URL, CODEGEN_VOCAB), | |
| (CODEGEN_MERGES_URL, CODEGEN_MERGES), | |
| ] | |
| CODEGEN_SPM_URL = "https://huggingface.co/Salesforce/codegen-350M-multi/resolve/main/spm.model" | |
| CODEGEN_SPM = "spm.model" | |
| def ensure_codegen_files_exist(): | |
| os.makedirs(CODEGEN_FOLDER, exist_ok=True) | |
| for url, filename in CODEGEN_FILES_URLS: | |
| filepath = os.path.join(CODEGEN_FOLDER, filename) | |
| if not os.path.exists(filepath): | |
| wget.download(url, out=filepath) | |
| filepath_spm = os.path.join(CODEGEN_FOLDER, CODEGEN_SPM) | |
| if not os.path.exists(filepath_spm): | |
| wget.download(CODEGEN_SPM_URL, out=filepath_spm) | |
| class CodeGenConfig: | |
| def __init__(self, vocab_size, n_positions=2048, n_ctx=2048, n_embd=1024, n_layer=24, n_head=16, n_inner=None, activation_function="gelu_new", resid_pdrop=0.1, embd_pdrop=0.1, attn_pdrop=0.1, layer_norm_epsilon=1e-05, initializer_range=0.02, scale_attn_weights=True, use_cache=True, bos_token_id=50256, eos_token_id=50256, **kwargs): | |
| self.vocab_size = vocab_size | |
| self.n_positions = n_positions | |
| self.n_ctx = n_ctx | |
| self.n_embd = n_embd | |
| self.n_layer = n_layer | |
| self.n_head = n_head | |
| self.n_inner = n_inner | |
| self.activation_function = activation_function | |
| self.resid_pdrop = resid_pdrop | |
| self.embd_pdrop = embd_pdrop | |
| self.attn_pdrop = attn_pdrop | |
| self.layer_norm_epsilon = layer_norm_epsilon | |
| self.initializer_range = initializer_range | |
| self.scale_attn_weights = scale_attn_weights | |
| self.use_cache = use_cache | |
| self.bos_token_id = bos_token_id | |
| self.eos_token_id = eos_token_id | |
| for key, value in kwargs.items(): | |
| setattr(self, key, value) | |
| def from_dict(cls, config_dict): | |
| return cls(**config_dict) | |
| class CodeGenForCausalLM(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.transformer = CodeGenModel(config) | |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| def forward(self, input_ids, attention_mask=None): | |
| transformer_outputs = self.transformer(input_ids, attention_mask=attention_mask) | |
| logits = self.lm_head(transformer_outputs) | |
| return logits | |
| class CodeGenModel(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.wte = nn.Embedding(config.vocab_size, config.n_embd) | |
| self.wpe = nn.Embedding(config.n_positions, config.n_embd) | |
| self.drop = nn.Dropout(config.embd_pdrop) | |
| self.h = nn.ModuleList([CodeGenBlock(config) for _ in range(config.n_layer)]) | |
| self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) | |
| def forward(self, input_ids, attention_mask=None): | |
| input_shape = input_ids.size() | |
| input_ids = input_ids.view(-1, input_ids.size(-1)) | |
| position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=input_ids.device) | |
| position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |
| inputs_embeds = self.wte(input_ids) | |
| position_embeds = self.wpe(position_ids) | |
| hidden_states = inputs_embeds + position_embeds | |
| hidden_states = self.drop(hidden_states) | |
| output_shape = input_shape + (hidden_states.size(-1),) | |
| for block in self.h: | |
| hidden_states = block(hidden_states, attention_mask=attention_mask) | |
| hidden_states = self.ln_f(hidden_states) | |
| return hidden_states.view(*output_shape) | |
| class CodeGenBlock(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.ln_1 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) | |
| self.attn = CodeGenAttention(config) | |
| self.ln_2 = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) | |
| self.mlp = CodeGenMLP(config) | |
| def forward(self, hidden_states, attention_mask=None): | |
| residual = hidden_states | |
| hidden_states = self.ln_1(hidden_states) | |
| attn_outputs = self.attn(hidden_states, attention_mask=attention_mask) | |
| hidden_states = residual + attn_outputs | |
| residual = hidden_states | |
| hidden_states = self.ln_2(hidden_states) | |
| feedforward_hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + feedforward_hidden_states | |
| return hidden_states | |
| class CodeGenMLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.c_fc = nn.Linear(config.n_embd, config.n_inner) | |
| self.c_proj = nn.Linear(config.n_inner, config.n_embd) | |
| self.dropout = nn.Dropout(config.resid_pdrop) | |
| def forward(self, hidden_states): | |
| hidden_states = self.c_fc(hidden_states) | |
| hidden_states = F.gelu(hidden_states) | |
| hidden_states = self.c_proj(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| return hidden_states | |
| class CodeGenAttention(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.attn_dropout = nn.Dropout(config.attn_pdrop) | |
| self.resid_dropout = nn.Dropout(config.resid_pdrop) | |
| self.n_head = config.n_head | |
| self.embed_dim = config.n_embd | |
| self.split_size = self.embed_dim | |
| self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim) | |
| self.c_proj = nn.Linear(self.embed_dim, self.embed_dim) | |
| self.scale_attn_weights = config.scale_attn_weights | |
| self.use_cache = config.use_cache | |
| self.register_buffer("bias", torch.tril(torch.ones((config.n_ctx, config.n_ctx), dtype=torch.uint8)).view((1, 1, config.n_ctx, config.n_ctx))) | |
| def _attn(self, query, key, value, attention_mask=None, head_mask=None): | |
| attn_weights = torch.matmul(query, key.transpose(-1, -2)) | |
| if self.scale_attn_weights: | |
| attn_weights = attn_weights / math.sqrt(value.size(-1)) | |
| mask = self.bias[:, :, :attn_weights.size(-2), :attn_weights.size(-1)] | |
| attn_weights = torch.where(mask.bool(), attn_weights, torch.tensor(-1e4, device=attn_weights.device)) | |
| if attention_mask is not None: | |
| attn_weights = attn_weights + attention_mask | |
| attn_weights = nn.Softmax(dim=-1)(attn_weights) | |
| attn_weights = self.attn_dropout(attn_weights) | |
| attn_output = torch.matmul(attn_weights, value) | |
| return attn_output | |
| def _split_heads(self, tensor, num_heads, attn_head_size): | |
| new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) | |
| tensor = tensor.view(*new_shape) | |
| return tensor.permute(0, 2, 1, 3) | |
| def _merge_heads(self, tensor, num_heads, attn_head_size): | |
| new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) | |
| return tensor.view(*new_shape) | |
| def forward(self, hidden_states, attention_mask=None, head_mask=None, past_key_value=None, use_cache=False): | |
| query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) | |
| query = self._split_heads(query, self.n_head, self.embed_dim // self.n_head) | |
| key = self._split_heads(key, self.n_head, self.embed_dim // self.n_head) | |
| value = self._split_heads(value, self.n_head, self.embed_dim // self.n_head) | |
| if past_key_value is not None: | |
| past_key, past_value = past_key_value | |
| key = torch.cat((past_key, key), dim=-2) | |
| value = torch.cat((past_value, value), dim=-2) | |
| present_key_value = (key, value) if use_cache else None | |
| attn_output = self._attn(query, key, value, attention_mask, head_mask) | |
| attn_output = self._merge_heads(attn_output, self.n_head, self.embed_dim // self.n_head) | |
| attn_output = self.c_proj(attn_output) | |
| attn_output = self.resid_dropout(attn_output) | |
| outputs = (attn_output, present_key_value) | |
| return outputs[0] |