JesseStover commited on
Commit
6b27180
·
verified ·
1 Parent(s): 7ffe52f

Upload code/inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/inference.py +92 -2
code/inference.py CHANGED
@@ -1,11 +1,101 @@
 
 
1
  from sagemaker_inference import encoder
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForMultipleChoice
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  def model_fn(model_dir):
7
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
8
- model = AutoModelForMultipleChoice.from_pretrained(model_dir)
9
  return {"model": model, "tokenizer": tokenizer}
10
 
11
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
  from sagemaker_inference import encoder
4
  import torch
5
+ from torch import nn
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers import AutoTokenizer, BertPreTrainedModel
8
+ from transformers.models.bert import BertModel
9
+ from transformers.modeling_outputs import ModelOutput
10
+
11
+
12
+ @dataclass
13
+ class MultipleChoiceModelOutput(ModelOutput):
14
+ loss: Optional[torch.FloatTensor] = None
15
+ logits: torch.FloatTensor = None
16
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
17
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
18
+
19
+
20
+ class BertForMultipleChoice(BertPreTrainedModel):
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+
24
+ self.bert = BertModel(config)
25
+ classifier_dropout = (
26
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
27
+ )
28
+ self.dropout = nn.Dropout(classifier_dropout)
29
+ self.classifier = nn.Linear(config.hidden_size, 1)
30
+
31
+ # Initialize weights and apply final processing
32
+ self.post_init()
33
+
34
+ def forward(
35
+ self,
36
+ input_ids: Optional[torch.Tensor] = None,
37
+ attention_mask: Optional[torch.Tensor] = None,
38
+ token_type_ids: Optional[torch.Tensor] = None,
39
+ position_ids: Optional[torch.Tensor] = None,
40
+ head_mask: Optional[torch.Tensor] = None,
41
+ inputs_embeds: Optional[torch.Tensor] = None,
42
+ labels: Optional[torch.Tensor] = None,
43
+ output_attentions: Optional[bool] = None,
44
+ output_hidden_states: Optional[bool] = None,
45
+ return_dict: Optional[bool] = None,
46
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
47
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
48
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
49
+
50
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
51
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
52
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
53
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
54
+ inputs_embeds = (
55
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
56
+ if inputs_embeds is not None
57
+ else None
58
+ )
59
+
60
+ outputs = self.bert(
61
+ input_ids,
62
+ attention_mask=attention_mask,
63
+ token_type_ids=token_type_ids,
64
+ position_ids=position_ids,
65
+ head_mask=head_mask,
66
+ inputs_embeds=inputs_embeds,
67
+ output_attentions=output_attentions,
68
+ output_hidden_states=output_hidden_states,
69
+ return_dict=return_dict,
70
+ )
71
+
72
+ pooled_output = outputs[1]
73
+
74
+ pooled_output = self.dropout(pooled_output)
75
+ logits = self.classifier(pooled_output)
76
+ reshaped_logits = logits.view(-1, num_choices)
77
+
78
+ loss = None
79
+ if labels is not None:
80
+ loss_fct = CrossEntropyLoss()
81
+ loss = loss_fct(reshaped_logits, labels)
82
+
83
+ if not return_dict:
84
+ output = (reshaped_logits,) + outputs[2:]
85
+ return ((loss,) + output) if loss is not None else output
86
+
87
+ return MultipleChoiceModelOutput(
88
+ loss=loss,
89
+ logits=reshaped_logits,
90
+ hidden_states=outputs.hidden_states,
91
+ attentions=outputs.attentions,
92
+ )
93
+
94
 
95
 
96
  def model_fn(model_dir):
97
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
98
+ model = BertForMultipleChoice.from_pretrained(model_dir)
99
  return {"model": model, "tokenizer": tokenizer}
100
 
101