kunjcr2 commited on
Commit
43e40b8
·
verified ·
1 Parent(s): 50c8b86

Update modeling_medassistgpt.py

Browse files
Files changed (1) hide show
  1. modeling_medassistgpt.py +14 -0
modeling_medassistgpt.py CHANGED
@@ -105,6 +105,20 @@ class MedAssistGPTModel(PreTrainedModel):
105
  elif isinstance(module, nn.Embedding):
106
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def forward(self, input_ids, labels=None):
109
  h = self.embed(input_ids)
110
  for blk in self.blocks:
 
105
  elif isinstance(module, nn.Embedding):
106
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
107
 
108
+ def prepare_inputs_for_generation(
109
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
110
+ ):
111
+ # If past_key_values is provided, only use the last token
112
+ if past_key_values is not None:
113
+ input_ids = input_ids[:, -1:]
114
+
115
+ return {
116
+ "input_ids": input_ids,
117
+ "past_key_values": past_key_values,
118
+ "use_cache": use_cache,
119
+ "attention_mask": attention_mask,
120
+ }
121
+
122
  def forward(self, input_ids, labels=None):
123
  h = self.embed(input_ids)
124
  for blk in self.blocks: