Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model (#19)
Browse files- Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model (775f6527d3cfd402c46b03c5fbf355b4f262b705)
Co-authored-by: Tomer Ronen <[email protected]>
- modeling_decilm.py +45 -1
modeling_decilm.py
CHANGED
|
@@ -25,7 +25,7 @@ import torch.utils.checkpoint
|
|
| 25 |
from torch import nn
|
| 26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 27 |
from transformers import GenerationConfig
|
| 28 |
-
from transformers.generation.utils import GenerationMixin,
|
| 29 |
from transformers.modeling_utils import PreTrainedModel
|
| 30 |
from transformers.utils import (
|
| 31 |
add_start_docstrings,
|
|
@@ -1311,6 +1311,50 @@ class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
|
|
| 1311 |
)
|
| 1312 |
return model_inputs
|
| 1313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1314 |
|
| 1315 |
@add_start_docstrings(
|
| 1316 |
"""
|
|
|
|
| 25 |
from torch import nn
|
| 26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 27 |
from transformers import GenerationConfig
|
| 28 |
+
from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING, GenerationMixin, GenerateOutput
|
| 29 |
from transformers.modeling_utils import PreTrainedModel
|
| 30 |
from transformers.utils import (
|
| 31 |
add_start_docstrings,
|
|
|
|
| 1311 |
)
|
| 1312 |
return model_inputs
|
| 1313 |
|
| 1314 |
+
def _maybe_initialize_input_ids_for_generation(
|
| 1315 |
+
self,
|
| 1316 |
+
inputs: Optional[torch.Tensor] = None,
|
| 1317 |
+
bos_token_id: Optional[torch.Tensor] = None,
|
| 1318 |
+
model_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
| 1319 |
+
) -> torch.LongTensor:
|
| 1320 |
+
"""
|
| 1321 |
+
Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
|
| 1322 |
+
"""
|
| 1323 |
+
input_ids = super()._maybe_initialize_input_ids_for_generation(
|
| 1324 |
+
inputs=inputs, bos_token_id=bos_token_id, model_kwargs=model_kwargs)
|
| 1325 |
+
if (
|
| 1326 |
+
"inputs_embeds" in model_kwargs
|
| 1327 |
+
and input_ids is not None
|
| 1328 |
+
and input_ids.shape[1] == 0
|
| 1329 |
+
):
|
| 1330 |
+
batch_size, input_sequence_length = model_kwargs["inputs_embeds"].shape[:2]
|
| 1331 |
+
input_ids = torch.zeros((batch_size, input_sequence_length), dtype=torch.long, device=self.device)
|
| 1332 |
+
return input_ids
|
| 1333 |
+
|
| 1334 |
+
def generate(
|
| 1335 |
+
self,
|
| 1336 |
+
inputs: Optional[torch.Tensor] = None,
|
| 1337 |
+
*args,
|
| 1338 |
+
**kwargs,
|
| 1339 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
| 1340 |
+
"""
|
| 1341 |
+
Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
|
| 1342 |
+
"""
|
| 1343 |
+
only_passed_inputs_embeds = (
|
| 1344 |
+
"inputs_embeds" in kwargs and
|
| 1345 |
+
"input_ids" not in kwargs and
|
| 1346 |
+
inputs is None
|
| 1347 |
+
)
|
| 1348 |
+
if only_passed_inputs_embeds:
|
| 1349 |
+
input_sequence_length = kwargs["inputs_embeds"].shape[1]
|
| 1350 |
+
|
| 1351 |
+
generation_output = super().generate(inputs=inputs, *args, **kwargs)
|
| 1352 |
+
|
| 1353 |
+
if only_passed_inputs_embeds and isinstance(generation_output, torch.Tensor):
|
| 1354 |
+
generation_output = generation_output[:, input_sequence_length:]
|
| 1355 |
+
|
| 1356 |
+
return generation_output
|
| 1357 |
+
|
| 1358 |
|
| 1359 |
@add_start_docstrings(
|
| 1360 |
"""
|