GetSubtitlesApp / transcribe_utils.py
KaanGoker's picture
GetSubtitlesApp V1.0 - Demo
08c468e
import os
import sys
import time
import subprocess
import threading
import shutil
from pathlib import Path
from typing import Dict, Tuple, Optional, List
from faster_whisper import WhisperModel
import ctranslate2 as ct2
BASE_DIR = Path(__file__).resolve().parent
UPLOAD_DIR = BASE_DIR / "uploads"
DEFAULT_OUTPUT_DIR = BASE_DIR / "outputs"
UPLOAD_DIR.mkdir(exist_ok=True)
DEFAULT_OUTPUT_DIR.mkdir(exist_ok=True)
def _bundle_base() -> Path:
return BASE_DIR
def _local_model_path(choice: str) -> Optional[Path]:
mroot = _bundle_base() / "models"
key = (choice or "").lower().strip()
folder = {"fast": "small", "balanced": "medium", "best": "large-v2"}.get(key)
if not folder:
return None
p = mroot / folder
return p if p.exists() else None
def _detect_device() -> str:
try:
return "cuda" if ct2.get_cuda_device_count() > 0 else "cpu"
except Exception:
return "cpu"
def _compute_candidates(profile: str, device: str) -> list[str]:
if device == "cuda":
base = {"fast": "int8_float16", "balanced": "float16", "best": "float32"} \
.get(profile, "float16")
return [base, "float16", "float32"]
else:
base = {"fast": "int8", "balanced": "int16", "best": "float32"} \
.get(profile, "int16")
return [base, "int16", "float32"]
_model_cache: Dict[str, WhisperModel] = {}
_model_meta: Dict[str, dict] = {}
def get_model(model_choice: str) -> Tuple[WhisperModel, dict]:
key = (model_choice or "fast").lower().strip()
if key in _model_cache:
return _model_cache[key], _model_meta[key]
size = "small" if key == "fast" else "medium" if key == "balanced" else "large-v2"
if key not in ("fast", "balanced", "best"):
key, size = "balanced", "medium"
local = _local_model_path(key)
model_id = str(local) if local else size
device = _detect_device()
candidates = _compute_candidates(key, device)
last_err = None
for compute in candidates:
try:
print(f"[GETSUBTITLES] Loading '{model_id}' device={device} compute_type='{compute}'")
model = WhisperModel(model_id, device=device, compute_type=compute)
meta = {
"model_choice": key,
"model_name": model_id,
"compute_type": compute,
"device": device,
"source": "local" if local else "hub",
}
_model_cache[key] = model
_model_meta[key] = meta
return model, meta
except Exception as e:
print(f"[GETSUBTITLES] Failed compute_type={compute}: {e}")
last_err = e
raise RuntimeError(f"Could not load model on {device} with any compute type. Last error: {last_err}")
def _ffmpeg_path() -> str:
return shutil.which("ffmpeg") or "ffmpeg"
def to_wav16k_mono(src: Path, dst: Path):
cmd = [_ffmpeg_path(), "-y", "-i", str(src), "-ac", "1", "-ar", "16000", str(dst)]
try:
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except subprocess.CalledProcessError as e:
print(f"FFmpeg error: {e.stderr.decode()}")
raise
def write_srt_from_segments(segments, out_path: Path):
def fmt(t):
h = int(t // 3600); t -= h*3600
m = int(t // 60); t -= m*60
s = int(t); ms = int((t - s)*1000)
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
with out_path.open("w", encoding="utf-8") as f:
for i, seg in enumerate(segments, 1):
f.write(f"{i}\n{fmt(seg.start)} --> {fmt(seg.end)}\n{seg.text.strip()}\n\n")
VERT_MAX_CHARS_PER_LINE = 38
VERT_MAX_WORDS_PER_BLOCK = 10
VERT_MAX_DURATION_S = 2.2
VERT_MIN_DURATION_S = 0.7
PUNCT_BREAK = {".", ",", "!", "?", "…", ":", ";", "—", "–"}
def _fmt_time(seconds: float) -> str:
h = int(seconds // 3600)
m = int((seconds % 3600) // 60)
s = int(seconds % 60)
ms = int(round((seconds - int(seconds)) * 1000))
return f"{h:02}:{m:02}:{s:02},{ms:03}"
def _clean_spaces(s: str) -> str:
return (s.replace(" ,", ",")
.replace(" .", ".")
.replace(" !", "!")
.replace(" ?", "?")
.replace(" …", "…")
.replace(" :", ":")
.replace(" ;", ";"))
def _should_break(line: str, block_start: float, last_end: float, last_token: str) -> bool:
duration = (last_end - block_start) if (last_end is not None and block_start is not None) else 0.0
if duration >= VERT_MAX_DURATION_S:
return True
if len(line) >= VERT_MAX_CHARS_PER_LINE - 3 and last_token and last_token[-1] in PUNCT_BREAK:
return True
return False
def build_vertical_blocks(words: List[dict]) -> List[dict]:
blocks = []
i, n = 0, len(words)
while i < n:
line = ""
block_start = words[i]["start"]
block_end = words[i]["end"]
count = 0
j = i
while j < n:
w = words[j]
token = w["text"]
extra_len = (1 if line else 0) + len(token)
if len(line) + extra_len > VERT_MAX_CHARS_PER_LINE and count > 0:
break
if count >= VERT_MAX_WORDS_PER_BLOCK:
break
line = f"{line} {token}" if line else token
block_end = w["end"]
count += 1
j += 1
if _should_break(line, block_start, block_end, token):
break
while j < n:
duration = block_end - block_start
if duration >= VERT_MIN_DURATION_S:
break
last_token = words[j-1]["text"] if j-1 >= 0 else ""
if last_token and last_token[-1] in {".", "!", "?"}:
break
next_token = words[j]["text"]
extra_len = (1 if line else 0) + len(next_token)
if len(line) + extra_len > VERT_MAX_CHARS_PER_LINE:
break
line = f"{line} {next_token}"
block_end = words[j]["end"]
count += 1
j += 1
if line.strip():
blocks.append({
"start": block_start,
"end": block_end,
"text": _clean_spaces(line.strip())
})
i = j if j > i else i + 1
return blocks
def write_srt_from_blocks(blocks: List[dict], out_path: Path):
with out_path.open("w", encoding="utf-8") as f:
for idx, b in enumerate(blocks, start=1):
f.write(f"{idx}\n")
f.write(f"{_fmt_time(b['start'])} --> {_fmt_time(b['end'])}\n")
f.write(f"{b['text']}\n\n")
class Job:
def __init__(self, job_id: str, original_name: str, out_dir: Path):
self.job_id = job_id
self.original_name = original_name
self.out_dir = out_dir
self.duration: float = 0.0
self.progress: float = 0.0 # 0..1
self.status: str = "queued" # queued|running|done|error
self.error_msg: Optional[str] = None
self.language: Optional[str] = None
self.model_choice: Optional[str] = None
self.model_name: Optional[str] = None
self.compute_type: Optional[str] = None
self.started_at: float = time.time()
self.finished_at: Optional[float] = None
self.srt_path: Optional[Path] = None
JOBS: Dict[str, Job] = {}
JOBS_LOCK = threading.Lock()
def run_transcription(job_id: str,
wav_path: Path,
language: Optional[str],
task: str,
model_choice: str,
out_dir: Path,
style: str):
with JOBS_LOCK:
job = JOBS[job_id]
try:
job.status = "loading_model"
model, meta = get_model(model_choice)
job.model_choice = meta["model_choice"]
job.model_name = meta["model_name"]
job.compute_type = meta["compute_type"]
use_word_ts = (style == "vertical")
segments, info = model.transcribe(
str(wav_path),
task=task,
language=language,
word_timestamps=use_word_ts
)
job.language = info.language
job.duration = float(info.duration or 0.0)
job.status = "running"
seg_list = []
last_end = 0.0
for seg in segments:
seg_list.append(seg)
last_end = max(last_end, float(getattr(seg, "end", 0.0) or 0.0))
if job.duration > 0:
with JOBS_LOCK:
job.progress = min(last_end / job.duration, 0.999)
srt_path = out_dir / f"{job.original_name}.srt"
if style == "vertical":
words = []
for seg in seg_list:
if getattr(seg, "words", None):
for w in seg.words:
token = (w.word or "").strip()
if not token or w.start is None or w.end is None:
continue
words.append({"text": token, "start": float(w.start), "end": float(w.end)})
if words:
blocks = build_vertical_blocks(words)
write_srt_from_blocks(blocks, srt_path)
else:
write_srt_from_segments(seg_list, srt_path)
else:
write_srt_from_segments(seg_list, srt_path)
with JOBS_LOCK:
job.srt_path = srt_path
job.progress = 1.0
job.status = "done"
job.finished_at = time.time()
except Exception as e:
with JOBS_LOCK:
job.status = "error"
job.error_msg = str(e)
job.finished_at = time.time()