Commit
·
742a8e3
1
Parent(s):
988edaf
support generation on backpacks by overloading prepare_inputs_for_generation.
Browse files- modeling_backpack_gpt2.py +42 -2
modeling_backpack_gpt2.py
CHANGED
|
@@ -153,7 +153,7 @@ class BackpackGPT2Model(BackpackGPT2PreTrainedModel):
|
|
| 153 |
def get_sense_network(self):
|
| 154 |
return self.sense_network
|
| 155 |
|
| 156 |
-
def forward(self, input_ids, position_ids):
|
| 157 |
# Compute senses
|
| 158 |
sense_input_embeds = self.word_embeddings(input_ids)
|
| 159 |
senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
|
|
@@ -205,8 +205,48 @@ class BackpackGPT2LMHeadModel(BackpackGPT2PreTrainedModel):
|
|
| 205 |
|
| 206 |
def get_lm_head(self):
|
| 207 |
return self.lm_head
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
-
def forward(self, input_ids, position_ids=None):
|
| 210 |
outputs = self.backpack(input_ids, position_ids=position_ids)
|
| 211 |
hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
|
| 212 |
lm_logits = self.lm_head(hidden_states) # (bs, s, V)
|
|
|
|
| 153 |
def get_sense_network(self):
|
| 154 |
return self.sense_network
|
| 155 |
|
| 156 |
+
def forward(self, input_ids, position_ids, **kwargs):
|
| 157 |
# Compute senses
|
| 158 |
sense_input_embeds = self.word_embeddings(input_ids)
|
| 159 |
senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
|
|
|
|
| 205 |
|
| 206 |
def get_lm_head(self):
|
| 207 |
return self.lm_head
|
| 208 |
+
|
| 209 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None, **kwargs):
|
| 210 |
+
# prepare_inputs_for_generation needs to be overwritten to support generation
|
| 211 |
+
# this is inspired from the one in GPT2LMHeadModel: https://github.com/huggingface/transformers/blob/d533465150532b0c5de167b574e59f64c68b1154/src/transformers/models/gpt2/modeling_gpt2.py#L1007C4-L1007C4
|
| 212 |
+
|
| 213 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
| 214 |
+
# only last token for inputs_ids if past is defined in kwargs
|
| 215 |
+
if past_key_values:
|
| 216 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 217 |
+
if token_type_ids is not None:
|
| 218 |
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
| 219 |
+
|
| 220 |
+
attention_mask = kwargs.get("attention_mask", None)
|
| 221 |
+
position_ids = kwargs.get("position_ids", None)
|
| 222 |
+
|
| 223 |
+
if attention_mask is not None and position_ids is None:
|
| 224 |
+
# create position_ids on the fly for batch generation
|
| 225 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 226 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 227 |
+
if past_key_values:
|
| 228 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
| 229 |
+
else:
|
| 230 |
+
position_ids = None
|
| 231 |
+
|
| 232 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 233 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 234 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 235 |
+
else:
|
| 236 |
+
model_inputs = {"input_ids": input_ids}
|
| 237 |
+
|
| 238 |
+
model_inputs.update(
|
| 239 |
+
{
|
| 240 |
+
"past_key_values": past_key_values,
|
| 241 |
+
"use_cache": kwargs.get("use_cache"),
|
| 242 |
+
"position_ids": position_ids,
|
| 243 |
+
"attention_mask": attention_mask,
|
| 244 |
+
"token_type_ids": token_type_ids,
|
| 245 |
+
}
|
| 246 |
+
)
|
| 247 |
+
return model_inputs
|
| 248 |
|
| 249 |
+
def forward(self, input_ids, position_ids=None, **kwargs):
|
| 250 |
outputs = self.backpack(input_ids, position_ids=position_ids)
|
| 251 |
hidden_states, contextualization = outputs.hidden_states, outputs.contextualization
|
| 252 |
lm_logits = self.lm_head(hidden_states) # (bs, s, V)
|