MLSpeech commited on
Commit
4fe4892
·
verified ·
1 Parent(s): 33b8135

Allow upload of all file types

Browse files
Files changed (1) hide show
  1. app.py +414 -414
app.py CHANGED
@@ -1,414 +1,414 @@
1
- import argparse
2
- import logging
3
- from pathlib import Path
4
- from tqdm import tqdm
5
- import torch
6
- import torchaudio
7
- import soundfile as sf
8
- import time
9
- from typing import TypedDict
10
- from enum import Enum
11
- import gradio as gr
12
-
13
- SR = 16000
14
- VAD_EXPAND_HEAD_SEC = 0.2
15
- VAD_EXPAND_TAIL_SEC = 0.2
16
-
17
-
18
- class SPEECH_ARRAY_INDEX(TypedDict):
19
- """
20
- TypedDict for representing speech segments in audio.
21
- This dictionary contains the start and end indices of a speech segment retrieved from VAD processing.
22
-
23
- Args:
24
- start (float): Start index of the speech segment in samples.
25
- end (float): End index of the speech segment in samples.
26
- """
27
- start: float
28
- end: float
29
-
30
-
31
- class SilenceTrimMode(Enum):
32
- """
33
- Enumeration for different silence trimming modes in audio processing.
34
-
35
- This enum defines various options for trimming silence from audio segments,
36
- allowing fine-grained control over which parts of the audio should have
37
- silence removed.
38
-
39
- Attributes:
40
- LEADING (str): Remove silence only from the beginning of the audio.
41
- TRAILING (str): Remove silence only from the end of the audio.
42
- EDGES (str): Remove silence from both the beginning and end of the audio.
43
- ALL (str): Remove all silence segments throughout the entire audio.
44
- """
45
-
46
- LEADING = "leading"
47
- TRAILING = "trailing"
48
- EDGES = "edges"
49
- ALL = "all"
50
-
51
-
52
- class VAD:
53
- def __init__(
54
- self,
55
- sr: int,
56
- remove_short: bool = False,
57
- pad_segments: bool = True,
58
- expand_head_sec: float = VAD_EXPAND_HEAD_SEC,
59
- expand_tail_sec: float = VAD_EXPAND_TAIL_SEC,
60
- trim_mode: SilenceTrimMode = SilenceTrimMode.EDGES,
61
- ):
62
- """Initialize the VAD processor.
63
-
64
- Args:
65
- sr (int): Sampling rate of input audio.
66
- remove_short (bool): Whether to remove short speech segments. Default is False.
67
- pad_segments (bool): Whether to expand detected segments with padding. Default is True.
68
- expand_head_sec (float): Padding in seconds to add before each segment. Default is 0.2.
69
- expand_tail_sec (float): Padding in seconds to add after each segment. Default is 0.2.
70
- trim_mode (SilenceTrimMode): Mode to use for trimming silence. Default is trim silence from edges. Options are:
71
- - SilenceTrimMode.LEADING: Remove silence only from the beginning.
72
- - SilenceTrimMode.TRAILING: Remove silence only from the end.
73
- - SilenceTrimMode.EDGES: Remove silence from both the beginning and end.
74
- - SilenceTrimMode.ALL: Remove all silence segments throughout the audio.
75
- """
76
- self.sr = sr
77
- self.pad_segments = pad_segments
78
- self.remove_short = remove_short
79
- self.expand_head_sec = expand_head_sec
80
- self.expand_tail_sec = expand_tail_sec
81
- self.trim_mode = trim_mode
82
- self.min_segment_dur = 1.0
83
-
84
- vad_components = torch.hub.load(
85
- repo_or_dir="snakers4/silero-vad",
86
- model="silero_vad",
87
- trust_repo=True,
88
- skip_validation=True,
89
- )
90
- self.vad_model, utils = vad_components # type: ignore
91
- self._detect_speech, _, _, *_ = utils
92
-
93
- def _remove_short_segments(self, segments: list[SPEECH_ARRAY_INDEX]) -> list[SPEECH_ARRAY_INDEX]:
94
- """Remove speech segments shorter than the configured minimum duration."""
95
- return [s for s in segments if s["end"] - s["start"] > self.min_segment_dur * self.sr]
96
-
97
- def _expand_segments(
98
- self, segments: list[SPEECH_ARRAY_INDEX], expand_head: int, expand_tail: int, total_length: int
99
- ) -> list[SPEECH_ARRAY_INDEX]:
100
- """Expand speech segments with padding before and after, constrained by surrounding segments and total length.
101
-
102
- Args:
103
- segments (list[SPEECH_ARRAY_INDEX]): List of speech segments.
104
- expand_head (int): Padding to add before each segment in samples.
105
- expand_tail (int): Padding to add after each segment in samples.
106
- total_length (int): Total length of the audio in samples.
107
-
108
- Returns:
109
- list[SPEECH_ARRAY_INDEX]: Expanded list of speech segments.
110
- """
111
- results = []
112
- for i, t in enumerate(segments):
113
- start = max(t["start"] - expand_head, segments[i - 1]["end"] if i > 0 else 0)
114
- end = min(t["end"] + expand_tail, segments[i + 1]["start"] if i < len(segments) - 1 else total_length)
115
- results.append({"start": start, "end": end})
116
- return results
117
-
118
- def _postprocess_segments(
119
- self, segments: list[SPEECH_ARRAY_INDEX], audio_len: int
120
- ) -> list[SPEECH_ARRAY_INDEX]:
121
- """Apply filtering and padding to detected speech segments. If no segments are detected, return a default segment covering the entire audio.
122
-
123
- Args:
124
- segments (list[SPEECH_ARRAY_INDEX]): Detected speech segments.
125
- audio_len (int): Length of the audio signal in samples. Used to ensure segments do not exceed audio length.
126
-
127
- Returns:
128
- list[SPEECH_ARRAY_INDEX]: Postprocessed speech segments.
129
- """
130
- if self.remove_short:
131
- segments = self._remove_short_segments(segments)
132
- if self.pad_segments:
133
- expand_head = int(self.expand_head_sec * self.sr)
134
- expand_tail = int(self.expand_tail_sec * self.sr)
135
- segments = self._expand_segments(segments, expand_head, expand_tail, audio_len)
136
- return segments if segments else [{"start": 0, "end": audio_len}]
137
-
138
- def _trim_audio(self, audio: torch.Tensor, segments: list[SPEECH_ARRAY_INDEX]) -> torch.Tensor:
139
- """Trim the input audio tensor according to the configured silence trim mode.
140
-
141
- Args:
142
- audio (torch.Tensor): Input audio tensor.
143
- segments (list[SPEECH_ARRAY_INDEX]): Processed speech segments.
144
-
145
- Returns:
146
- torch.Tensor: Trimmed audio tensor.
147
- """
148
- if not segments:
149
- return audio.unsqueeze(0)
150
-
151
- if self.trim_mode is SilenceTrimMode.ALL:
152
- speech = torch.cat([audio[int(s["start"]):int(s["end"])] for s in segments])
153
- else:
154
- first_start = int(segments[0]["start"])
155
- last_end = int(segments[-1]["end"])
156
- if self.trim_mode is SilenceTrimMode.LEADING:
157
- speech = audio[first_start:]
158
- elif self.trim_mode is SilenceTrimMode.TRAILING:
159
- speech = audio[:last_end]
160
- elif self.trim_mode is SilenceTrimMode.EDGES:
161
- speech = audio[first_start:last_end]
162
- else:
163
- raise ValueError(f"Unsupported trim_mode: {self.trim_mode}")
164
-
165
- return speech.unsqueeze(0)
166
-
167
- def __call__(self, audio: torch.Tensor) -> torch.Tensor:
168
- """Apply VAD processing and silence trimming to an audio tensor.
169
-
170
- Args:
171
- audio (torch.Tensor): Audio tensor, either [samples] or [1, samples].
172
-
173
- Returns:
174
- torch.Tensor: Trimmed audio tensor with silence removed.
175
- """
176
- if audio.dim() == 2:
177
- audio = audio[0]
178
-
179
- tic = time.time()
180
- segments = self._detect_speech(audio, model=self.vad_model, sampling_rate=self.sr)
181
- segments = self._postprocess_segments(segments, len(audio))
182
- logging.debug(f"Detected speech in {time.time() - tic:.1f} sec")
183
-
184
- return self._trim_audio(audio, segments)
185
-
186
-
187
- def preprocess_input_lst(input_lst_path: str) -> list[Path]:
188
- """
189
- Load a list of audio file paths from a text file.
190
-
191
- Args:
192
- input_lst_path (str): Path to a text file containing audio file paths, one per line.
193
-
194
- Returns:
195
- list[Path]: List of audio file paths.
196
- """
197
- with open(input_lst_path, "r") as f:
198
- return [Path(line.strip()) for line in f if line.strip()]
199
-
200
-
201
- def preprocess_input_dir(input_dir: Path) -> list[Path]:
202
- """
203
- Recursively collect all .wav audio file paths from a directory.
204
-
205
- Args:
206
- input_dir (Path): Path to the base directory to search for .wav files.
207
-
208
- Returns:
209
- list[Path]: List of full paths to .wav files.
210
- """
211
- return list(input_dir.rglob("*.wav"))
212
-
213
-
214
- def setup_logger(log_file: Path, verbose: bool = False) -> None:
215
- """
216
- Configure the logging module to write to file and stdout.
217
-
218
- Args:
219
- log_file (Path): Path to the log file.
220
- verbose (bool, optional): Whether to enable verbose logging. Defaults to False.
221
- """
222
- log_file.parent.mkdir(parents=True, exist_ok=True)
223
- logging.basicConfig(
224
- level=logging.INFO if not verbose else logging.DEBUG,
225
- format="%(asctime)s [%(levelname)s] %(message)s",
226
- handlers=[logging.FileHandler(log_file, mode="w"), logging.StreamHandler()],
227
- )
228
-
229
-
230
- def apply_vad(
231
- input_lst: list[Path],
232
- output_dir: Path,
233
- input_base_dir: str | Path | None = None,
234
- expand_head_sec: float = VAD_EXPAND_HEAD_SEC,
235
- expand_tail_sec: float = VAD_EXPAND_TAIL_SEC,
236
- trim_mode: SilenceTrimMode = SilenceTrimMode.EDGES,
237
- ) -> None:
238
- """
239
- Apply VAD to a list of input audio files and save the processed outputs.
240
-
241
- Args:
242
- input_lst (list[Path]): List of audio file paths to process.
243
- output_dir (Path): Directory to save the processed audio files.
244
- input_base_dir (str | Path | None, optional): If provided, preserve directory structure relative to this base.
245
- """
246
- logging.info(f"Processing {len(input_lst)} files from {input_base_dir} to {output_dir}")
247
- logging.info(f"Creating VAD model with sampling rate {SR} and expand head {expand_head_sec} sec")
248
- vad = VAD(
249
- sr=SR, pad_segments=True, expand_head_sec=expand_head_sec, expand_tail_sec=expand_tail_sec, trim_mode=trim_mode
250
- )
251
- for wav_file in tqdm(input_lst, desc="Applying VAD"):
252
- try:
253
- if input_base_dir is not None:
254
- # Keep tree hierarchy relative to base dir
255
- rel_path = wav_file.relative_to(input_base_dir)
256
- out_file = output_dir / rel_path
257
- else:
258
- # Copy to output dir as is (just the filename)
259
- out_file = output_dir / (wav_file.stem + "_vad" + wav_file.suffix)
260
-
261
- out_file.parent.mkdir(parents=True, exist_ok=True)
262
-
263
- audio, sr = torchaudio.load(str(wav_file))
264
- if sr != SR:
265
- audio = torchaudio.functional.resample(audio, sr, SR)
266
- sr = SR
267
-
268
- audio_vad = vad(audio)
269
- sf.write(out_file, audio_vad.squeeze().numpy(), sr)
270
- logging.debug(f"Saved: {out_file}")
271
-
272
- except Exception as e:
273
- logging.error(f"Failed to process {wav_file}: {e}")
274
- print(f"VAD processing complete. Processed {len(input_lst)} files. Outputs saved to {output_dir}")
275
-
276
-
277
- def apply_vad_gradio(wav_file):
278
- vad = VAD(sr=SR, pad_segments=True, expand_head_sec=0.2, expand_tail_sec=0.2, trim_mode=SilenceTrimMode.EDGES)
279
- audio, sr = torchaudio.load(str(wav_file))
280
- if sr != SR:
281
- audio = torchaudio.functional.resample(audio, sr, SR)
282
- sr = SR
283
- audio_vad = vad(audio)
284
- sf.write("output.wav", audio_vad.squeeze().numpy(), sr)
285
- return 'output.wav'
286
-
287
- def parse_args() -> argparse.Namespace:
288
- """
289
- Parse command-line arguments for the VAD processing script.
290
-
291
- Returns:
292
- argparse.Namespace: Parsed arguments.
293
- """
294
- parser = argparse.ArgumentParser(description="Apply VAD to all .wav files in a directory tree.")
295
- parser.add_argument(
296
- "--input_dir",
297
- type=Path,
298
- help="Path to input directory. Also used as the base input directory for relative paths.",
299
- )
300
- parser.add_argument("--input_lst", type=Path, help="Path to input list file with audio paths")
301
- parser.add_argument("--output_dir", type=Path, help="Path to output directory")
302
- parser.add_argument("--debug_file", type=Path, help="Optional: Path to a single file to test VAD on")
303
- parser.add_argument("--expand_head_sec", type=float, default=VAD_EXPAND_HEAD_SEC)
304
- parser.add_argument("--expand_tail_sec", type=float, default=VAD_EXPAND_TAIL_SEC)
305
- parser.add_argument(
306
- "--trim_mode",
307
- type=str,
308
- default=SilenceTrimMode.EDGES.value,
309
- choices=[m.value for m in SilenceTrimMode],
310
- help="Silence trim mode: " + ", ".join(m.value for m in SilenceTrimMode),
311
- )
312
- parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
313
- args = parser.parse_args()
314
-
315
- # Validation logic
316
- if args.debug_file:
317
- # Debug mode - only debug_file is needed
318
- if args.input_dir or args.input_lst or args.output_dir:
319
- parser.error("When using --debug_file, do not provide --input_dir, --input_lst, or --output_dir.")
320
- else:
321
- # Normal mode - need output_dir and either input_dir or input_lst
322
- if not args.output_dir:
323
- parser.error("--output_dir is required when not using --debug_file.")
324
- if not args.input_dir and not args.input_lst:
325
- parser.error("Either --input_dir or --input_lst must be provided when not using --debug_file.")
326
- args.trim_mode = SilenceTrimMode(args.trim_mode)
327
- return args
328
-
329
-
330
- def run_debug_file(
331
- debug_file: str,
332
- expand_head_sec: float = VAD_EXPAND_HEAD_SEC,
333
- expand_tail_sec: float = VAD_EXPAND_TAIL_SEC,
334
- trim_mode: SilenceTrimMode = SilenceTrimMode.EDGES,
335
- ) -> None:
336
- """
337
- Run VAD on a single debug audio file and save the result.
338
-
339
- Args:
340
- debug_file (str): Path to the debug audio file.
341
- expand_head_sec (float): Padding duration in seconds before each segment.
342
- expand_tail_sec (float): Padding duration in seconds after each segment.
343
- """
344
- debug_path = Path(debug_file).resolve()
345
-
346
- logging.info(f"Running VAD debug on: {debug_path}")
347
- audio, sr = torchaudio.load(debug_path)
348
-
349
- if sr != SR:
350
- logging.info(f"Resampling from {sr} → {SR}")
351
- audio = torchaudio.functional.resample(audio, sr, SR)
352
- sr = SR
353
-
354
- vad = VAD(
355
- sr=SR, pad_segments=True, expand_head_sec=expand_head_sec, expand_tail_sec=expand_tail_sec, trim_mode=trim_mode
356
- )
357
- audio_vad = vad(audio)
358
-
359
- out_path = debug_path.with_name(debug_path.stem + "_vad.wav")
360
- sf.write(out_path, audio_vad.squeeze().numpy(), sr)
361
- logging.info(f"Saved VAD output to: {out_path}")
362
-
363
-
364
- with gr.Blocks() as demo:
365
- with gr.Row():
366
- inputFile = gr.File(label="wav files", file_count="single", file_types=[".wav"])
367
- runbtn = gr.Button("Run")
368
- audio = gr.Audio(label="output")
369
- runbtn.click(fn=apply_vad_gradio, inputs=[inputFile], outputs=audio)
370
-
371
- if __name__ == "__main__":
372
- demo.launch(ssr_mode=False)
373
- # Optional: override args for debugging
374
- #import sys
375
-
376
- # sys.argv = [
377
- # "script.py",
378
- # "--output_dir",
379
- # "/mlspeech/data/eyalcohen/datasets/intelligibility/sandi2025_challenge/tts_data/debug_train_files/with_vad_head_03_tail_03",
380
- # "--input_lst",
381
- # "/mlspeech/data/eyalcohen/datasets/intelligibility/sandi2025_challenge/tts_data/with_vad/normalized/debug_train_files.txt",
382
- # "--expand_head_sec",
383
- # "0.3",
384
- # "--expand_tail_sec",
385
- # "0.3",
386
- # "--verbose",
387
- # # "--debug_file",
388
- # # "/mlspeech/data/eyalcohen/datasets/intelligibility/sandi2025_challenge/tts_data/no_vad/normalized/train/sla-P1/SI137O-00982-P10005-AM_FENRIR.wav",
389
- # ]
390
-
391
-
392
-
393
- #
394
- # args = parse_args()
395
- # log_file = args.output_dir / "vad_processing.log"
396
- # setup_logger(log_file, verbose=args.verbose)
397
- # if args.debug_file:
398
- # run_debug_file(args.debug_file, args.expand_head_sec, args.expand_tail_sec, args.trim_mode)
399
- # else:
400
- # if args.input_lst:
401
- # input_lst = preprocess_input_lst(args.input_lst)
402
- # elif args.input_dir:
403
- # input_lst = preprocess_input_dir(args.input_dir)
404
- # else:
405
- # raise ValueError("Either --input_lst or --input_dir must be provided.")
406
- # apply_vad(
407
- # input_lst,
408
- # args.output_dir,
409
- # args.input_dir,
410
- # args.expand_head_sec,
411
- # args.expand_tail_sec,
412
- # trim_mode=args.trim_mode,
413
- # )
414
-
 
1
+ import argparse
2
+ import logging
3
+ from pathlib import Path
4
+ from tqdm import tqdm
5
+ import torch
6
+ import torchaudio
7
+ import soundfile as sf
8
+ import time
9
+ from typing import TypedDict
10
+ from enum import Enum
11
+ import gradio as gr
12
+
13
+ SR = 16000
14
+ VAD_EXPAND_HEAD_SEC = 0.2
15
+ VAD_EXPAND_TAIL_SEC = 0.2
16
+
17
+
18
+ class SPEECH_ARRAY_INDEX(TypedDict):
19
+ """
20
+ TypedDict for representing speech segments in audio.
21
+ This dictionary contains the start and end indices of a speech segment retrieved from VAD processing.
22
+
23
+ Args:
24
+ start (float): Start index of the speech segment in samples.
25
+ end (float): End index of the speech segment in samples.
26
+ """
27
+ start: float
28
+ end: float
29
+
30
+
31
+ class SilenceTrimMode(Enum):
32
+ """
33
+ Enumeration for different silence trimming modes in audio processing.
34
+
35
+ This enum defines various options for trimming silence from audio segments,
36
+ allowing fine-grained control over which parts of the audio should have
37
+ silence removed.
38
+
39
+ Attributes:
40
+ LEADING (str): Remove silence only from the beginning of the audio.
41
+ TRAILING (str): Remove silence only from the end of the audio.
42
+ EDGES (str): Remove silence from both the beginning and end of the audio.
43
+ ALL (str): Remove all silence segments throughout the entire audio.
44
+ """
45
+
46
+ LEADING = "leading"
47
+ TRAILING = "trailing"
48
+ EDGES = "edges"
49
+ ALL = "all"
50
+
51
+
52
+ class VAD:
53
+ def __init__(
54
+ self,
55
+ sr: int,
56
+ remove_short: bool = False,
57
+ pad_segments: bool = True,
58
+ expand_head_sec: float = VAD_EXPAND_HEAD_SEC,
59
+ expand_tail_sec: float = VAD_EXPAND_TAIL_SEC,
60
+ trim_mode: SilenceTrimMode = SilenceTrimMode.EDGES,
61
+ ):
62
+ """Initialize the VAD processor.
63
+
64
+ Args:
65
+ sr (int): Sampling rate of input audio.
66
+ remove_short (bool): Whether to remove short speech segments. Default is False.
67
+ pad_segments (bool): Whether to expand detected segments with padding. Default is True.
68
+ expand_head_sec (float): Padding in seconds to add before each segment. Default is 0.2.
69
+ expand_tail_sec (float): Padding in seconds to add after each segment. Default is 0.2.
70
+ trim_mode (SilenceTrimMode): Mode to use for trimming silence. Default is trim silence from edges. Options are:
71
+ - SilenceTrimMode.LEADING: Remove silence only from the beginning.
72
+ - SilenceTrimMode.TRAILING: Remove silence only from the end.
73
+ - SilenceTrimMode.EDGES: Remove silence from both the beginning and end.
74
+ - SilenceTrimMode.ALL: Remove all silence segments throughout the audio.
75
+ """
76
+ self.sr = sr
77
+ self.pad_segments = pad_segments
78
+ self.remove_short = remove_short
79
+ self.expand_head_sec = expand_head_sec
80
+ self.expand_tail_sec = expand_tail_sec
81
+ self.trim_mode = trim_mode
82
+ self.min_segment_dur = 1.0
83
+
84
+ vad_components = torch.hub.load(
85
+ repo_or_dir="snakers4/silero-vad",
86
+ model="silero_vad",
87
+ trust_repo=True,
88
+ skip_validation=True,
89
+ )
90
+ self.vad_model, utils = vad_components # type: ignore
91
+ self._detect_speech, _, _, *_ = utils
92
+
93
+ def _remove_short_segments(self, segments: list[SPEECH_ARRAY_INDEX]) -> list[SPEECH_ARRAY_INDEX]:
94
+ """Remove speech segments shorter than the configured minimum duration."""
95
+ return [s for s in segments if s["end"] - s["start"] > self.min_segment_dur * self.sr]
96
+
97
+ def _expand_segments(
98
+ self, segments: list[SPEECH_ARRAY_INDEX], expand_head: int, expand_tail: int, total_length: int
99
+ ) -> list[SPEECH_ARRAY_INDEX]:
100
+ """Expand speech segments with padding before and after, constrained by surrounding segments and total length.
101
+
102
+ Args:
103
+ segments (list[SPEECH_ARRAY_INDEX]): List of speech segments.
104
+ expand_head (int): Padding to add before each segment in samples.
105
+ expand_tail (int): Padding to add after each segment in samples.
106
+ total_length (int): Total length of the audio in samples.
107
+
108
+ Returns:
109
+ list[SPEECH_ARRAY_INDEX]: Expanded list of speech segments.
110
+ """
111
+ results = []
112
+ for i, t in enumerate(segments):
113
+ start = max(t["start"] - expand_head, segments[i - 1]["end"] if i > 0 else 0)
114
+ end = min(t["end"] + expand_tail, segments[i + 1]["start"] if i < len(segments) - 1 else total_length)
115
+ results.append({"start": start, "end": end})
116
+ return results
117
+
118
+ def _postprocess_segments(
119
+ self, segments: list[SPEECH_ARRAY_INDEX], audio_len: int
120
+ ) -> list[SPEECH_ARRAY_INDEX]:
121
+ """Apply filtering and padding to detected speech segments. If no segments are detected, return a default segment covering the entire audio.
122
+
123
+ Args:
124
+ segments (list[SPEECH_ARRAY_INDEX]): Detected speech segments.
125
+ audio_len (int): Length of the audio signal in samples. Used to ensure segments do not exceed audio length.
126
+
127
+ Returns:
128
+ list[SPEECH_ARRAY_INDEX]: Postprocessed speech segments.
129
+ """
130
+ if self.remove_short:
131
+ segments = self._remove_short_segments(segments)
132
+ if self.pad_segments:
133
+ expand_head = int(self.expand_head_sec * self.sr)
134
+ expand_tail = int(self.expand_tail_sec * self.sr)
135
+ segments = self._expand_segments(segments, expand_head, expand_tail, audio_len)
136
+ return segments if segments else [{"start": 0, "end": audio_len}]
137
+
138
+ def _trim_audio(self, audio: torch.Tensor, segments: list[SPEECH_ARRAY_INDEX]) -> torch.Tensor:
139
+ """Trim the input audio tensor according to the configured silence trim mode.
140
+
141
+ Args:
142
+ audio (torch.Tensor): Input audio tensor.
143
+ segments (list[SPEECH_ARRAY_INDEX]): Processed speech segments.
144
+
145
+ Returns:
146
+ torch.Tensor: Trimmed audio tensor.
147
+ """
148
+ if not segments:
149
+ return audio.unsqueeze(0)
150
+
151
+ if self.trim_mode is SilenceTrimMode.ALL:
152
+ speech = torch.cat([audio[int(s["start"]):int(s["end"])] for s in segments])
153
+ else:
154
+ first_start = int(segments[0]["start"])
155
+ last_end = int(segments[-1]["end"])
156
+ if self.trim_mode is SilenceTrimMode.LEADING:
157
+ speech = audio[first_start:]
158
+ elif self.trim_mode is SilenceTrimMode.TRAILING:
159
+ speech = audio[:last_end]
160
+ elif self.trim_mode is SilenceTrimMode.EDGES:
161
+ speech = audio[first_start:last_end]
162
+ else:
163
+ raise ValueError(f"Unsupported trim_mode: {self.trim_mode}")
164
+
165
+ return speech.unsqueeze(0)
166
+
167
+ def __call__(self, audio: torch.Tensor) -> torch.Tensor:
168
+ """Apply VAD processing and silence trimming to an audio tensor.
169
+
170
+ Args:
171
+ audio (torch.Tensor): Audio tensor, either [samples] or [1, samples].
172
+
173
+ Returns:
174
+ torch.Tensor: Trimmed audio tensor with silence removed.
175
+ """
176
+ if audio.dim() == 2:
177
+ audio = audio[0]
178
+
179
+ tic = time.time()
180
+ segments = self._detect_speech(audio, model=self.vad_model, sampling_rate=self.sr)
181
+ segments = self._postprocess_segments(segments, len(audio))
182
+ logging.debug(f"Detected speech in {time.time() - tic:.1f} sec")
183
+
184
+ return self._trim_audio(audio, segments)
185
+
186
+
187
+ def preprocess_input_lst(input_lst_path: str) -> list[Path]:
188
+ """
189
+ Load a list of audio file paths from a text file.
190
+
191
+ Args:
192
+ input_lst_path (str): Path to a text file containing audio file paths, one per line.
193
+
194
+ Returns:
195
+ list[Path]: List of audio file paths.
196
+ """
197
+ with open(input_lst_path, "r") as f:
198
+ return [Path(line.strip()) for line in f if line.strip()]
199
+
200
+
201
+ def preprocess_input_dir(input_dir: Path) -> list[Path]:
202
+ """
203
+ Recursively collect all .wav audio file paths from a directory.
204
+
205
+ Args:
206
+ input_dir (Path): Path to the base directory to search for .wav files.
207
+
208
+ Returns:
209
+ list[Path]: List of full paths to .wav files.
210
+ """
211
+ return list(input_dir.rglob("*.wav"))
212
+
213
+
214
+ def setup_logger(log_file: Path, verbose: bool = False) -> None:
215
+ """
216
+ Configure the logging module to write to file and stdout.
217
+
218
+ Args:
219
+ log_file (Path): Path to the log file.
220
+ verbose (bool, optional): Whether to enable verbose logging. Defaults to False.
221
+ """
222
+ log_file.parent.mkdir(parents=True, exist_ok=True)
223
+ logging.basicConfig(
224
+ level=logging.INFO if not verbose else logging.DEBUG,
225
+ format="%(asctime)s [%(levelname)s] %(message)s",
226
+ handlers=[logging.FileHandler(log_file, mode="w"), logging.StreamHandler()],
227
+ )
228
+
229
+
230
+ def apply_vad(
231
+ input_lst: list[Path],
232
+ output_dir: Path,
233
+ input_base_dir: str | Path | None = None,
234
+ expand_head_sec: float = VAD_EXPAND_HEAD_SEC,
235
+ expand_tail_sec: float = VAD_EXPAND_TAIL_SEC,
236
+ trim_mode: SilenceTrimMode = SilenceTrimMode.EDGES,
237
+ ) -> None:
238
+ """
239
+ Apply VAD to a list of input audio files and save the processed outputs.
240
+
241
+ Args:
242
+ input_lst (list[Path]): List of audio file paths to process.
243
+ output_dir (Path): Directory to save the processed audio files.
244
+ input_base_dir (str | Path | None, optional): If provided, preserve directory structure relative to this base.
245
+ """
246
+ logging.info(f"Processing {len(input_lst)} files from {input_base_dir} to {output_dir}")
247
+ logging.info(f"Creating VAD model with sampling rate {SR} and expand head {expand_head_sec} sec")
248
+ vad = VAD(
249
+ sr=SR, pad_segments=True, expand_head_sec=expand_head_sec, expand_tail_sec=expand_tail_sec, trim_mode=trim_mode
250
+ )
251
+ for wav_file in tqdm(input_lst, desc="Applying VAD"):
252
+ try:
253
+ if input_base_dir is not None:
254
+ # Keep tree hierarchy relative to base dir
255
+ rel_path = wav_file.relative_to(input_base_dir)
256
+ out_file = output_dir / rel_path
257
+ else:
258
+ # Copy to output dir as is (just the filename)
259
+ out_file = output_dir / (wav_file.stem + "_vad" + wav_file.suffix)
260
+
261
+ out_file.parent.mkdir(parents=True, exist_ok=True)
262
+
263
+ audio, sr = torchaudio.load(str(wav_file))
264
+ if sr != SR:
265
+ audio = torchaudio.functional.resample(audio, sr, SR)
266
+ sr = SR
267
+
268
+ audio_vad = vad(audio)
269
+ sf.write(out_file, audio_vad.squeeze().numpy(), sr)
270
+ logging.debug(f"Saved: {out_file}")
271
+
272
+ except Exception as e:
273
+ logging.error(f"Failed to process {wav_file}: {e}")
274
+ print(f"VAD processing complete. Processed {len(input_lst)} files. Outputs saved to {output_dir}")
275
+
276
+
277
+ def apply_vad_gradio(wav_file):
278
+ vad = VAD(sr=SR, pad_segments=True, expand_head_sec=0.2, expand_tail_sec=0.2, trim_mode=SilenceTrimMode.EDGES)
279
+ audio, sr = torchaudio.load(str(wav_file))
280
+ if sr != SR:
281
+ audio = torchaudio.functional.resample(audio, sr, SR)
282
+ sr = SR
283
+ audio_vad = vad(audio)
284
+ sf.write("output.wav", audio_vad.squeeze().numpy(), sr)
285
+ return 'output.wav'
286
+
287
+ def parse_args() -> argparse.Namespace:
288
+ """
289
+ Parse command-line arguments for the VAD processing script.
290
+
291
+ Returns:
292
+ argparse.Namespace: Parsed arguments.
293
+ """
294
+ parser = argparse.ArgumentParser(description="Apply VAD to all .wav files in a directory tree.")
295
+ parser.add_argument(
296
+ "--input_dir",
297
+ type=Path,
298
+ help="Path to input directory. Also used as the base input directory for relative paths.",
299
+ )
300
+ parser.add_argument("--input_lst", type=Path, help="Path to input list file with audio paths")
301
+ parser.add_argument("--output_dir", type=Path, help="Path to output directory")
302
+ parser.add_argument("--debug_file", type=Path, help="Optional: Path to a single file to test VAD on")
303
+ parser.add_argument("--expand_head_sec", type=float, default=VAD_EXPAND_HEAD_SEC)
304
+ parser.add_argument("--expand_tail_sec", type=float, default=VAD_EXPAND_TAIL_SEC)
305
+ parser.add_argument(
306
+ "--trim_mode",
307
+ type=str,
308
+ default=SilenceTrimMode.EDGES.value,
309
+ choices=[m.value for m in SilenceTrimMode],
310
+ help="Silence trim mode: " + ", ".join(m.value for m in SilenceTrimMode),
311
+ )
312
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
313
+ args = parser.parse_args()
314
+
315
+ # Validation logic
316
+ if args.debug_file:
317
+ # Debug mode - only debug_file is needed
318
+ if args.input_dir or args.input_lst or args.output_dir:
319
+ parser.error("When using --debug_file, do not provide --input_dir, --input_lst, or --output_dir.")
320
+ else:
321
+ # Normal mode - need output_dir and either input_dir or input_lst
322
+ if not args.output_dir:
323
+ parser.error("--output_dir is required when not using --debug_file.")
324
+ if not args.input_dir and not args.input_lst:
325
+ parser.error("Either --input_dir or --input_lst must be provided when not using --debug_file.")
326
+ args.trim_mode = SilenceTrimMode(args.trim_mode)
327
+ return args
328
+
329
+
330
+ def run_debug_file(
331
+ debug_file: str,
332
+ expand_head_sec: float = VAD_EXPAND_HEAD_SEC,
333
+ expand_tail_sec: float = VAD_EXPAND_TAIL_SEC,
334
+ trim_mode: SilenceTrimMode = SilenceTrimMode.EDGES,
335
+ ) -> None:
336
+ """
337
+ Run VAD on a single debug audio file and save the result.
338
+
339
+ Args:
340
+ debug_file (str): Path to the debug audio file.
341
+ expand_head_sec (float): Padding duration in seconds before each segment.
342
+ expand_tail_sec (float): Padding duration in seconds after each segment.
343
+ """
344
+ debug_path = Path(debug_file).resolve()
345
+
346
+ logging.info(f"Running VAD debug on: {debug_path}")
347
+ audio, sr = torchaudio.load(debug_path)
348
+
349
+ if sr != SR:
350
+ logging.info(f"Resampling from {sr} → {SR}")
351
+ audio = torchaudio.functional.resample(audio, sr, SR)
352
+ sr = SR
353
+
354
+ vad = VAD(
355
+ sr=SR, pad_segments=True, expand_head_sec=expand_head_sec, expand_tail_sec=expand_tail_sec, trim_mode=trim_mode
356
+ )
357
+ audio_vad = vad(audio)
358
+
359
+ out_path = debug_path.with_name(debug_path.stem + "_vad.wav")
360
+ sf.write(out_path, audio_vad.squeeze().numpy(), sr)
361
+ logging.info(f"Saved VAD output to: {out_path}")
362
+
363
+
364
+ with gr.Blocks() as demo:
365
+ with gr.Row():
366
+ inputFile = gr.File(label="wav files", file_count="single")
367
+ runbtn = gr.Button("Run")
368
+ audio = gr.Audio(label="output")
369
+ runbtn.click(fn=apply_vad_gradio, inputs=[inputFile], outputs=audio)
370
+
371
+ if __name__ == "__main__":
372
+ demo.launch(ssr_mode=False)
373
+ # Optional: override args for debugging
374
+ #import sys
375
+
376
+ # sys.argv = [
377
+ # "script.py",
378
+ # "--output_dir",
379
+ # "/mlspeech/data/eyalcohen/datasets/intelligibility/sandi2025_challenge/tts_data/debug_train_files/with_vad_head_03_tail_03",
380
+ # "--input_lst",
381
+ # "/mlspeech/data/eyalcohen/datasets/intelligibility/sandi2025_challenge/tts_data/with_vad/normalized/debug_train_files.txt",
382
+ # "--expand_head_sec",
383
+ # "0.3",
384
+ # "--expand_tail_sec",
385
+ # "0.3",
386
+ # "--verbose",
387
+ # # "--debug_file",
388
+ # # "/mlspeech/data/eyalcohen/datasets/intelligibility/sandi2025_challenge/tts_data/no_vad/normalized/train/sla-P1/SI137O-00982-P10005-AM_FENRIR.wav",
389
+ # ]
390
+
391
+
392
+
393
+ #
394
+ # args = parse_args()
395
+ # log_file = args.output_dir / "vad_processing.log"
396
+ # setup_logger(log_file, verbose=args.verbose)
397
+ # if args.debug_file:
398
+ # run_debug_file(args.debug_file, args.expand_head_sec, args.expand_tail_sec, args.trim_mode)
399
+ # else:
400
+ # if args.input_lst:
401
+ # input_lst = preprocess_input_lst(args.input_lst)
402
+ # elif args.input_dir:
403
+ # input_lst = preprocess_input_dir(args.input_dir)
404
+ # else:
405
+ # raise ValueError("Either --input_lst or --input_dir must be provided.")
406
+ # apply_vad(
407
+ # input_lst,
408
+ # args.output_dir,
409
+ # args.input_dir,
410
+ # args.expand_head_sec,
411
+ # args.expand_tail_sec,
412
+ # trim_mode=args.trim_mode,
413
+ # )
414
+