Pocket-Gen / models /esm2adapter.py
Zaixi's picture
1
dcacefd
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy
import math
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
import esm
from esm.modules import (
TransformerLayer,
LearnedPositionalEmbedding,
SinusoidalPositionalEmbedding,
RobertaLMHead,
ESM1bLayerNorm,
ContactPredictionHead,
ESM1LayerNorm,
FeedForwardNetwork,
NormalizedResidualBlock,
gelu,
)
from esm.multihead_attention import MultiheadAttention
def Cfg(**kwds):
return OmegaConf.create(kwds)
def merge_config(default_cfg, override_cfg):
return OmegaConf.merge(default_cfg, override_cfg)
class ESM2WithStructuralAdatper(nn.Module):
@classmethod
def from_pretrained(cls, args, override_args=None, name='esm2_t33_650M_UR50D'):
import esm
pretrained_model, alphabet = esm.pretrained.load_model_and_alphabet_hub(name)
pretrained_args = Cfg(
num_layers=pretrained_model.num_layers,
embed_dim=pretrained_model.embed_dim,
attention_heads=pretrained_model.attention_heads,
token_dropout=pretrained_model.token_dropout,
)
args = merge_config(pretrained_args, args)
# args.adapter_layer_indices = getattr(args, 'adapter_layer_indices', [6, 20, 32])
args.adapter_layer_indices = [-1]
args.adapter_layer_indices = list(
map(lambda x: (args.num_layers + x) % args.num_layers,
args.adapter_layer_indices)
)
#args.adapter_layer_indices = [6, 20, 32]
model = cls(args, deepcopy(alphabet))
model.load_state_dict(pretrained_model.state_dict(), strict=False)
del pretrained_model
# freeze pretrained parameters
for pname, param in model.named_parameters():
if 'adapter' not in pname:
param.requires_grad = False
return model
def __init__(
self,
args,
alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
# num_layers: int = 33,
# embed_dim: int = 1280,
# attention_heads: int = 20,
# token_dropout: bool = True,
):
super().__init__()
self.args = args
self.num_layers = args.num_layers
self.embed_dim = args.embed_dim
self.attention_heads = args.attention_heads
if not isinstance(alphabet, esm.data.Alphabet):
alphabet = esm.data.Alphabet.from_architecture(alphabet)
self.alphabet = alphabet
self.alphabet_size = len(alphabet)
self.padding_idx = alphabet.padding_idx
self.mask_idx = alphabet.mask_idx
self.cls_idx = alphabet.cls_idx
self.eos_idx = alphabet.eos_idx
self.prepend_bos = alphabet.prepend_bos
self.append_eos = alphabet.append_eos
self.token_dropout = args.token_dropout
self._init_submodules()
def _init_submodules(self):
self.embed_scale = 1
self.embed_tokens = nn.Embedding(
self.alphabet_size,
self.embed_dim,
padding_idx=self.padding_idx,
)
self.layers = nn.ModuleList(
[
self._init_layer(_)
for _ in range(self.num_layers)
]
)
self.contact_head = ContactPredictionHead(
self.num_layers * self.attention_heads,
self.prepend_bos,
self.append_eos,
eos_idx=self.eos_idx,
)
self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
self.lm_head = RobertaLMHead(
embed_dim=self.embed_dim,
output_dim=self.alphabet_size,
weight=self.embed_tokens.weight,
)
def _init_layer(self, layer_idx):
if layer_idx in self.args.adapter_layer_indices:
layer = TransforerLayerWithStructralAdapter(
self.embed_dim,
4 * self.embed_dim,
self.attention_heads,
add_bias_kv=False,
use_esm1b_layer_norm=True,
use_rotary_embeddings=True,
encoder_embed_dim=self.args.encoder.d_model,
dropout=self.args.dropout
)
else:
layer = TransformerLayer(
self.embed_dim,
4 * self.embed_dim,
self.attention_heads,
add_bias_kv=False,
use_esm1b_layer_norm=True,
use_rotary_embeddings=True,
)
return layer
def forward_layers(self, x, encoder_out, padding_mask, repr_layers=[], hidden_representations=[], need_head_weights=False, attn_weights=[]):
for layer_idx, layer in enumerate(self.layers):
if layer_idx in self.args.adapter_layer_indices:
x, attn = layer(
x, encoder_out, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights
)
else:
x, attn = layer(
x, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights
)
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
if need_head_weights:
# (H, B, T, T) => (B, H, T, T)
attn_weights.append(attn.transpose(1, 0))
return x, hidden_representations, attn_weights, layer_idx
def forward(self, tokens, encoder_out, repr_layers=[], need_head_weights=False, return_contacts=False):
if return_contacts:
need_head_weights = True
assert tokens.ndim == 2
padding_mask = tokens.eq(self.padding_idx) # B, T
x = self.embed_scale * self.embed_tokens(tokens)
if self.token_dropout:
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
# x: B x T x C
mask_ratio_train = 0.15 * 0.8
src_lengths = (~padding_mask).sum(-1)
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
repr_layers = set(repr_layers)
hidden_representations = {}
if 0 in repr_layers:
hidden_representations[0] = x
if need_head_weights:
attn_weights = []
# (B, T, E) => (T, B, E)
x = x.transpose(0, 1)
if not padding_mask.any():
padding_mask = None
# for layer_idx, layer in enumerate(self.layers):
# x, attn = layer(
# x,
# self_attn_padding_mask=padding_mask,
# need_head_weights=need_head_weights,
# )
# if (layer_idx + 1) in repr_layers:
# hidden_representations[layer_idx + 1] = x.transpose(0, 1)
# if need_head_weights:
# # (H, B, T, T) => (B, H, T, T)
# attn_weights.append(attn.transpose(1, 0))
x, hidden_representations, attn_weights, layer_idx = self.forward_layers(
x, encoder_out, padding_mask,
repr_layers=repr_layers,
hidden_representations=hidden_representations,
need_head_weights=need_head_weights,
attn_weights=attn_weights if need_head_weights else None
)
x = self.emb_layer_norm_after(x)
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
# last hidden representation should have layer norm applied
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x
x = self.lm_head(x)
result = {"logits": x, "representations": hidden_representations}
if need_head_weights:
# attentions: B x L x H x T x T
attentions = torch.stack(attn_weights, 1)
if padding_mask is not None:
attention_mask = 1 - padding_mask.type_as(attentions)
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
attentions = attentions * attention_mask[:, None, None, :, :]
result["attentions"] = attentions
if return_contacts:
contacts = self.contact_head(tokens, attentions)
result["contacts"] = contacts
return result
def predict_contacts(self, tokens):
return self(tokens, return_contacts=True)["contacts"]
class TransforerLayerWithStructralAdapter(nn.Module):
def __init__(
self,
embed_dim,
ffn_embed_dim,
attention_heads,
encoder_embed_dim,
add_bias_kv=True,
use_esm1b_layer_norm=False,
use_rotary_embeddings: bool = False,
dropout=0.1,
):
super().__init__()
self.embed_dim = embed_dim
self.ffn_embed_dim = ffn_embed_dim
self.attention_heads = attention_heads
self.use_rotary_embeddings = use_rotary_embeddings
self.encoder_embed_dim = encoder_embed_dim
self.dropout = dropout
self._init_submodules(add_bias_kv, use_esm1b_layer_norm)
def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm):
BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm
self.self_attn = MultiheadAttention(
self.embed_dim,
self.attention_heads,
add_bias_kv=add_bias_kv,
add_zero_attn=False,
use_rotary_embeddings=self.use_rotary_embeddings,
)
self.self_attn_layer_norm = BertLayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
self.final_layer_norm = BertLayerNorm(self.embed_dim)
# structural adapter
self.structural_adapter_attn = NormalizedResidualBlock(
layer=MultiheadAttention(
self.embed_dim,
self.attention_heads,
kdim=self.encoder_embed_dim,
vdim=self.encoder_embed_dim,
add_bias_kv=add_bias_kv,
add_zero_attn=False,
use_rotary_embeddings=True,
),
embedding_dim=self.embed_dim,
dropout=self.dropout
)
self.structural_adapter_ffn = NormalizedResidualBlock(
layer=FeedForwardNetwork(
self.embed_dim,
self.embed_dim // 2, # NOTE: bottleneck FFN is important
# self.ffn_embed_dim,
activation_dropout=self.dropout
),
embedding_dim=self.embed_dim,
dropout=self.dropout
)
def forward(
self, x, encoder_out, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False
):
residual = x
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=True,
need_head_weights=need_head_weights,
attn_mask=self_attn_mask,
)
x = residual + x
# x = self.forward_adapter(x, encoder_out, attn_mask=self_attn_mask, attn_padding_mask=self_attn_padding_mask)
residual = x
x = self.final_layer_norm(x)
x = gelu(self.fc1(x))
x = self.fc2(x)
x = residual + x
x = x + self.forward_adapter(x, encoder_out, attn_mask=self_attn_mask, attn_padding_mask=self_attn_padding_mask)
return x, attn
def forward_adapter(self, x, encoder_out, attn_mask, attn_padding_mask):
encoder_feats = encoder_out['feats']
encoder_feats = encoder_feats.transpose(0, 1)
x = self.structural_adapter_attn(
x,
key=encoder_feats,
value=encoder_feats,
key_padding_mask=attn_padding_mask,
attn_mask=attn_mask,
need_weights=False
)[0]
x = self.structural_adapter_ffn(x)
# x = x.transpose(0, 1)
return x