| import torch | |
| import torch.nn as nn | |
| class FeedForward(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False) | |
| self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False) | |
| self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False) | |
| def forward(self, x): | |
| x_fc1 = self.fc1(x) | |
| x_fc2 = self.fc2(x) | |
| x = nn.functional.silu(x_fc1) * x_fc2 | |
| return self.fc3(x) |