Update modeling_medassistgpt.py
Browse files- modeling_medassistgpt.py +14 -0
modeling_medassistgpt.py
CHANGED
|
@@ -105,6 +105,20 @@ class MedAssistGPTModel(PreTrainedModel):
|
|
| 105 |
elif isinstance(module, nn.Embedding):
|
| 106 |
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
def forward(self, input_ids, labels=None):
|
| 109 |
h = self.embed(input_ids)
|
| 110 |
for blk in self.blocks:
|
|
|
|
| 105 |
elif isinstance(module, nn.Embedding):
|
| 106 |
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 107 |
|
| 108 |
+
def prepare_inputs_for_generation(
|
| 109 |
+
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
| 110 |
+
):
|
| 111 |
+
# If past_key_values is provided, only use the last token
|
| 112 |
+
if past_key_values is not None:
|
| 113 |
+
input_ids = input_ids[:, -1:]
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
"input_ids": input_ids,
|
| 117 |
+
"past_key_values": past_key_values,
|
| 118 |
+
"use_cache": use_cache,
|
| 119 |
+
"attention_mask": attention_mask,
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
def forward(self, input_ids, labels=None):
|
| 123 |
h = self.embed(input_ids)
|
| 124 |
for blk in self.blocks:
|