Example code fails
#4
by
RaivisDejus
- opened
Example code fails on regular laptop
torch 2.7.1
torchaudio 2.7.1
transformers 4.53.2
You are using a model of type gigaam-rnnt to instantiate a model of type gigaam. This is not supported for all configurations of models and can yield errors.
Traceback (most recent call last):
File "/home/raivis/Code/buzz/test.py", line 19, in <module>
pred_ids = model.generate(**input_features)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/.cache/huggingface/modules/transformers_modules/waveletdeboshir/gigaam-rnnt/da91709246e4987257758882a11026674a25fb00/gigaam_transformers.py", line 442, in generate
encoder_out, encoded_lengths = self.model.encoder(input_features, input_lengths)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/.cache/huggingface/modules/transformers_modules/waveletdeboshir/gigaam-rnnt/da91709246e4987257758882a11026674a25fb00/encoder.py", line 597, in forward
audio_signal = layer(
^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/.cache/huggingface/modules/transformers_modules/waveletdeboshir/gigaam-rnnt/da91709246e4987257758882a11026674a25fb00/encoder.py", line 480, in forward
x = self.self_attn(x, x, x, pos_emb, mask=att_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/.cache/huggingface/modules/transformers_modules/waveletdeboshir/gigaam-rnnt/da91709246e4987257758882a11026674a25fb00/encoder.py", line 264, in forward
query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/.cache/huggingface/modules/transformers_modules/waveletdeboshir/gigaam-rnnt/da91709246e4987257758882a11026674a25fb00/encoder.py", line 38, in apply_rotary_pos_emb
return (q * cos) + (rtt_half(q) * sin), (k * cos) + (rtt_half(k) * sin)
~~^~~~~
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_prims_common/wrappers.py", line 308, in _fn
result = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_prims_common/wrappers.py", line 149, in _fn
result = fn(**bound.arguments)
^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 1086, in _ref
output = prim(a, b)
^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 1693, in mul
return prims.mul(a, b)
^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_ops.py", line 756, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_library/fake_impl.py", line 95, in meta_kernel
return fake_impl_holder.kernel(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_library/utils.py", line 32, in __call__
return self.func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/library.py", line 1383, in inner
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 622, in fake_impl
return self._abstract_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_prims/__init__.py", line 404, in _prim_elementwise_meta
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_prims_common/__init__.py", line 779, in check_same_device
raise RuntimeError(msg)
RuntimeError: Tensor on device meta is not on the expected device cpu!
Hello. Yes, there is such a problem with transformers >= 4.50. I haven't tried to fix this yet. You can use transformers==4.49.0.