import os import shutil import warnings from pathlib import Path from typing import Literal import numpy as np from pydub import AudioSegment try: from trackio.media.media import TrackioMedia from trackio.media.utils import check_ffmpeg_installed, check_path except ImportError: from media.media import TrackioMedia from media.utils import check_ffmpeg_installed, check_path SUPPORTED_FORMATS = ["wav", "mp3"] AudioFormatType = Literal["wav", "mp3"] TrackioAudioSourceType = str | Path | np.ndarray class TrackioAudio(TrackioMedia): """ Initializes an Audio object. Example: ```python import trackio import numpy as np # Generate a 1-second 440 Hz sine wave (mono) sr = 16000 t = np.linspace(0, 1, sr, endpoint=False) wave = 0.2 * np.sin(2 * np.pi * 440 * t) audio = trackio.Audio(wave, caption="A4 sine", sample_rate=sr, format="wav") trackio.log({"tone": audio}) # Stereo from numpy array (shape: samples, 2) stereo = np.stack([wave, wave], axis=1) audio = trackio.Audio(stereo, caption="Stereo", sample_rate=sr, format="mp3") trackio.log({"stereo": audio}) # From an existing file audio = trackio.Audio("path/to/audio.wav", caption="From file") trackio.log({"file_audio": audio}) ``` Args: value (`str`, `Path`, or `numpy.ndarray`, *optional*): A path to an audio file, or a numpy array. The array should be shaped `(samples,)` for mono or `(samples, 2)` for stereo. Float arrays will be peak-normalized and converted to 16-bit PCM; integer arrays will be converted to 16-bit PCM as needed. caption (`str`, *optional*): A string caption for the audio. sample_rate (`int`, *optional*): Sample rate in Hz. Required when `value` is a numpy array. format (`Literal["wav", "mp3"]`, *optional*): Audio format used when `value` is a numpy array. Default is "wav". """ TYPE = "trackio.audio" def __init__( self, value: TrackioAudioSourceType, caption: str | None = None, sample_rate: int | None = None, format: AudioFormatType | None = None, ): super().__init__(value, caption) if isinstance(value, np.ndarray): if sample_rate is None: raise ValueError("Sample rate is required when value is an ndarray") if format is None: format = "wav" self._format = format self._sample_rate = sample_rate def _save_media(self, file_path: Path): if isinstance(self._value, np.ndarray): TrackioAudio.write_audio( data=self._value, sample_rate=self._sample_rate, filename=file_path, format=self._format, ) elif isinstance(self._value, str | Path): if os.path.isfile(self._value): shutil.copy(self._value, file_path) else: raise ValueError(f"File not found: {self._value}") @staticmethod def ensure_int16_pcm(data: np.ndarray) -> np.ndarray: """ Convert input audio array to contiguous int16 PCM. Peak normalization is applied to floating inputs. """ arr = np.asarray(data) if arr.ndim not in (1, 2): raise ValueError("Audio data must be 1D (mono) or 2D ([samples, channels])") if arr.dtype != np.int16: warnings.warn( f"Converting {arr.dtype} audio to int16 PCM; pass int16 to avoid conversion.", stacklevel=2, ) arr = np.nan_to_num(arr, copy=False) # Floating types: normalize to peak 1.0, then scale to int16 if np.issubdtype(arr.dtype, np.floating): max_abs = float(np.max(np.abs(arr))) if arr.size else 0.0 if max_abs > 0.0: arr = arr / max_abs out = (arr * 32767.0).clip(-32768, 32767).astype(np.int16, copy=False) return np.ascontiguousarray(out) converters: dict[np.dtype, callable] = { np.dtype(np.int16): lambda a: a, np.dtype(np.int32): lambda a: ( (a.astype(np.int32) // 65536).astype(np.int16, copy=False) ), np.dtype(np.uint16): lambda a: ( (a.astype(np.int32) - 32768).astype(np.int16, copy=False) ), np.dtype(np.uint8): lambda a: ( (a.astype(np.int32) * 257 - 32768).astype(np.int16, copy=False) ), np.dtype(np.int8): lambda a: ( (a.astype(np.int32) * 256).astype(np.int16, copy=False) ), } conv = converters.get(arr.dtype) if conv is not None: out = conv(arr) return np.ascontiguousarray(out) raise TypeError(f"Unsupported audio dtype: {arr.dtype}") @staticmethod def write_audio( data: np.ndarray, sample_rate: int, filename: str | Path, format: AudioFormatType = "wav", ) -> None: if not isinstance(sample_rate, int) or sample_rate <= 0: raise ValueError(f"Invalid sample_rate: {sample_rate}") if format not in SUPPORTED_FORMATS: raise ValueError( f"Unsupported format: {format}. Supported: {SUPPORTED_FORMATS}" ) check_path(filename) pcm = TrackioAudio.ensure_int16_pcm(data) if format != "wav": check_ffmpeg_installed() channels = 1 if pcm.ndim == 1 else pcm.shape[1] audio = AudioSegment( pcm.tobytes(), frame_rate=sample_rate, sample_width=2, # int16 channels=channels, ) file = audio.export(str(filename), format=format) file.close()