Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| import timm | |
| import config as CFG | |
| class TextEncoder(nn.Module): | |
| """ | |
| Text/Poem encoder used in PoemTextModel and CLIPModel | |
| ... | |
| Attributes: | |
| ----------- | |
| model : a torch.nn.Module model | |
| The image encoder model | |
| Methods: | |
| -------- | |
| forward(x) | |
| returns model embeddings of x (batch of texts/poems) (of the CLS token) | |
| __init__() | |
| creates the encoder model using huggingface transformers, | |
| also freezes the model if it's not trainable. | |
| """ | |
| def __init__(self, encoder_model, encoder_pretrained_name, pretrained, trainable): | |
| """ | |
| creates the poem or text encoder model using transformers and loads weights from pretrained model if needed. | |
| Also freezes the model if it's not trainable. | |
| Parameters: | |
| ----------- | |
| pretrained: bool | |
| if pretrained=True, get pretrained model's weights. else create a fresh untrained model. | |
| trainable: bool | |
| if trainable=False, the model's weights will be frozen. | |
| encoder_model: str | |
| image encoder model name used as input to get the right model from configs. | |
| encoder_pretrained_name: str | |
| image encoder model to get weights from. (not used when pretrained=False) | |
| """ | |
| super().__init__() | |
| if pretrained: | |
| self.model = CFG.encoders[encoder_model].from_pretrained(encoder_pretrained_name) | |
| else: | |
| self.model = CFG.encoders[encoder_model](config=CFG.configs[encoder_model]()) | |
| for p in self.model.parameters(): | |
| p.requires_grad = trainable | |
| # Using the CLS token hidden representation as the sentence's embedding | |
| self.target_token_idx = 0 | |
| def forward(self, input_ids, attention_mask): | |
| """ | |
| forwards and calculates embeddings of the input using attention mask. | |
| Parameters: | |
| ----------- | |
| input_ids: input ids (output of tokenizer) | |
| attention masks: input masks (for example for padding, pad tokens will be masked) | |
| Returns: | |
| -------- | |
| the embedding of the CLS (or target) token of the encoder's last hidden state | |
| """ | |
| output = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
| last_hidden_state = output.last_hidden_state | |
| return last_hidden_state[:, self.target_token_idx, :] | |
| class ProjectionHead(nn.Module): | |
| """ | |
| Projection head used to project embeddings from each encoder to a shared embedding space | |
| ... | |
| Attributes: | |
| ----------- | |
| projection : torch.nn.Linear | |
| The main Dense projection (from encoder's embedding dim to shared embedding projection dim) | |
| gelu: torch.nn.GELU | |
| activation function | |
| fc: torch.nn.Linear | |
| a dense layer after projection (projection_dim to projection_dim) | |
| dropout: torch.nn.Dropout | |
| dropout after fc | |
| layer_norm: torch.nn.LayerNorm | |
| layer norm after dropout | |
| Methods: | |
| -------- | |
| forward(x) | |
| returns projection embeddings from x (encoder output embeddings) | |
| __init__() | |
| creates the projection head | |
| """ | |
| def __init__( | |
| self, | |
| embedding_dim, | |
| projection_dim=CFG.projection_dim, | |
| dropout=CFG.dropout | |
| ): | |
| """ | |
| Creates the projection head used after an encoder. | |
| Parameters: | |
| ----------- | |
| embedding_dim: int | |
| dimension of the output embeddings of the encoder. | |
| projection_dim: int, optional | |
| dimension to project embeddings to. | |
| dropout: float | |
| fraction of the output of fc layer to be zeroed. | |
| """ | |
| super().__init__() | |
| self.projection = nn.Linear(embedding_dim, projection_dim) | |
| self.gelu = nn.GELU() | |
| self.fc = nn.Linear(projection_dim, projection_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.layer_norm = nn.LayerNorm(projection_dim) | |
| def forward(self, x): | |
| """ | |
| Forwards and calculates projected embeddings from encoder embeddings. | |
| Parameters: | |
| ----------- | |
| x: input (of shape (batch_size, embedding_dim)) | |
| the output embedding of this projection head's encoder | |
| Returns: | |
| -------- | |
| the embeddings in a shared embedding space (of shape (batch_size, projection_dim)) | |
| """ | |
| projected = self.projection(x) #main projection layer | |
| x = self.gelu(projected) | |
| x = self.fc(x) | |
| x = self.dropout(x) | |
| # the projected outputs are added to x as a residual connection | |
| x = x + projected | |
| x = self.layer_norm(x) | |
| return x | |
| class ImageEncoder(nn.Module): | |
| """ | |
| Image encoder used in CLIPModel | |
| ... | |
| Attributes: | |
| ----------- | |
| model : a torch.nn.Module model from timm (pytorch-image-models) | |
| The image encoder model | |
| Methods: | |
| -------- | |
| forward(x) | |
| returns model embeddings of x (batch of images) | |
| __init__() | |
| creates the encoder model using timm and loads fine-tuned model's state dict if needed. | |
| also freezes the model if it's not trainable. | |
| """ | |
| def __init__( | |
| self, pretrained, trainable, model_name=CFG.image_encoder_model | |
| ): | |
| """ | |
| creates the encoder model using timm and loads fine-tuned model's state dict if needed. | |
| Also freezes the model if it's not trainable. | |
| Parameters: | |
| ----------- | |
| pretrained: bool | |
| if pretrained=True, get SOTA weights (or weights saved in image_encoder_weights_load_path). | |
| else create a fresh untrained model. | |
| trainable: bool | |
| if trainable=False, the model's weights will be frozen. | |
| model_name: str | |
| image encoder model name used as input to timm.create_model. | |
| """ | |
| super().__init__() | |
| self.model = timm.create_model( | |
| model_name, pretrained, num_classes=0, global_pool="avg" | |
| ) | |
| if pretrained and CFG.image_encoder_weights_load_path: | |
| self.model.load_state_dict(torch.load(CFG.image_encoder_weights_load_path, map_location=CFG.device)) | |
| for p in self.model.parameters(): | |
| p.requires_grad = trainable | |
| def forward(self, x): | |
| """ | |
| forwards and calculates embeddings of the input. | |
| Parameters: | |
| ----------- | |
| x: input (batch of transformed images) | |
| Returns: | |
| -------- | |
| embeddings of the model for the input (of shape (batch_size, image_embedding)) | |
| """ | |
| return self.model(x) | |