katrjohn commited on
Commit
b5fd657
·
verified ·
1 Parent(s): bcebe31

Upload 4 files

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
config_tiny_greek_news_bert.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # configuration_tiny_greek_news_bert.py
2
+ from transformers import BertConfig
3
+
4
+ class TinyGreekNewsBertConfig(BertConfig):
5
+ model_type = "tiny_greek_news_bert"
6
+ def __init__(
7
+ self,
8
+ num_labels_class=19,
9
+ num_labels_ner=32,
10
+ ner_loss_weight=3.0,
11
+ **kwargs,
12
+ ):
13
+ super().__init__(**kwargs)
14
+ self.num_labels_class = num_labels_class
15
+ self.num_labels_ner = num_labels_ner
16
+ self.ner_loss_weight = ner_loss_weight
17
+
18
+ # 👇 this writes the AutoConfig mapping when you save_pretrained()
19
+ TinyGreekNewsBertConfig.register_for_auto_class()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:408cce07304b82dc981b3014f463d1d6305366ce83a3c3168f6ac31612125f2b
3
+ size 56478996
modeling_tiny_greek_news_bert.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import BertModel, BertPreTrainedModel
3
+ from transformers import BertConfig, AutoTokenizer
4
+
5
+ class TinyGreekNewsBert(BertPreTrainedModel):
6
+ def __init__(self, config):
7
+ super().__init__(config)
8
+ num_labels_class = config.num_labels_class
9
+ num_labels_ner = config.num_labels_ner
10
+ self.ner_loss_weight = getattr(config, "ner_loss_weight", 3.0)
11
+ self.bert = BertModel(config)
12
+
13
+ # Classification head
14
+ self.class_dropout = nn.Dropout(0.3)
15
+ self.class_fc = nn.Linear(config.hidden_size, 768)
16
+ self.class_relu = nn.ReLU()
17
+ self.classifier = nn.Linear(768, num_labels_class)
18
+
19
+ # NER head
20
+ self.ner_classifier = nn.Linear(config.hidden_size, num_labels_ner)
21
+
22
+ self.init_weights()
23
+ # For normalization
24
+ self.initial_cls_loss = None
25
+ self.initial_ner_loss = None
26
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
27
+ labels_class=None, labels_ner=None):
28
+ outputs = self.bert(
29
+ input_ids,
30
+ attention_mask=attention_mask,
31
+ token_type_ids=token_type_ids
32
+ )
33
+ sequence_output = outputs.last_hidden_state # (batch_size, seq_length, hidden_size)
34
+ pooled_output = outputs.pooler_output # (batch_size, hidden_size)
35
+
36
+ # Classification branch
37
+ pooled_output = self.class_dropout(pooled_output)
38
+ x = self.class_fc(pooled_output)
39
+ x = self.class_relu(x)
40
+ logits_class = self.classifier(x)
41
+
42
+ # NER branch
43
+ logits_ner = self.ner_classifier(sequence_output) # (batch_size, seq_length, num_labels_ner)
44
+
45
+ loss = None
46
+ if labels_class is not None and labels_ner is not None:
47
+ # Classification loss
48
+ loss_fct_class = nn.CrossEntropyLoss()
49
+ loss_class = loss_fct_class(logits_class, labels_class)
50
+
51
+ # NER loss: Cross-entropy with ignore_index=-100, summed then averaged over non-pad tokens
52
+ loss_fct_ner = nn.CrossEntropyLoss(ignore_index=-100, reduction='sum')
53
+ ner_loss_sum = loss_fct_ner(
54
+ logits_ner.view(-1, logits_ner.shape[-1]),
55
+ labels_ner.view(-1)
56
+ )
57
+ mask = (labels_ner != -100).view(-1).float()
58
+ loss_ner = ner_loss_sum / (mask.sum() + 1e-9)
59
+
60
+ # Store initial values
61
+ if self.initial_cls_loss is None and self.training:
62
+ self.initial_cls_loss = loss_class.item()
63
+ if self.initial_ner_loss is None and self.training:
64
+ self.initial_ner_loss = loss_ner.item()
65
+
66
+ # Normalize losses
67
+ if (self.initial_cls_loss is not None) and (self.initial_ner_loss is not None):
68
+ norm_cls_loss = loss_class / (self.initial_cls_loss + 1e-8)
69
+ norm_ner_loss = loss_ner / (self.initial_ner_loss + 1e-8)
70
+ else:
71
+ norm_cls_loss = loss_class
72
+ norm_ner_loss = loss_ner
73
+
74
+ # Combine losses with weight
75
+ loss = norm_cls_loss + self.ner_loss_weight * norm_ner_loss
76
+ return (loss, logits_class, logits_ner)
77
+ else:
78
+ return (logits_class, logits_ner)
79
+ TinyGreekNewsBert.register_for_auto_class("AutoModel")
training_args.bin ADDED
Binary file (5.3 kB). View file