Example code fails

#4
by RaivisDejus - opened

Example code fails on regular laptop

torch 2.7.1
torchaudio 2.7.1
transformers 4.53.2

You are using a model of type gigaam-rnnt to instantiate a model of type gigaam. This is not supported for all configurations of models and can yield errors.
Traceback (most recent call last):
  File "/home/raivis/Code/buzz/test.py", line 19, in <module>
    pred_ids = model.generate(**input_features)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/.cache/huggingface/modules/transformers_modules/waveletdeboshir/gigaam-rnnt/da91709246e4987257758882a11026674a25fb00/gigaam_transformers.py", line 442, in generate
    encoder_out, encoded_lengths = self.model.encoder(input_features, input_lengths)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/.cache/huggingface/modules/transformers_modules/waveletdeboshir/gigaam-rnnt/da91709246e4987257758882a11026674a25fb00/encoder.py", line 597, in forward
    audio_signal = layer(
                   ^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/.cache/huggingface/modules/transformers_modules/waveletdeboshir/gigaam-rnnt/da91709246e4987257758882a11026674a25fb00/encoder.py", line 480, in forward
    x = self.self_attn(x, x, x, pos_emb, mask=att_mask)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/.cache/huggingface/modules/transformers_modules/waveletdeboshir/gigaam-rnnt/da91709246e4987257758882a11026674a25fb00/encoder.py", line 264, in forward
    query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/.cache/huggingface/modules/transformers_modules/waveletdeboshir/gigaam-rnnt/da91709246e4987257758882a11026674a25fb00/encoder.py", line 38, in apply_rotary_pos_emb
    return (q * cos) + (rtt_half(q) * sin), (k * cos) + (rtt_half(k) * sin)
            ~~^~~~~
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_prims_common/wrappers.py", line 308, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_prims_common/wrappers.py", line 149, in _fn
    result = fn(**bound.arguments)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 1086, in _ref
    output = prim(a, b)
             ^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 1693, in mul
    return prims.mul(a, b)
           ^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_ops.py", line 756, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_library/fake_impl.py", line 95, in meta_kernel
    return fake_impl_holder.kernel(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_library/utils.py", line 32, in __call__
    return self.func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/library.py", line 1383, in inner
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 622, in fake_impl
    return self._abstract_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_prims/__init__.py", line 404, in _prim_elementwise_meta
    utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
  File "/home/raivis/Code/buzz/venv/lib/python3.12/site-packages/torch/_prims_common/__init__.py", line 779, in check_same_device
    raise RuntimeError(msg)
RuntimeError: Tensor on device meta is not on the expected device cpu!

Hello. Yes, there is such a problem with transformers >= 4.50. I haven't tried to fix this yet. You can use transformers==4.49.0.

Sign up or log in to comment