Rosa-V1 / examples /emotion_model.py
willt-dc's picture
Upload 4 files
c1d949e verified
# ROSA: Recursive Ontology of Semantic Affect
# Sublime Emotional System by Willinton Triana Cardona
from transformers import BertModel
import torch.nn as nn
import torch
class Rosa(nn.Module):
def __init__(
self,
model_name="bert-base-uncased",
num_emotions=29, # Includes 28 emotions + neutral
latent_dim=None,
return_vector=False,
emotion_labels=None
):
super().__init__()
self.heart = BertModel.from_pretrained(model_name)
self.grace = nn.Dropout(0.3)
self.bloom = nn.Linear(self.heart.config.hidden_size, num_emotions)
self.return_vector = return_vector
self.latent_dim = latent_dim
self.emotion_labels = emotion_labels or [
"admiration", "amusement", "anger", "annoyance", "approval", "caring",
"confusion", "curiosity", "desire", "disappointment", "disapproval",
"disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
"joy", "love", "nervousness", "optimism", "pride", "realization", "relief",
"remorse", "sadness", "surprise", "neutral"
]
if latent_dim:
self.rosa_embedding = nn.Linear(self.heart.config.hidden_size, latent_dim)
self.loss_fct = nn.BCEWithLogitsLoss()
def forward(self, input_ids, attention_mask, labels=None, **kwargs):
petals = self.heart(input_ids=input_ids, attention_mask=attention_mask)
pooled = petals.pooler_output
softened = self.grace(pooled)
if self.return_vector and hasattr(self, 'rosa_embedding'):
embedding = self.rosa_embedding(softened)
return {"embedding": embedding}
logits = self.bloom(softened)
if labels is not None:
loss = self.loss_fct(logits, labels.float())
return {"loss": loss, "logits": logits}
else:
return {"logits": logits}