Spaces:
Running
Running
| import os | |
| import sys | |
| import time | |
| import subprocess | |
| import threading | |
| import shutil | |
| from pathlib import Path | |
| from typing import Dict, Tuple, Optional, List | |
| from faster_whisper import WhisperModel | |
| import ctranslate2 as ct2 | |
| BASE_DIR = Path(__file__).resolve().parent | |
| UPLOAD_DIR = BASE_DIR / "uploads" | |
| DEFAULT_OUTPUT_DIR = BASE_DIR / "outputs" | |
| UPLOAD_DIR.mkdir(exist_ok=True) | |
| DEFAULT_OUTPUT_DIR.mkdir(exist_ok=True) | |
| def _bundle_base() -> Path: | |
| return BASE_DIR | |
| def _local_model_path(choice: str) -> Optional[Path]: | |
| mroot = _bundle_base() / "models" | |
| key = (choice or "").lower().strip() | |
| folder = {"fast": "small", "balanced": "medium", "best": "large-v2"}.get(key) | |
| if not folder: | |
| return None | |
| p = mroot / folder | |
| return p if p.exists() else None | |
| def _detect_device() -> str: | |
| try: | |
| return "cuda" if ct2.get_cuda_device_count() > 0 else "cpu" | |
| except Exception: | |
| return "cpu" | |
| def _compute_candidates(profile: str, device: str) -> list[str]: | |
| if device == "cuda": | |
| base = {"fast": "int8_float16", "balanced": "float16", "best": "float32"} \ | |
| .get(profile, "float16") | |
| return [base, "float16", "float32"] | |
| else: | |
| base = {"fast": "int8", "balanced": "int16", "best": "float32"} \ | |
| .get(profile, "int16") | |
| return [base, "int16", "float32"] | |
| _model_cache: Dict[str, WhisperModel] = {} | |
| _model_meta: Dict[str, dict] = {} | |
| def get_model(model_choice: str) -> Tuple[WhisperModel, dict]: | |
| key = (model_choice or "fast").lower().strip() | |
| if key in _model_cache: | |
| return _model_cache[key], _model_meta[key] | |
| size = "small" if key == "fast" else "medium" if key == "balanced" else "large-v2" | |
| if key not in ("fast", "balanced", "best"): | |
| key, size = "balanced", "medium" | |
| local = _local_model_path(key) | |
| model_id = str(local) if local else size | |
| device = _detect_device() | |
| candidates = _compute_candidates(key, device) | |
| last_err = None | |
| for compute in candidates: | |
| try: | |
| print(f"[GETSUBTITLES] Loading '{model_id}' device={device} compute_type='{compute}'") | |
| model = WhisperModel(model_id, device=device, compute_type=compute) | |
| meta = { | |
| "model_choice": key, | |
| "model_name": model_id, | |
| "compute_type": compute, | |
| "device": device, | |
| "source": "local" if local else "hub", | |
| } | |
| _model_cache[key] = model | |
| _model_meta[key] = meta | |
| return model, meta | |
| except Exception as e: | |
| print(f"[GETSUBTITLES] Failed compute_type={compute}: {e}") | |
| last_err = e | |
| raise RuntimeError(f"Could not load model on {device} with any compute type. Last error: {last_err}") | |
| def _ffmpeg_path() -> str: | |
| return shutil.which("ffmpeg") or "ffmpeg" | |
| def to_wav16k_mono(src: Path, dst: Path): | |
| cmd = [_ffmpeg_path(), "-y", "-i", str(src), "-ac", "1", "-ar", "16000", str(dst)] | |
| try: | |
| subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| except subprocess.CalledProcessError as e: | |
| print(f"FFmpeg error: {e.stderr.decode()}") | |
| raise | |
| def write_srt_from_segments(segments, out_path: Path): | |
| def fmt(t): | |
| h = int(t // 3600); t -= h*3600 | |
| m = int(t // 60); t -= m*60 | |
| s = int(t); ms = int((t - s)*1000) | |
| return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" | |
| with out_path.open("w", encoding="utf-8") as f: | |
| for i, seg in enumerate(segments, 1): | |
| f.write(f"{i}\n{fmt(seg.start)} --> {fmt(seg.end)}\n{seg.text.strip()}\n\n") | |
| VERT_MAX_CHARS_PER_LINE = 38 | |
| VERT_MAX_WORDS_PER_BLOCK = 10 | |
| VERT_MAX_DURATION_S = 2.2 | |
| VERT_MIN_DURATION_S = 0.7 | |
| PUNCT_BREAK = {".", ",", "!", "?", "…", ":", ";", "—", "–"} | |
| def _fmt_time(seconds: float) -> str: | |
| h = int(seconds // 3600) | |
| m = int((seconds % 3600) // 60) | |
| s = int(seconds % 60) | |
| ms = int(round((seconds - int(seconds)) * 1000)) | |
| return f"{h:02}:{m:02}:{s:02},{ms:03}" | |
| def _clean_spaces(s: str) -> str: | |
| return (s.replace(" ,", ",") | |
| .replace(" .", ".") | |
| .replace(" !", "!") | |
| .replace(" ?", "?") | |
| .replace(" …", "…") | |
| .replace(" :", ":") | |
| .replace(" ;", ";")) | |
| def _should_break(line: str, block_start: float, last_end: float, last_token: str) -> bool: | |
| duration = (last_end - block_start) if (last_end is not None and block_start is not None) else 0.0 | |
| if duration >= VERT_MAX_DURATION_S: | |
| return True | |
| if len(line) >= VERT_MAX_CHARS_PER_LINE - 3 and last_token and last_token[-1] in PUNCT_BREAK: | |
| return True | |
| return False | |
| def build_vertical_blocks(words: List[dict]) -> List[dict]: | |
| blocks = [] | |
| i, n = 0, len(words) | |
| while i < n: | |
| line = "" | |
| block_start = words[i]["start"] | |
| block_end = words[i]["end"] | |
| count = 0 | |
| j = i | |
| while j < n: | |
| w = words[j] | |
| token = w["text"] | |
| extra_len = (1 if line else 0) + len(token) | |
| if len(line) + extra_len > VERT_MAX_CHARS_PER_LINE and count > 0: | |
| break | |
| if count >= VERT_MAX_WORDS_PER_BLOCK: | |
| break | |
| line = f"{line} {token}" if line else token | |
| block_end = w["end"] | |
| count += 1 | |
| j += 1 | |
| if _should_break(line, block_start, block_end, token): | |
| break | |
| while j < n: | |
| duration = block_end - block_start | |
| if duration >= VERT_MIN_DURATION_S: | |
| break | |
| last_token = words[j-1]["text"] if j-1 >= 0 else "" | |
| if last_token and last_token[-1] in {".", "!", "?"}: | |
| break | |
| next_token = words[j]["text"] | |
| extra_len = (1 if line else 0) + len(next_token) | |
| if len(line) + extra_len > VERT_MAX_CHARS_PER_LINE: | |
| break | |
| line = f"{line} {next_token}" | |
| block_end = words[j]["end"] | |
| count += 1 | |
| j += 1 | |
| if line.strip(): | |
| blocks.append({ | |
| "start": block_start, | |
| "end": block_end, | |
| "text": _clean_spaces(line.strip()) | |
| }) | |
| i = j if j > i else i + 1 | |
| return blocks | |
| def write_srt_from_blocks(blocks: List[dict], out_path: Path): | |
| with out_path.open("w", encoding="utf-8") as f: | |
| for idx, b in enumerate(blocks, start=1): | |
| f.write(f"{idx}\n") | |
| f.write(f"{_fmt_time(b['start'])} --> {_fmt_time(b['end'])}\n") | |
| f.write(f"{b['text']}\n\n") | |
| class Job: | |
| def __init__(self, job_id: str, original_name: str, out_dir: Path): | |
| self.job_id = job_id | |
| self.original_name = original_name | |
| self.out_dir = out_dir | |
| self.duration: float = 0.0 | |
| self.progress: float = 0.0 # 0..1 | |
| self.status: str = "queued" # queued|running|done|error | |
| self.error_msg: Optional[str] = None | |
| self.language: Optional[str] = None | |
| self.model_choice: Optional[str] = None | |
| self.model_name: Optional[str] = None | |
| self.compute_type: Optional[str] = None | |
| self.started_at: float = time.time() | |
| self.finished_at: Optional[float] = None | |
| self.srt_path: Optional[Path] = None | |
| JOBS: Dict[str, Job] = {} | |
| JOBS_LOCK = threading.Lock() | |
| def run_transcription(job_id: str, | |
| wav_path: Path, | |
| language: Optional[str], | |
| task: str, | |
| model_choice: str, | |
| out_dir: Path, | |
| style: str): | |
| with JOBS_LOCK: | |
| job = JOBS[job_id] | |
| try: | |
| job.status = "loading_model" | |
| model, meta = get_model(model_choice) | |
| job.model_choice = meta["model_choice"] | |
| job.model_name = meta["model_name"] | |
| job.compute_type = meta["compute_type"] | |
| use_word_ts = (style == "vertical") | |
| segments, info = model.transcribe( | |
| str(wav_path), | |
| task=task, | |
| language=language, | |
| word_timestamps=use_word_ts | |
| ) | |
| job.language = info.language | |
| job.duration = float(info.duration or 0.0) | |
| job.status = "running" | |
| seg_list = [] | |
| last_end = 0.0 | |
| for seg in segments: | |
| seg_list.append(seg) | |
| last_end = max(last_end, float(getattr(seg, "end", 0.0) or 0.0)) | |
| if job.duration > 0: | |
| with JOBS_LOCK: | |
| job.progress = min(last_end / job.duration, 0.999) | |
| srt_path = out_dir / f"{job.original_name}.srt" | |
| if style == "vertical": | |
| words = [] | |
| for seg in seg_list: | |
| if getattr(seg, "words", None): | |
| for w in seg.words: | |
| token = (w.word or "").strip() | |
| if not token or w.start is None or w.end is None: | |
| continue | |
| words.append({"text": token, "start": float(w.start), "end": float(w.end)}) | |
| if words: | |
| blocks = build_vertical_blocks(words) | |
| write_srt_from_blocks(blocks, srt_path) | |
| else: | |
| write_srt_from_segments(seg_list, srt_path) | |
| else: | |
| write_srt_from_segments(seg_list, srt_path) | |
| with JOBS_LOCK: | |
| job.srt_path = srt_path | |
| job.progress = 1.0 | |
| job.status = "done" | |
| job.finished_at = time.time() | |
| except Exception as e: | |
| with JOBS_LOCK: | |
| job.status = "error" | |
| job.error_msg = str(e) | |
| job.finished_at = time.time() |