File size: 5,935 Bytes
f1b856f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import shutil
import warnings
from pathlib import Path
from typing import Literal

import numpy as np
from pydub import AudioSegment

try:
    from trackio.media.media import TrackioMedia
    from trackio.media.utils import check_ffmpeg_installed, check_path
except ImportError:
    from media.media import TrackioMedia
    from media.utils import check_ffmpeg_installed, check_path

SUPPORTED_FORMATS = ["wav", "mp3"]
AudioFormatType = Literal["wav", "mp3"]
TrackioAudioSourceType = str | Path | np.ndarray


class TrackioAudio(TrackioMedia):
    """
    Initializes an Audio object.

    Example:
        ```python
        import trackio
        import numpy as np

        # Generate a 1-second 440 Hz sine wave (mono)
        sr = 16000
        t = np.linspace(0, 1, sr, endpoint=False)
        wave = 0.2 * np.sin(2 * np.pi * 440 * t)
        audio = trackio.Audio(wave, caption="A4 sine", sample_rate=sr, format="wav")
        trackio.log({"tone": audio})

        # Stereo from numpy array (shape: samples, 2)
        stereo = np.stack([wave, wave], axis=1)
        audio = trackio.Audio(stereo, caption="Stereo", sample_rate=sr, format="mp3")
        trackio.log({"stereo": audio})

        # From an existing file
        audio = trackio.Audio("path/to/audio.wav", caption="From file")
        trackio.log({"file_audio": audio})
        ```

    Args:
        value (`str`, `Path`, or `numpy.ndarray`, *optional*):
            A path to an audio file, or a numpy array.
            The array should be shaped `(samples,)` for mono or `(samples, 2)` for stereo.
            Float arrays will be peak-normalized and converted to 16-bit PCM; integer arrays will be converted to 16-bit PCM as needed.
        caption (`str`, *optional*):
            A string caption for the audio.
        sample_rate (`int`, *optional*):
            Sample rate in Hz. Required when `value` is a numpy array.
        format (`Literal["wav", "mp3"]`, *optional*):
            Audio format used when `value` is a numpy array. Default is "wav".
    """

    TYPE = "trackio.audio"

    def __init__(
        self,
        value: TrackioAudioSourceType,
        caption: str | None = None,
        sample_rate: int | None = None,
        format: AudioFormatType | None = None,
    ):
        super().__init__(value, caption)
        if isinstance(value, np.ndarray):
            if sample_rate is None:
                raise ValueError("Sample rate is required when value is an ndarray")
            if format is None:
                format = "wav"
        self._format = format
        self._sample_rate = sample_rate

    def _save_media(self, file_path: Path):
        if isinstance(self._value, np.ndarray):
            TrackioAudio.write_audio(
                data=self._value,
                sample_rate=self._sample_rate,
                filename=file_path,
                format=self._format,
            )
        elif isinstance(self._value, str | Path):
            if os.path.isfile(self._value):
                shutil.copy(self._value, file_path)
            else:
                raise ValueError(f"File not found: {self._value}")

    @staticmethod
    def ensure_int16_pcm(data: np.ndarray) -> np.ndarray:
        """
        Convert input audio array to contiguous int16 PCM.
        Peak normalization is applied to floating inputs.
        """
        arr = np.asarray(data)
        if arr.ndim not in (1, 2):
            raise ValueError("Audio data must be 1D (mono) or 2D ([samples, channels])")

        if arr.dtype != np.int16:
            warnings.warn(
                f"Converting {arr.dtype} audio to int16 PCM; pass int16 to avoid conversion.",
                stacklevel=2,
            )

        arr = np.nan_to_num(arr, copy=False)

        # Floating types: normalize to peak 1.0, then scale to int16
        if np.issubdtype(arr.dtype, np.floating):
            max_abs = float(np.max(np.abs(arr))) if arr.size else 0.0
            if max_abs > 0.0:
                arr = arr / max_abs
            out = (arr * 32767.0).clip(-32768, 32767).astype(np.int16, copy=False)
            return np.ascontiguousarray(out)

        converters: dict[np.dtype, callable] = {
            np.dtype(np.int16): lambda a: a,
            np.dtype(np.int32): lambda a: (
                (a.astype(np.int32) // 65536).astype(np.int16, copy=False)
            ),
            np.dtype(np.uint16): lambda a: (
                (a.astype(np.int32) - 32768).astype(np.int16, copy=False)
            ),
            np.dtype(np.uint8): lambda a: (
                (a.astype(np.int32) * 257 - 32768).astype(np.int16, copy=False)
            ),
            np.dtype(np.int8): lambda a: (
                (a.astype(np.int32) * 256).astype(np.int16, copy=False)
            ),
        }

        conv = converters.get(arr.dtype)
        if conv is not None:
            out = conv(arr)
            return np.ascontiguousarray(out)
        raise TypeError(f"Unsupported audio dtype: {arr.dtype}")

    @staticmethod
    def write_audio(
        data: np.ndarray,
        sample_rate: int,
        filename: str | Path,
        format: AudioFormatType = "wav",
    ) -> None:
        if not isinstance(sample_rate, int) or sample_rate <= 0:
            raise ValueError(f"Invalid sample_rate: {sample_rate}")
        if format not in SUPPORTED_FORMATS:
            raise ValueError(
                f"Unsupported format: {format}. Supported: {SUPPORTED_FORMATS}"
            )

        check_path(filename)

        pcm = TrackioAudio.ensure_int16_pcm(data)

        if format != "wav":
            check_ffmpeg_installed()

        channels = 1 if pcm.ndim == 1 else pcm.shape[1]
        audio = AudioSegment(
            pcm.tobytes(),
            frame_rate=sample_rate,
            sample_width=2,  # int16
            channels=channels,
        )

        file = audio.export(str(filename), format=format)
        file.close()