# app.py — Voice Clarity Booster with Presets, CPU/GPU-smart Dual-Stage, # A/B alternating, Loudness Match, and a *polished Delta* (noise-only) option. # # New: # - Delta Mode: Raw Difference | Spectral Residual (noise-only) # - Delta Alignment (cross-correlation) to reduce phase/latency smear # - Delta Gain (dB) + HPF/LPF + RMS leveling for listenable delta import os import tempfile from typing import Tuple, Optional, Dict, Any # ---- Quiet noisy deprecation warnings (optional) ---- import warnings warnings.filterwarnings( "ignore", message=".*torchaudio._backend.list_audio_backends has been deprecated.*", ) warnings.filterwarnings( "ignore", module=r"speechbrain\..*", category=UserWarning, ) import gradio as gr import numpy as np import soundfile as sf import torch import torchaudio # Optional LUFS matching (falls back to RMS if unavailable) try: import pyloudnorm as pyln _HAVE_PYLN = True except Exception: _HAVE_PYLN = False # Prefer new SpeechBrain API; fall back for older versions try: from speechbrain.inference import SpectralMaskEnhancement except Exception: # < 1.0 from speechbrain.pretrained import SpectralMaskEnhancement # type: ignore try: from speechbrain.inference import SepformerSeparation except Exception: from speechbrain.pretrained import SepformerSeparation # type: ignore # ----------------------------- # Environment / runtime limits # ----------------------------- USE_GPU = torch.cuda.is_available() # On CPU, SepFormer is extremely slow; avoid for long clips (or disable). MAX_SEPFORMER_SEC_CPU = float(os.getenv("MAX_SEPFORMER_SEC_CPU", 12)) MAX_SEPFORMER_SEC_GPU = float(os.getenv("MAX_SEPFORMER_SEC_GPU", 180)) ALLOW_SEPFORMER_CPU = os.getenv("ALLOW_SEPFORMER_CPU", "0") == "1" _DEVICE = "cuda" if USE_GPU else "cpu" _ENHANCER_METRICGAN: Optional[SpectralMaskEnhancement] = None _ENHANCER_SEPFORMER: Optional[SepformerSeparation] = None def _get_metricgan() -> SpectralMaskEnhancement: global _ENHANCER_METRICGAN if _ENHANCER_METRICGAN is None: _ENHANCER_METRICGAN = SpectralMaskEnhancement.from_hparams( source="speechbrain/metricgan-plus-voicebank", savedir="pretrained/metricgan_plus_voicebank", run_opts={"device": _DEVICE}, ) return _ENHANCER_METRICGAN def _get_sepformer() -> SepformerSeparation: global _ENHANCER_SEPFORMER if _ENHANCER_SEPFORMER is None: _ENHANCER_SEPFORMER = SepformerSeparation.from_hparams( source="speechbrain/sepformer-whamr-enhancement", savedir="pretrained/sepformer_whamr_enh", run_opts={"device": _DEVICE}, ) return _ENHANCER_SEPFORMER # ----------------------------- # Audio helpers # ----------------------------- def _to_mono(wav: np.ndarray) -> np.ndarray: """Robust mono: accepts [T], [T,C], [C,T]; treats dim<=8 as channels.""" wav = np.asarray(wav, dtype=np.float32) if wav.ndim == 1: return wav if wav.ndim == 2: t, u = wav.shape if 1 in (t, u): return wav.reshape(-1).astype(np.float32) if u <= 8: # [T, C] return wav.mean(axis=1).astype(np.float32) if t <= 8: # [C, T] return wav.mean(axis=0).astype(np.float32) return wav.mean(axis=1).astype(np.float32) return wav.reshape(-1).astype(np.float32) def _sanitize(x: np.ndarray) -> np.ndarray: return np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) def _resample_torch(wav: torch.Tensor, sr_in: int, sr_out: int) -> torch.Tensor: if sr_in == sr_out: return wav return torchaudio.functional.resample(wav, sr_in, sr_out) def _highpass(wav: torch.Tensor, sr: int, cutoff_hz: float) -> torch.Tensor: if cutoff_hz is None or cutoff_hz <= 0: return wav return torchaudio.functional.highpass_biquad(wav, sr, cutoff_hz) def _lowpass(wav: torch.Tensor, sr: int, cutoff_hz: float) -> torch.Tensor: if cutoff_hz is None or cutoff_hz <= 0: return wav return torchaudio.functional.lowpass_biquad(wav, sr, cutoff_hz) def _presence_boost(wav: torch.Tensor, sr: int, gain_db: float) -> torch.Tensor: if abs(gain_db) < 1e-6: return wav center = 4500.0 q = 0.707 return torchaudio.functional.equalizer_biquad(wav, sr, center, q, gain_db) def _limit_peak(wav: torch.Tensor, target_dbfs: float = -1.0) -> torch.Tensor: target_amp = 10.0 ** (target_dbfs / 20.0) peak = torch.max(torch.abs(wav)).item() if peak > 0: wav = wav * min(1.0, target_amp / peak) return torch.clamp(wav, -1.0, 1.0) def _align_lengths(a: np.ndarray, b: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: n = min(len(a), len(b)) return a[:n], b[:n] def _rms(x: np.ndarray, eps: float = 1e-9) -> float: return float(np.sqrt(np.mean(x**2) + eps)) def _rms_target(x: np.ndarray, target_dbfs: float = -20.0) -> np.ndarray: """Scale to approx target dBFS RMS, then hard-limit peaks.""" target_amp = 10.0 ** (target_dbfs / 20.0) cur = _rms(x) if cur > 0: x = x * (target_amp / cur) x = np.clip(x, -1.0, 1.0).astype(np.float32) return x def _loudness_match_to_ref(ref: np.ndarray, cand: np.ndarray, sr: int) -> Tuple[np.ndarray, str]: """Match cand loudness to ref (LUFS if available, else RMS).""" if len(ref) < sr // 10 or len(cand) < sr // 10: return cand, "skipped (clip too short)" if _HAVE_PYLN: try: meter = pyln.Meter(sr) l_ref = meter.integrated_loudness(ref.astype(np.float64)) l_cand = meter.integrated_loudness(cand.astype(np.float64)) gain_db = l_ref - l_cand cand_adj = cand * (10.0 ** (gain_db / 20.0)) return cand_adj.astype(np.float32), f"LUFS matched (Δ {gain_db:+.2f} dB)" except Exception: pass # RMS fallback eps = 1e-9 rms_ref = np.sqrt(np.mean(ref**2) + eps) rms_cand = np.sqrt(np.mean(cand**2) + eps) gain = rms_ref / (rms_cand + eps) cand_adj = cand * gain gain_db = 20.0 * np.log10(gain + eps) return cand_adj.astype(np.float32), f"RMS matched (Δ {gain_db:+.2f} dB)" def _make_ab_alternating(orig: np.ndarray, enh: np.ndarray, sr: int, seg_sec: float = 2.0) -> np.ndarray: """A/B track flips Original→Enhanced every seg_sec.""" seg_n = max(1, int(seg_sec * sr)) orig, enh = _align_lengths(orig, enh) n = len(orig) out = [] pos = 0 flag = True while pos < n: end = min(pos + seg_n, n) out.append(orig[pos:end] if flag else enh[pos:end]) pos = end flag = not flag return np.concatenate(out, axis=0).astype(np.float32) # ----------------------------- # Alignment for delta (cross-correlation) # ----------------------------- def _next_pow_two(n: int) -> int: n -= 1 shift = 1 while (n + 1) & n: n |= n >> shift shift <<= 1 return n + 1 def _align_by_xcorr(a: np.ndarray, b: np.ndarray, max_shift: int) -> Tuple[np.ndarray, np.ndarray, int]: """ Align b to a using FFT cross-correlation. Only accept shifts within ±max_shift. Returns (a_aligned, b_aligned, shift) where positive shift means b lags a and is shifted forward. """ # Pad to same length n = max(len(a), len(b)) a_pad = np.zeros(n, dtype=np.float32); a_pad[:len(a)] = a b_pad = np.zeros(n, dtype=np.float32); b_pad[:len(b)] = b N = _next_pow_two(2 * n - 1) A = np.fft.rfft(a_pad, N) B = np.fft.rfft(b_pad, N) corr = np.fft.irfft(A * np.conj(B), N) # lags: 0..N-1, convert so center at zero lag corr = np.concatenate((corr[-(n-1):], corr[:n])) lags = np.arange(-(n-1), n) # Limit to window w = (lags >= -max_shift) & (lags <= max_shift) lag = int(lags[w][np.argmax(corr[w])]) if lag > 0: # b lags behind a -> shift b forward b_shift = np.concatenate((b[lag:], np.zeros(lag, dtype=np.float32))) a_shift = a[:len(b_shift)] b_shift = b_shift[:len(a_shift)] return a_shift, b_shift, lag elif lag < 0: # a lags -> shift a forward lag = -lag a_shift = np.concatenate((a[lag:], np.zeros(lag, dtype=np.float32))) b_shift = b[:len(a_shift)] a_shift = a_shift[:len(b_shift)] return a_shift, b_shift, -lag else: # no shift a2, b2 = _align_lengths(a, b) return a2, b2, 0 # ----------------------------- # Model runners (with guards) # ----------------------------- def _run_metricgan(path_16k: str) -> torch.Tensor: enh = _get_metricgan() with torch.no_grad(): out = enh.enhance_file(path_16k) # [1, T] return out def _run_sepformer(path_16k: str, dur_sec: float) -> Tuple[Optional[torch.Tensor], Optional[str]]: """Return (tensor, fallback_msg). If not safe to run, returns (None, reason).""" if USE_GPU: if dur_sec > MAX_SEPFORMER_SEC_GPU: return None, f"SepFormer skipped (GPU clip {dur_sec:.1f}s > {MAX_SEPFORMER_SEC_GPU:.0f}s limit)" else: if not ALLOW_SEPFORMER_CPU: return None, "SepFormer disabled on CPU (set ALLOW_SEPFORMER_CPU=1 to force)" if dur_sec > MAX_SEPFORMER_SEC_CPU: return None, f"SepFormer skipped (CPU clip {dur_sec:.1f}s > {MAX_SEPFORMER_SEC_CPU:.0f}s limit)" try: sep = _get_sepformer() with torch.no_grad(): out = sep.separate_file(path=path_16k) if isinstance(out, torch.Tensor): if out.dim() == 1: out = out.unsqueeze(0) elif out.dim() == 2 and out.shape[0] > 1: out = out[:1, :] return out, None if hasattr(out, "numpy"): t = torch.from_numpy(out.numpy()) if t.dim() == 1: t = t.unsqueeze(0) elif t.dim() == 2 and t.shape[0] > 1: t = t[:1, :] return t, None if isinstance(out, (list, tuple)): t = torch.tensor(out[0] if isinstance(out[0], (np.ndarray, list)) else out, dtype=torch.float32) if t.dim() == 1: t = t.unsqueeze(0) return t, None return None, "SepFormer returned unexpected format; skipped" except Exception as e: return None, f"SepFormer error: {e.__class__.__name__}" def _run_dual_stage(path_16k: str, dur_sec: float) -> Tuple[Optional[torch.Tensor], Optional[str]]: """SepFormer → MetricGAN+. Applies same guards; returns (tensor, msg).""" stage1, msg = _run_sepformer(path_16k, dur_sec) if stage1 is None: return None, msg or "SepFormer unavailable" # Save stage1 to temp for MetricGAN with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_mid: sf.write(tmp_mid.name, stage1.squeeze(0).numpy(), 16000, subtype="PCM_16") tmp_mid.flush() mid_path = tmp_mid.name try: stage2 = _run_metricgan(mid_path) return stage2, None except Exception as e: return None, f"MetricGAN after SepFormer failed: {e.__class__.__name__}" finally: try: os.remove(mid_path) except Exception: pass # ----------------------------- # Spectral residual delta (cleaner noise-only preview) # ----------------------------- def _delta_spectral_residual(orig: np.ndarray, enh: np.ndarray, sr: int) -> np.ndarray: """ Build a noise-focused residual via STFT magnitudes: R_mag = ReLU(|X| - |Y|) use original phase for iSTFT reconstruction Then gentle HPF/LPF and RMS leveling for listenability. """ # Torch tensors x = torch.from_numpy(orig).to(torch.float32) y = torch.from_numpy(enh).to(torch.float32) n_fft = 1024 hop = 256 win = torch.hann_window(n_fft) # STFTs X = torch.stft(x, n_fft=n_fft, hop_length=hop, window=win, return_complex=True, center=True) Y = torch.stft(y, n_fft=n_fft, hop_length=hop, window=win, return_complex=True, center=True) # Positive residual magnitudes R_mag = torch.relu(torch.abs(X) - torch.abs(Y)) # Mild temporal smoothing (moving average across time) R_mag = torch.nn.functional.avg_pool1d( R_mag.unsqueeze(0), kernel_size=3, stride=1, padding=1 ).squeeze(0) # Reconstruct residual with original phase phase = torch.angle(X) R_complex = torch.polar(R_mag, phase) r = torch.istft(R_complex, n_fft=n_fft, hop_length=hop, window=win, length=len(orig)) # HPF/LPF + light RMS leveling for comfort r_t = r.unsqueeze(0) r_t = _highpass(r_t, sr, cutoff_hz=80.0) r_t = _lowpass(r_t, sr, cutoff_hz=9000.0) r_np = r_t.squeeze(0).numpy().astype(np.float32) r_np = _rms_target(r_np, target_dbfs=-24.0) return r_np # ----------------------------- # Core pipeline # ----------------------------- def _enhance_numpy_audio( audio: Tuple[int, np.ndarray], mode: str = "MetricGAN+ (denoise)", dry_wet: float = 1.0, # 0..1 presence_db: float = 0.0, lowcut_hz: float = 0.0, out_sr: Optional[int] = None, loudness_match: bool = True, ) -> Tuple[int, np.ndarray, str]: """ Returns: (sr_out, enhanced, metrics_text) """ sr_in, wav_np = audio wav_mono = _sanitize(_to_mono(wav_np)) if wav_mono.size < 32: sr_out = sr_in if sr_in else 16000 silence = np.zeros(int(sr_out * 1.0), dtype=np.float32) return sr_out, silence, "Input too short; returned silence." dry_t = torch.from_numpy(wav_mono).unsqueeze(0) # [1, T @ sr_in] wav_16k = _resample_torch(dry_t, sr_in, 16000) dur_sec = float(wav_16k.shape[-1]) / 16000.0 # Write temp input for model runners with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in: sf.write(tmp_in.name, wav_16k.squeeze(0).numpy(), 16000, subtype="PCM_16") tmp_in.flush() path_16k = tmp_in.name fallback_note = None try: if mode.startswith("MetricGAN"): proc = _run_metricgan(path_16k) elif mode.startswith("SepFormer"): proc, msg = _run_sepformer(path_16k, dur_sec) if proc is None: proc = wav_16k # bypass fallback_note = f"[Fallback→Bypass] {msg}" elif mode.startswith("Dual-Stage"): proc, msg = _run_dual_stage(path_16k, dur_sec) if proc is None: # fall back to MetricGAN if SepFormer not possible try: proc = _run_metricgan(path_16k) fallback_note = f"[Fallback→MetricGAN+] {msg}" except Exception as e: proc = wav_16k # ultimate fallback: bypass fallback_note = f"[Fallback→Bypass] {msg or ''} / MetricGAN error: {e.__class__.__name__}" else: # Bypass (EQ only) proc = wav_16k finally: try: os.remove(path_16k) except Exception: pass # Polish on processed only proc = _highpass(proc, 16000, lowcut_hz) proc = _presence_boost(proc, 16000, presence_db) proc = _limit_peak(proc, target_dbfs=-1.0) # Resample both to output rate for mixing & export sr_out = sr_in if (out_sr is None or out_sr <= 0) else int(out_sr) proc_out = _resample_torch(proc, 16000, sr_out).squeeze(0).numpy().astype(np.float32) dry_out = _resample_torch(dry_t, sr_in, sr_out).squeeze(0).numpy().astype(np.float32) # Mix dry/wet proc_out, dry_out = _align_lengths(proc_out, dry_out) dry_wet = float(np.clip(dry_wet, 0.0, 1.0)) enhanced = proc_out * dry_wet + dry_out * (1.0 - dry_wet) # Loudness match loud_text = "off" if loudness_match: enhanced, loud_text = _loudness_match_to_ref(dry_out, enhanced, sr_out) enhanced = _sanitize(enhanced) # Metrics eps = 1e-9 rms_delta_hint = np.sqrt(np.mean((dry_out - enhanced)**2) + eps) metrics = ( f"Mode: {mode} | Dry/Wet: {dry_wet*100:.0f}% | Presence: {presence_db:+.1f} dB | " f"Low-cut: {lowcut_hz:.0f} Hz | Loudness match: {loud_text} | Device: {'GPU' if USE_GPU else 'CPU'} | " f"Clip @16k: {dur_sec:.2f}s" ) if fallback_note: metrics += f"\n{fallback_note}" metrics += f"\nΔ (raw) RMS: {20*np.log10(rms_delta_hint+eps):+.2f} dBFS" return sr_out, enhanced, metrics # ----------------------------- # Presets # ----------------------------- PRESETS: Dict[str, Dict[str, Any]] = { "Ultimate Clean Voice": { "mode": "Dual-Stage (SepFormer → MetricGAN+)", "dry_wet": 0.92, "presence_db": 1.5, "lowcut_hz": 80.0, "loudness_match": True, }, "Natural Speech": { "mode": "MetricGAN+ (denoise)", "dry_wet": 0.85, "presence_db": 1.0, "lowcut_hz": 50.0, "loudness_match": True, }, "Podcast Studio": { "mode": "MetricGAN+ (denoise)", "dry_wet": 0.90, "presence_db": 2.0, "lowcut_hz": 75.0, "loudness_match": True, }, "Room Dereverb": { "mode": "SepFormer (dereverb+denoise)", "dry_wet": 0.70, "presence_db": 0.5, "lowcut_hz": 60.0, "loudness_match": True, }, "Music + Voice Safe": { "mode": "MetricGAN+ (denoise)", "dry_wet": 0.60, "presence_db": 0.0, "lowcut_hz": 40.0, "loudness_match": True, }, "Phone Call Rescue": { "mode": "MetricGAN+ (denoise)", "dry_wet": 0.88, "presence_db": 2.0, "lowcut_hz": 100.0, "loudness_match": True, }, "Gentle Denoise": { "mode": "MetricGAN+ (denoise)", "dry_wet": 0.65, "presence_db": 0.0, "lowcut_hz": 0.0, "loudness_match": True, }, "Custom": {} } def _apply_preset(preset_name: str): cfg = PRESETS.get(preset_name, {}) def upd(val=None): return gr.update(value=val) if val is not None else gr.update() if not cfg or preset_name == "Custom": return upd(), upd(), upd(), upd(), upd() return ( upd(cfg["mode"]), upd(int(round(cfg["dry_wet"] * 100))), upd(float(cfg["presence_db"])), upd(float(cfg["lowcut_hz"])), upd(bool(cfg["loudness_match"])), ) # ----------------------------- # Gradio UI # ----------------------------- def gradio_enhance( audio: Tuple[int, np.ndarray], mode: str, dry_wet_pct: float, presence_db: float, lowcut_hz: float, output_sr: str, loudness_match: bool, delta_mode: str, delta_align: bool, delta_gain_db: float, ): if audio is None: return None, None, None, "No audio provided." out_sr = None if output_sr in {"44100", "48000"}: out_sr = int(output_sr) # Enhance sr_out, enhanced, metrics = _enhance_numpy_audio( audio, mode=mode, dry_wet=dry_wet_pct / 100.0, presence_db=float(presence_db), lowcut_hz=float(lowcut_hz), out_sr=out_sr, loudness_match=bool(loudness_match), ) # Build A/B and Delta (polished) sr_in, wav_np = audio orig_mono = _sanitize(_to_mono(wav_np)) orig_at_out = _resample_torch(torch.from_numpy(orig_mono).unsqueeze(0), sr_in, sr_out).squeeze(0).numpy().astype(np.float32) # Optional alignment to reduce phase/latency offsets a_for_ab, b_for_ab = _align_lengths(orig_at_out, enhanced) if delta_align: max_shift = int(0.05 * sr_out) # up to 50 ms a_for_ab, b_for_ab, lag = _align_by_xcorr(a_for_ab, b_for_ab, max_shift=max_shift) metrics += f"\nDelta alignment: shift={lag} samples" # A/B alternating ab_alt = _make_ab_alternating(a_for_ab, b_for_ab, sr_out, seg_sec=2.0) # Delta (noise-focused if selected) if delta_mode.startswith("Spectral"): delta = _delta_spectral_residual(a_for_ab, b_for_ab, sr_out) else: delta = a_for_ab - b_for_ab # Gentle polish on raw difference d_t = torch.from_numpy(delta).unsqueeze(0) d_t = _highpass(d_t, sr_out, cutoff_hz=80.0) d_t = _lowpass(d_t, sr_out, cutoff_hz=9000.0) delta = d_t.squeeze(0).numpy().astype(np.float32) delta = _rms_target(delta, target_dbfs=-24.0) # Apply user delta gain delta *= 10.0 ** (delta_gain_db / 20.0) delta = np.clip(delta, -1.0, 1.0).astype(np.float32) return (sr_out, enhanced), (sr_out, ab_alt), (sr_out, delta), metrics with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( f"## Voice Clarity Booster — Presets, A/B, *Polished Delta*, Loudness Match \n" f"**Device:** {'GPU' if USE_GPU else 'CPU'} · " f"SepFormer limits — CPU≤{MAX_SEPFORMER_SEC_CPU:.0f}s, GPU≤{MAX_SEPFORMER_SEC_GPU:.0f}s" + ("" if USE_GPU or ALLOW_SEPFORMER_CPU else " · (SepFormer disabled on CPU)") ) with gr.Row(): with gr.Column(scale=1): in_audio = gr.Audio( sources=["upload", "microphone"], type="numpy", label="Input", ) preset = gr.Dropdown( choices=list(PRESETS.keys()), value="Ultimate Clean Voice", label="Preset", ) mode = gr.Radio( choices=[ "MetricGAN+ (denoise)", "SepFormer (dereverb+denoise)", "Dual-Stage (SepFormer → MetricGAN+)", "Bypass (EQ only)" ], value="Dual-Stage (SepFormer → MetricGAN+)", label="Mode", ) dry_wet = gr.Slider( minimum=0, maximum=100, value=92, step=1, label="Dry/Wet Mix (%) — lower to reduce artifacts" ) presence = gr.Slider( minimum=-12, maximum=12, value=1.5, step=0.5, label="Presence Boost (dB)" ) lowcut = gr.Slider( minimum=0, maximum=200, value=80, step=5, label="Low-Cut (Hz)" ) loudmatch = gr.Checkbox(value=True, label="Loudness-match enhanced to original") out_sr = gr.Radio( choices=["Original", "44100", "48000"], value="Original", label="Output Sample Rate", ) # Delta controls gr.Markdown("### Delta (what changed)") delta_mode = gr.Dropdown( choices=["Spectral Residual (noise-only)", "Raw Difference"], value="Spectral Residual (noise-only)", label="Delta Mode", ) delta_align = gr.Checkbox(value=True, label="Align original & enhanced for delta (recommended)") delta_gain = gr.Slider(minimum=-12, maximum=24, value=6, step=1, label="Delta Gain (dB)") preset.change( _apply_preset, inputs=[preset], outputs=[mode, dry_wet, presence, lowcut, loudmatch], ) btn = gr.Button("Enhance", variant="primary") with gr.Column(scale=1): out_audio = gr.Audio(type="numpy", label="Enhanced (autoplay)", autoplay=True) ab_audio = gr.Audio(type="numpy", label="A/B Alternating (2s O → 2s E)") delta_audio = gr.Audio(type="numpy", label="Delta (polished)") metrics = gr.Markdown("") btn.click( gradio_enhance, inputs=[in_audio, mode, dry_wet, presence, lowcut, out_sr, loudmatch, delta_mode, delta_align, delta_gain], outputs=[out_audio, ab_audio, delta_audio, metrics], ) # Launch unguarded so Spaces initializes demo.launch()