File size: 5,935 Bytes
f1b856f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import os
import shutil
import warnings
from pathlib import Path
from typing import Literal
import numpy as np
from pydub import AudioSegment
try:
from trackio.media.media import TrackioMedia
from trackio.media.utils import check_ffmpeg_installed, check_path
except ImportError:
from media.media import TrackioMedia
from media.utils import check_ffmpeg_installed, check_path
SUPPORTED_FORMATS = ["wav", "mp3"]
AudioFormatType = Literal["wav", "mp3"]
TrackioAudioSourceType = str | Path | np.ndarray
class TrackioAudio(TrackioMedia):
"""
Initializes an Audio object.
Example:
```python
import trackio
import numpy as np
# Generate a 1-second 440 Hz sine wave (mono)
sr = 16000
t = np.linspace(0, 1, sr, endpoint=False)
wave = 0.2 * np.sin(2 * np.pi * 440 * t)
audio = trackio.Audio(wave, caption="A4 sine", sample_rate=sr, format="wav")
trackio.log({"tone": audio})
# Stereo from numpy array (shape: samples, 2)
stereo = np.stack([wave, wave], axis=1)
audio = trackio.Audio(stereo, caption="Stereo", sample_rate=sr, format="mp3")
trackio.log({"stereo": audio})
# From an existing file
audio = trackio.Audio("path/to/audio.wav", caption="From file")
trackio.log({"file_audio": audio})
```
Args:
value (`str`, `Path`, or `numpy.ndarray`, *optional*):
A path to an audio file, or a numpy array.
The array should be shaped `(samples,)` for mono or `(samples, 2)` for stereo.
Float arrays will be peak-normalized and converted to 16-bit PCM; integer arrays will be converted to 16-bit PCM as needed.
caption (`str`, *optional*):
A string caption for the audio.
sample_rate (`int`, *optional*):
Sample rate in Hz. Required when `value` is a numpy array.
format (`Literal["wav", "mp3"]`, *optional*):
Audio format used when `value` is a numpy array. Default is "wav".
"""
TYPE = "trackio.audio"
def __init__(
self,
value: TrackioAudioSourceType,
caption: str | None = None,
sample_rate: int | None = None,
format: AudioFormatType | None = None,
):
super().__init__(value, caption)
if isinstance(value, np.ndarray):
if sample_rate is None:
raise ValueError("Sample rate is required when value is an ndarray")
if format is None:
format = "wav"
self._format = format
self._sample_rate = sample_rate
def _save_media(self, file_path: Path):
if isinstance(self._value, np.ndarray):
TrackioAudio.write_audio(
data=self._value,
sample_rate=self._sample_rate,
filename=file_path,
format=self._format,
)
elif isinstance(self._value, str | Path):
if os.path.isfile(self._value):
shutil.copy(self._value, file_path)
else:
raise ValueError(f"File not found: {self._value}")
@staticmethod
def ensure_int16_pcm(data: np.ndarray) -> np.ndarray:
"""
Convert input audio array to contiguous int16 PCM.
Peak normalization is applied to floating inputs.
"""
arr = np.asarray(data)
if arr.ndim not in (1, 2):
raise ValueError("Audio data must be 1D (mono) or 2D ([samples, channels])")
if arr.dtype != np.int16:
warnings.warn(
f"Converting {arr.dtype} audio to int16 PCM; pass int16 to avoid conversion.",
stacklevel=2,
)
arr = np.nan_to_num(arr, copy=False)
# Floating types: normalize to peak 1.0, then scale to int16
if np.issubdtype(arr.dtype, np.floating):
max_abs = float(np.max(np.abs(arr))) if arr.size else 0.0
if max_abs > 0.0:
arr = arr / max_abs
out = (arr * 32767.0).clip(-32768, 32767).astype(np.int16, copy=False)
return np.ascontiguousarray(out)
converters: dict[np.dtype, callable] = {
np.dtype(np.int16): lambda a: a,
np.dtype(np.int32): lambda a: (
(a.astype(np.int32) // 65536).astype(np.int16, copy=False)
),
np.dtype(np.uint16): lambda a: (
(a.astype(np.int32) - 32768).astype(np.int16, copy=False)
),
np.dtype(np.uint8): lambda a: (
(a.astype(np.int32) * 257 - 32768).astype(np.int16, copy=False)
),
np.dtype(np.int8): lambda a: (
(a.astype(np.int32) * 256).astype(np.int16, copy=False)
),
}
conv = converters.get(arr.dtype)
if conv is not None:
out = conv(arr)
return np.ascontiguousarray(out)
raise TypeError(f"Unsupported audio dtype: {arr.dtype}")
@staticmethod
def write_audio(
data: np.ndarray,
sample_rate: int,
filename: str | Path,
format: AudioFormatType = "wav",
) -> None:
if not isinstance(sample_rate, int) or sample_rate <= 0:
raise ValueError(f"Invalid sample_rate: {sample_rate}")
if format not in SUPPORTED_FORMATS:
raise ValueError(
f"Unsupported format: {format}. Supported: {SUPPORTED_FORMATS}"
)
check_path(filename)
pcm = TrackioAudio.ensure_int16_pcm(data)
if format != "wav":
check_ffmpeg_installed()
channels = 1 if pcm.ndim == 1 else pcm.shape[1]
audio = AudioSegment(
pcm.tobytes(),
frame_rate=sample_rate,
sample_width=2, # int16
channels=channels,
)
file = audio.export(str(filename), format=format)
file.close()
|