Update ultravox_model.py
Browse files- ultravox_model.py +4 -1
ultravox_model.py
CHANGED
|
@@ -412,7 +412,10 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
| 412 |
cls, config: UltravoxConfig
|
| 413 |
) -> "UltravoxProjector":
|
| 414 |
projector = UltravoxProjector(config)
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
| 416 |
return projector
|
| 417 |
|
| 418 |
@classmethod
|
|
|
|
| 412 |
cls, config: UltravoxConfig
|
| 413 |
) -> "UltravoxProjector":
|
| 414 |
projector = UltravoxProjector(config)
|
| 415 |
+
dtype = config.torch_dtype
|
| 416 |
+
if isinstance(dtype, str):
|
| 417 |
+
dtype = getattr(torch, dtype)
|
| 418 |
+
projector.to(dtype)
|
| 419 |
return projector
|
| 420 |
|
| 421 |
@classmethod
|