|
|
|
|
|
|
|
|
|
|
|
from transformers import BertModel |
|
|
import torch.nn as nn |
|
|
import torch |
|
|
|
|
|
class Rosa(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
model_name="bert-base-uncased", |
|
|
num_emotions=29, |
|
|
latent_dim=None, |
|
|
return_vector=False, |
|
|
emotion_labels=None |
|
|
): |
|
|
super().__init__() |
|
|
self.heart = BertModel.from_pretrained(model_name) |
|
|
self.grace = nn.Dropout(0.3) |
|
|
self.bloom = nn.Linear(self.heart.config.hidden_size, num_emotions) |
|
|
|
|
|
self.return_vector = return_vector |
|
|
self.latent_dim = latent_dim |
|
|
self.emotion_labels = emotion_labels or [ |
|
|
"admiration", "amusement", "anger", "annoyance", "approval", "caring", |
|
|
"confusion", "curiosity", "desire", "disappointment", "disapproval", |
|
|
"disgust", "embarrassment", "excitement", "fear", "gratitude", "grief", |
|
|
"joy", "love", "nervousness", "optimism", "pride", "realization", "relief", |
|
|
"remorse", "sadness", "surprise", "neutral" |
|
|
] |
|
|
|
|
|
if latent_dim: |
|
|
self.rosa_embedding = nn.Linear(self.heart.config.hidden_size, latent_dim) |
|
|
|
|
|
self.loss_fct = nn.BCEWithLogitsLoss() |
|
|
|
|
|
def forward(self, input_ids, attention_mask, labels=None, **kwargs): |
|
|
petals = self.heart(input_ids=input_ids, attention_mask=attention_mask) |
|
|
pooled = petals.pooler_output |
|
|
softened = self.grace(pooled) |
|
|
|
|
|
if self.return_vector and hasattr(self, 'rosa_embedding'): |
|
|
embedding = self.rosa_embedding(softened) |
|
|
return {"embedding": embedding} |
|
|
|
|
|
logits = self.bloom(softened) |
|
|
|
|
|
if labels is not None: |
|
|
loss = self.loss_fct(logits, labels.float()) |
|
|
return {"loss": loss, "logits": logits} |
|
|
else: |
|
|
return {"logits": logits} |