Umadevi0305 commited on
Commit
a6ad90f
Β·
verified Β·
1 Parent(s): 480bf04

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1867 -0
app.py ADDED
@@ -0,0 +1,1867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gc
3
+ import json
4
+ import os
5
+ import platform
6
+ import queue
7
+ import random
8
+ import re
9
+ import shutil
10
+ import signal
11
+ import subprocess
12
+ import sys
13
+ import tempfile
14
+ import threading
15
+ import time
16
+ from glob import glob
17
+ from importlib.resources import files
18
+
19
+ import click
20
+ import gradio as gr
21
+ import librosa
22
+ import numpy as np
23
+ import psutil
24
+ import torch
25
+ import torchaudio
26
+ from cached_path import cached_path
27
+ from datasets import Dataset as Dataset_
28
+ from datasets.arrow_writer import ArrowWriter
29
+ from safetensors.torch import load_file, save_file
30
+ from scipy.io import wavfile
31
+
32
+ from f5_tts.api import F5TTS
33
+ from f5_tts.infer.utils_infer import transcribe
34
+ from f5_tts.model.utils import convert_char_to_pinyin
35
+
36
+
37
+ training_process = None
38
+ system = platform.system()
39
+ python_executable = sys.executable or "python"
40
+ tts_api = None
41
+ last_checkpoint = ""
42
+ last_device = ""
43
+ last_ema = None
44
+
45
+
46
+ path_data = str(files("f5_tts").joinpath("../../data"))
47
+ path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
48
+ file_train = str(files("f5_tts").joinpath("train/finetune_cli.py"))
49
+
50
+ device = (
51
+ "cuda"
52
+ if torch.cuda.is_available()
53
+ else "xpu"
54
+ if torch.xpu.is_available()
55
+ else "mps"
56
+ if torch.backends.mps.is_available()
57
+ else "cpu"
58
+ )
59
+
60
+
61
+ # Save settings from a JSON file
62
+ def save_settings(
63
+ project_name,
64
+ exp_name,
65
+ learning_rate,
66
+ batch_size_per_gpu,
67
+ batch_size_type,
68
+ max_samples,
69
+ grad_accumulation_steps,
70
+ max_grad_norm,
71
+ epochs,
72
+ num_warmup_updates,
73
+ save_per_updates,
74
+ keep_last_n_checkpoints,
75
+ last_per_updates,
76
+ finetune,
77
+ file_checkpoint_train,
78
+ tokenizer_type,
79
+ tokenizer_file,
80
+ mixed_precision,
81
+ logger,
82
+ ch_8bit_adam,
83
+ ):
84
+ path_project = os.path.join(path_project_ckpts, project_name)
85
+ os.makedirs(path_project, exist_ok=True)
86
+ file_setting = os.path.join(path_project, "setting.json")
87
+
88
+ settings = {
89
+ "exp_name": exp_name,
90
+ "learning_rate": learning_rate,
91
+ "batch_size_per_gpu": batch_size_per_gpu,
92
+ "batch_size_type": batch_size_type,
93
+ "max_samples": max_samples,
94
+ "grad_accumulation_steps": grad_accumulation_steps,
95
+ "max_grad_norm": max_grad_norm,
96
+ "epochs": epochs,
97
+ "num_warmup_updates": num_warmup_updates,
98
+ "save_per_updates": save_per_updates,
99
+ "keep_last_n_checkpoints": keep_last_n_checkpoints,
100
+ "last_per_updates": last_per_updates,
101
+ "finetune": finetune,
102
+ "file_checkpoint_train": file_checkpoint_train,
103
+ "tokenizer_type": tokenizer_type,
104
+ "tokenizer_file": tokenizer_file,
105
+ "mixed_precision": mixed_precision,
106
+ "logger": logger,
107
+ "bnb_optimizer": ch_8bit_adam,
108
+ }
109
+ with open(file_setting, "w") as f:
110
+ json.dump(settings, f, indent=4)
111
+ return "Settings saved!"
112
+
113
+
114
+ # Load settings from a JSON file
115
+ def load_settings(project_name):
116
+ project_name = project_name.replace("_pinyin", "").replace("_char", "")
117
+ path_project = os.path.join(path_project_ckpts, project_name)
118
+ file_setting = os.path.join(path_project, "setting.json")
119
+
120
+ # Default settings
121
+ default_settings = {
122
+ "exp_name": "F5TTS_v1_Base",
123
+ "learning_rate": 1e-5,
124
+ "batch_size_per_gpu": 3200,
125
+ "batch_size_type": "frame",
126
+ "max_samples": 64,
127
+ "grad_accumulation_steps": 1,
128
+ "max_grad_norm": 1.0,
129
+ "epochs": 100,
130
+ "num_warmup_updates": 100,
131
+ "save_per_updates": 500,
132
+ "keep_last_n_checkpoints": -1,
133
+ "last_per_updates": 100,
134
+ "finetune": True,
135
+ "file_checkpoint_train": "",
136
+ "tokenizer_type": "pinyin",
137
+ "tokenizer_file": "",
138
+ "mixed_precision": "fp16",
139
+ "logger": "none",
140
+ "bnb_optimizer": False,
141
+ }
142
+ if device == "mps":
143
+ default_settings["mixed_precision"] = "none"
144
+
145
+ # Load settings from file if it exists
146
+ if os.path.isfile(file_setting):
147
+ with open(file_setting, "r") as f:
148
+ file_settings = json.load(f)
149
+ default_settings.update(file_settings)
150
+
151
+ # Return as a tuple in the correct order
152
+ return (
153
+ default_settings["exp_name"],
154
+ default_settings["learning_rate"],
155
+ default_settings["batch_size_per_gpu"],
156
+ default_settings["batch_size_type"],
157
+ default_settings["max_samples"],
158
+ default_settings["grad_accumulation_steps"],
159
+ default_settings["max_grad_norm"],
160
+ default_settings["epochs"],
161
+ default_settings["num_warmup_updates"],
162
+ default_settings["save_per_updates"],
163
+ default_settings["keep_last_n_checkpoints"],
164
+ default_settings["last_per_updates"],
165
+ default_settings["finetune"],
166
+ default_settings["file_checkpoint_train"],
167
+ default_settings["tokenizer_type"],
168
+ default_settings["tokenizer_file"],
169
+ default_settings["mixed_precision"],
170
+ default_settings["logger"],
171
+ default_settings["bnb_optimizer"],
172
+ )
173
+
174
+
175
+ # Load metadata
176
+ def get_audio_duration(audio_path):
177
+ """Calculate the duration mono of an audio file."""
178
+ audio, sample_rate = torchaudio.load(audio_path)
179
+ return audio.shape[1] / sample_rate
180
+
181
+
182
+ class Slicer: # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
183
+ def __init__(
184
+ self,
185
+ sr: int,
186
+ threshold: float = -40.0,
187
+ min_length: int = 20000, # 20 seconds
188
+ min_interval: int = 300,
189
+ hop_size: int = 20,
190
+ max_sil_kept: int = 2000,
191
+ ):
192
+ if not min_length >= min_interval >= hop_size:
193
+ raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
194
+ if not max_sil_kept >= hop_size:
195
+ raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
196
+ min_interval = sr * min_interval / 1000
197
+ self.threshold = 10 ** (threshold / 20.0)
198
+ self.hop_size = round(sr * hop_size / 1000)
199
+ self.win_size = min(round(min_interval), 4 * self.hop_size)
200
+ self.min_length = round(sr * min_length / 1000 / self.hop_size)
201
+ self.min_interval = round(min_interval / self.hop_size)
202
+ self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
203
+
204
+ def _apply_slice(self, waveform, begin, end):
205
+ if len(waveform.shape) > 1:
206
+ return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
207
+ else:
208
+ return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
209
+
210
+ # @timeit
211
+ def slice(self, waveform):
212
+ if len(waveform.shape) > 1:
213
+ samples = waveform.mean(axis=0)
214
+ else:
215
+ samples = waveform
216
+ if samples.shape[0] <= self.min_length:
217
+ return [waveform]
218
+ rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
219
+ sil_tags = []
220
+ silence_start = None
221
+ clip_start = 0
222
+ for i, rms in enumerate(rms_list):
223
+ # Keep looping while frame is silent.
224
+ if rms < self.threshold:
225
+ # Record start of silent frames.
226
+ if silence_start is None:
227
+ silence_start = i
228
+ continue
229
+ # Keep looping while frame is not silent and silence start has not been recorded.
230
+ if silence_start is None:
231
+ continue
232
+ # Clear recorded silence start if interval is not enough or clip is too short
233
+ is_leading_silence = silence_start == 0 and i > self.max_sil_kept
234
+ need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
235
+ if not is_leading_silence and not need_slice_middle:
236
+ silence_start = None
237
+ continue
238
+ # Need slicing. Record the range of silent frames to be removed.
239
+ if i - silence_start <= self.max_sil_kept:
240
+ pos = rms_list[silence_start : i + 1].argmin() + silence_start
241
+ if silence_start == 0:
242
+ sil_tags.append((0, pos))
243
+ else:
244
+ sil_tags.append((pos, pos))
245
+ clip_start = pos
246
+ elif i - silence_start <= self.max_sil_kept * 2:
247
+ pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
248
+ pos += i - self.max_sil_kept
249
+ pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
250
+ pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
251
+ if silence_start == 0:
252
+ sil_tags.append((0, pos_r))
253
+ clip_start = pos_r
254
+ else:
255
+ sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
256
+ clip_start = max(pos_r, pos)
257
+ else:
258
+ pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
259
+ pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
260
+ if silence_start == 0:
261
+ sil_tags.append((0, pos_r))
262
+ else:
263
+ sil_tags.append((pos_l, pos_r))
264
+ clip_start = pos_r
265
+ silence_start = None
266
+ # Deal with trailing silence.
267
+ total_frames = rms_list.shape[0]
268
+ if silence_start is not None and total_frames - silence_start >= self.min_interval:
269
+ silence_end = min(total_frames, silence_start + self.max_sil_kept)
270
+ pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
271
+ sil_tags.append((pos, total_frames + 1))
272
+ # Apply and return slices: [chunk, start, end]
273
+ if len(sil_tags) == 0:
274
+ return [[waveform, 0, int(total_frames * self.hop_size)]]
275
+ else:
276
+ chunks = []
277
+ if sil_tags[0][0] > 0:
278
+ chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
279
+ for i in range(len(sil_tags) - 1):
280
+ chunks.append(
281
+ [
282
+ self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
283
+ int(sil_tags[i][1] * self.hop_size),
284
+ int(sil_tags[i + 1][0] * self.hop_size),
285
+ ]
286
+ )
287
+ if sil_tags[-1][1] < total_frames:
288
+ chunks.append(
289
+ [
290
+ self._apply_slice(waveform, sil_tags[-1][1], total_frames),
291
+ int(sil_tags[-1][1] * self.hop_size),
292
+ int(total_frames * self.hop_size),
293
+ ]
294
+ )
295
+ return chunks
296
+
297
+
298
+ # terminal
299
+ def terminate_process_tree(pid, including_parent=True):
300
+ try:
301
+ parent = psutil.Process(pid)
302
+ except psutil.NoSuchProcess:
303
+ # Process already terminated
304
+ return
305
+
306
+ children = parent.children(recursive=True)
307
+ for child in children:
308
+ try:
309
+ os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
310
+ except OSError:
311
+ pass
312
+ if including_parent:
313
+ try:
314
+ os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
315
+ except OSError:
316
+ pass
317
+
318
+
319
+ def terminate_process(pid):
320
+ if system == "Windows":
321
+ cmd = f"taskkill /t /f /pid {pid}"
322
+ os.system(cmd)
323
+ else:
324
+ terminate_process_tree(pid)
325
+
326
+
327
+ def start_training(
328
+ dataset_name,
329
+ exp_name,
330
+ learning_rate,
331
+ batch_size_per_gpu,
332
+ batch_size_type,
333
+ max_samples,
334
+ grad_accumulation_steps,
335
+ max_grad_norm,
336
+ epochs,
337
+ num_warmup_updates,
338
+ save_per_updates,
339
+ keep_last_n_checkpoints,
340
+ last_per_updates,
341
+ finetune,
342
+ file_checkpoint_train,
343
+ tokenizer_type,
344
+ tokenizer_file,
345
+ mixed_precision,
346
+ stream,
347
+ logger,
348
+ ch_8bit_adam,
349
+ ):
350
+ global training_process, tts_api, stop_signal
351
+
352
+ if tts_api is not None:
353
+ if tts_api is not None:
354
+ del tts_api
355
+
356
+ gc.collect()
357
+ torch.cuda.empty_cache()
358
+ tts_api = None
359
+
360
+ path_project = os.path.join(path_data, dataset_name)
361
+
362
+ if not os.path.isdir(path_project):
363
+ yield (
364
+ f"There is not project with name {dataset_name}",
365
+ gr.update(interactive=True),
366
+ gr.update(interactive=False),
367
+ )
368
+ return
369
+
370
+ file_raw = os.path.join(path_project, "raw.arrow")
371
+ if not os.path.isfile(file_raw):
372
+ yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False)
373
+ return
374
+
375
+ # Check if a training process is already running
376
+ if training_process is not None:
377
+ return "Train run already!", gr.update(interactive=False), gr.update(interactive=True)
378
+
379
+ yield "start train", gr.update(interactive=False), gr.update(interactive=False)
380
+
381
+ # Command to run the training script with the specified arguments
382
+
383
+ if tokenizer_file == "":
384
+ if dataset_name.endswith("_pinyin"):
385
+ tokenizer_type = "pinyin"
386
+ elif dataset_name.endswith("_char"):
387
+ tokenizer_type = "char"
388
+ else:
389
+ tokenizer_type = "custom"
390
+
391
+ dataset_name = dataset_name.replace("_pinyin", "").replace("_char", "")
392
+
393
+ if mixed_precision != "none":
394
+ fp16 = f"--mixed_precision={mixed_precision}"
395
+ else:
396
+ fp16 = ""
397
+
398
+ cmd = (
399
+ f'accelerate launch {fp16} "{file_train}" --exp_name {exp_name}'
400
+ f" --learning_rate {learning_rate}"
401
+ f" --batch_size_per_gpu {batch_size_per_gpu}"
402
+ f" --batch_size_type {batch_size_type}"
403
+ f" --max_samples {max_samples}"
404
+ f" --grad_accumulation_steps {grad_accumulation_steps}"
405
+ f" --max_grad_norm {max_grad_norm}"
406
+ f" --epochs {epochs}"
407
+ f" --num_warmup_updates {num_warmup_updates}"
408
+ f" --save_per_updates {save_per_updates}"
409
+ f" --keep_last_n_checkpoints {keep_last_n_checkpoints}"
410
+ f" --last_per_updates {last_per_updates}"
411
+ f" --dataset_name {dataset_name}"
412
+ )
413
+
414
+ if finetune:
415
+ cmd += " --finetune"
416
+
417
+ if file_checkpoint_train != "":
418
+ cmd += f' --pretrain "{file_checkpoint_train}"'
419
+
420
+ if tokenizer_file != "":
421
+ cmd += f" --tokenizer_path {tokenizer_file}"
422
+
423
+ cmd += f" --tokenizer {tokenizer_type}"
424
+
425
+ if logger != "none":
426
+ cmd += f" --logger {logger}"
427
+
428
+ cmd += " --log_samples"
429
+
430
+ if ch_8bit_adam:
431
+ cmd += " --bnb_optimizer"
432
+
433
+ print("run command : \n" + cmd + "\n")
434
+
435
+ save_settings(
436
+ dataset_name,
437
+ exp_name,
438
+ learning_rate,
439
+ batch_size_per_gpu,
440
+ batch_size_type,
441
+ max_samples,
442
+ grad_accumulation_steps,
443
+ max_grad_norm,
444
+ epochs,
445
+ num_warmup_updates,
446
+ save_per_updates,
447
+ keep_last_n_checkpoints,
448
+ last_per_updates,
449
+ finetune,
450
+ file_checkpoint_train,
451
+ tokenizer_type,
452
+ tokenizer_file,
453
+ mixed_precision,
454
+ logger,
455
+ ch_8bit_adam,
456
+ )
457
+
458
+ try:
459
+ if not stream:
460
+ # Start the training process
461
+ training_process = subprocess.Popen(cmd, shell=True)
462
+
463
+ time.sleep(5)
464
+ yield "train start", gr.update(interactive=False), gr.update(interactive=True)
465
+
466
+ # Wait for the training process to finish
467
+ training_process.wait()
468
+ else:
469
+
470
+ def stream_output(pipe, output_queue):
471
+ try:
472
+ for line in iter(pipe.readline, ""):
473
+ output_queue.put(line)
474
+ except Exception as e:
475
+ output_queue.put(f"Error reading pipe: {str(e)}")
476
+ finally:
477
+ pipe.close()
478
+
479
+ env = os.environ.copy()
480
+ env["PYTHONUNBUFFERED"] = "1"
481
+
482
+ training_process = subprocess.Popen(
483
+ cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env
484
+ )
485
+ yield "Training started ...", gr.update(interactive=False), gr.update(interactive=True)
486
+
487
+ stdout_queue = queue.Queue()
488
+ stderr_queue = queue.Queue()
489
+
490
+ stdout_thread = threading.Thread(target=stream_output, args=(training_process.stdout, stdout_queue))
491
+ stderr_thread = threading.Thread(target=stream_output, args=(training_process.stderr, stderr_queue))
492
+ stdout_thread.daemon = True
493
+ stderr_thread.daemon = True
494
+ stdout_thread.start()
495
+ stderr_thread.start()
496
+ stop_signal = False
497
+ while True:
498
+ if stop_signal:
499
+ training_process.terminate()
500
+ time.sleep(0.5)
501
+ if training_process.poll() is None:
502
+ training_process.kill()
503
+ yield "Training stopped by user.", gr.update(interactive=True), gr.update(interactive=False)
504
+ break
505
+
506
+ process_status = training_process.poll()
507
+
508
+ # Handle stdout
509
+ try:
510
+ while True:
511
+ output = stdout_queue.get_nowait()
512
+ print(output, end="")
513
+ match = re.search(
514
+ r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+), update=(\d+)", output
515
+ )
516
+ if match:
517
+ current_epoch = match.group(1)
518
+ total_epochs = match.group(2)
519
+ percent_complete = match.group(3)
520
+ elapsed_time = match.group(4)
521
+ loss = match.group(5)
522
+ current_update = match.group(6)
523
+ message = (
524
+ f"Epoch: {current_epoch}/{total_epochs}, "
525
+ f"Progress: {percent_complete}%, "
526
+ f"Elapsed Time: {elapsed_time}, "
527
+ f"Loss: {loss}, "
528
+ f"Update: {current_update}"
529
+ )
530
+ yield message, gr.update(interactive=False), gr.update(interactive=True)
531
+ elif output.strip():
532
+ yield output, gr.update(interactive=False), gr.update(interactive=True)
533
+ except queue.Empty:
534
+ pass
535
+
536
+ # Handle stderr
537
+ try:
538
+ while True:
539
+ error_output = stderr_queue.get_nowait()
540
+ print(error_output, end="")
541
+ if error_output.strip():
542
+ yield f"{error_output.strip()}", gr.update(interactive=False), gr.update(interactive=True)
543
+ except queue.Empty:
544
+ pass
545
+
546
+ if process_status is not None and stdout_queue.empty() and stderr_queue.empty():
547
+ if process_status != 0:
548
+ yield (
549
+ f"Process crashed with exit code {process_status}!",
550
+ gr.update(interactive=False),
551
+ gr.update(interactive=True),
552
+ )
553
+ else:
554
+ yield (
555
+ "Training complete or paused ...",
556
+ gr.update(interactive=False),
557
+ gr.update(interactive=True),
558
+ )
559
+ break
560
+
561
+ # Small sleep to prevent CPU thrashing
562
+ time.sleep(0.1)
563
+
564
+ # Clean up
565
+ training_process.stdout.close()
566
+ training_process.stderr.close()
567
+ training_process.wait()
568
+
569
+ time.sleep(1)
570
+
571
+ if training_process is None:
572
+ text_info = "Train stopped !"
573
+ else:
574
+ text_info = "Train complete at end !"
575
+
576
+ except Exception as e: # Catch all exceptions
577
+ # Ensure that we reset the training process variable in case of an error
578
+ text_info = f"An error occurred: {str(e)}"
579
+
580
+ training_process = None
581
+
582
+ yield text_info, gr.update(interactive=True), gr.update(interactive=False)
583
+
584
+
585
+ def stop_training():
586
+ global training_process, stop_signal
587
+
588
+ if training_process is None:
589
+ return "Train not running !", gr.update(interactive=True), gr.update(interactive=False)
590
+ terminate_process_tree(training_process.pid)
591
+ # training_process = None
592
+ stop_signal = True
593
+ return "Train stopped !", gr.update(interactive=True), gr.update(interactive=False)
594
+
595
+
596
+ def get_list_projects():
597
+ project_list = []
598
+ for folder in os.listdir(path_data):
599
+ path_folder = os.path.join(path_data, folder)
600
+ if not os.path.isdir(path_folder):
601
+ continue
602
+ folder = folder.lower()
603
+ if folder == "emilia_zh_en_pinyin":
604
+ continue
605
+ project_list.append(folder)
606
+
607
+ projects_selelect = None if not project_list else project_list[-1]
608
+
609
+ return project_list, projects_selelect
610
+
611
+
612
+ def create_data_project(name, tokenizer_type):
613
+ name += "_" + tokenizer_type
614
+ os.makedirs(os.path.join(path_data, name), exist_ok=True)
615
+ os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
616
+ project_list, projects_selelect = get_list_projects()
617
+ return gr.update(choices=project_list, value=name)
618
+
619
+
620
+ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
621
+ path_project = os.path.join(path_data, name_project)
622
+ path_dataset = os.path.join(path_project, "dataset")
623
+ path_project_wavs = os.path.join(path_project, "wavs")
624
+ file_metadata = os.path.join(path_project, "metadata.csv")
625
+
626
+ if not user:
627
+ if audio_files is None:
628
+ return "You need to load an audio file."
629
+
630
+ if os.path.isdir(path_project_wavs):
631
+ shutil.rmtree(path_project_wavs)
632
+
633
+ if os.path.isfile(file_metadata):
634
+ os.remove(file_metadata)
635
+
636
+ os.makedirs(path_project_wavs, exist_ok=True)
637
+
638
+ if user:
639
+ file_audios = [
640
+ file
641
+ for format in ("*.wav", "*.ogg", "*.opus", "*.mp3", "*.flac")
642
+ for file in glob(os.path.join(path_dataset, format))
643
+ ]
644
+ if file_audios == []:
645
+ return "No audio file was found in the dataset."
646
+ else:
647
+ file_audios = audio_files
648
+
649
+ alpha = 0.5
650
+ _max = 1.0
651
+ slicer = Slicer(24000)
652
+
653
+ num = 0
654
+ error_num = 0
655
+ data = ""
656
+ for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))):
657
+ audio, _ = librosa.load(file_audio, sr=24000, mono=True)
658
+
659
+ list_slicer = slicer.slice(audio)
660
+ for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"):
661
+ name_segment = os.path.join(f"segment_{num}")
662
+ file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
663
+
664
+ tmp_max = np.abs(chunk).max()
665
+ if tmp_max > 1:
666
+ chunk /= tmp_max
667
+ chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
668
+ wavfile.write(file_segment, 24000, (chunk * 32767).astype(np.int16))
669
+
670
+ try:
671
+ text = transcribe(file_segment, language)
672
+ text = text.strip()
673
+
674
+ data += f"{name_segment}|{text}\n"
675
+
676
+ num += 1
677
+ except: # noqa: E722
678
+ error_num += 1
679
+
680
+ with open(file_metadata, "w", encoding="utf-8-sig") as f:
681
+ f.write(data)
682
+
683
+ if error_num != []:
684
+ error_text = f"\nerror files : {error_num}"
685
+ else:
686
+ error_text = ""
687
+
688
+ return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
689
+
690
+
691
+ def format_seconds_to_hms(seconds):
692
+ hours = int(seconds / 3600)
693
+ minutes = int((seconds % 3600) / 60)
694
+ seconds = seconds % 60
695
+ return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
696
+
697
+
698
+ def get_correct_audio_path(
699
+ audio_input,
700
+ base_path="wavs",
701
+ supported_formats=("wav", "mp3", "aac", "flac", "m4a", "alac", "ogg", "aiff", "wma", "amr"),
702
+ ):
703
+ file_audio = None
704
+
705
+ # Helper function to check if file has a supported extension
706
+ def has_supported_extension(file_name):
707
+ return any(file_name.endswith(f".{ext}") for ext in supported_formats)
708
+
709
+ # Case 1: If it's a full path with a valid extension, use it directly
710
+ if os.path.isabs(audio_input) and has_supported_extension(audio_input):
711
+ file_audio = audio_input
712
+
713
+ # Case 2: If it has a supported extension but is not a full path
714
+ elif has_supported_extension(audio_input) and not os.path.isabs(audio_input):
715
+ file_audio = os.path.join(base_path, audio_input)
716
+
717
+ # Case 3: If only the name is given (no extension and not a full path)
718
+ elif not has_supported_extension(audio_input) and not os.path.isabs(audio_input):
719
+ for ext in supported_formats:
720
+ potential_file = os.path.join(base_path, f"{audio_input}.{ext}")
721
+ if os.path.exists(potential_file):
722
+ file_audio = potential_file
723
+ break
724
+ else:
725
+ file_audio = os.path.join(base_path, f"{audio_input}.{supported_formats[0]}")
726
+ return file_audio
727
+
728
+
729
+ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
730
+ path_project = os.path.join(path_data, name_project)
731
+ path_project_wavs = os.path.join(path_project, "wavs")
732
+ file_metadata = os.path.join(path_project, "metadata.csv")
733
+ file_raw = os.path.join(path_project, "raw.arrow")
734
+ file_duration = os.path.join(path_project, "duration.json")
735
+ file_vocab = os.path.join(path_project, "vocab.txt")
736
+
737
+ if not os.path.isfile(file_metadata):
738
+ return "The file was not found in " + file_metadata, ""
739
+
740
+ with open(file_metadata, "r", encoding="utf-8-sig") as f:
741
+ data = f.read()
742
+
743
+ audio_path_list = []
744
+ text_list = []
745
+ duration_list = []
746
+
747
+ count = data.split("\n")
748
+ lenght = 0
749
+ result = []
750
+ error_files = []
751
+ text_vocab_set = set()
752
+ for line in progress.tqdm(data.split("\n"), total=count):
753
+ sp_line = line.split("|")
754
+ if len(sp_line) != 2:
755
+ continue
756
+ name_audio, text = sp_line[:2]
757
+
758
+ file_audio = get_correct_audio_path(name_audio, path_project_wavs)
759
+
760
+ if not os.path.isfile(file_audio):
761
+ error_files.append([file_audio, "error path"])
762
+ continue
763
+
764
+ try:
765
+ duration = get_audio_duration(file_audio)
766
+ except Exception as e:
767
+ error_files.append([file_audio, "duration"])
768
+ print(f"Error processing {file_audio}: {e}")
769
+ continue
770
+
771
+ if duration < 1 or duration > 30:
772
+ if duration > 30:
773
+ error_files.append([file_audio, "duration > 30 sec"])
774
+ if duration < 1:
775
+ error_files.append([file_audio, "duration < 1 sec "])
776
+ continue
777
+ if len(text) < 3:
778
+ error_files.append([file_audio, "very short text length 3"])
779
+ continue
780
+
781
+ text = text.strip()
782
+ text = convert_char_to_pinyin([text], polyphone=True)[0]
783
+
784
+ audio_path_list.append(file_audio)
785
+ duration_list.append(duration)
786
+ text_list.append(text)
787
+
788
+ result.append({"audio_path": file_audio, "text": text, "duration": duration})
789
+ if ch_tokenizer:
790
+ text_vocab_set.update(list(text))
791
+
792
+ lenght += duration
793
+
794
+ if duration_list == []:
795
+ return f"Error: No audio files found in the specified path : {path_project_wavs}", ""
796
+
797
+ min_second = round(min(duration_list), 2)
798
+ max_second = round(max(duration_list), 2)
799
+
800
+ with ArrowWriter(path=file_raw) as writer:
801
+ for line in progress.tqdm(result, total=len(result), desc="prepare data"):
802
+ writer.write(line)
803
+ writer.finalize()
804
+
805
+ with open(file_duration, "w") as f:
806
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
807
+
808
+ new_vocal = ""
809
+ if not ch_tokenizer:
810
+ if not os.path.isfile(file_vocab):
811
+ file_vocab_finetune = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
812
+ if not os.path.isfile(file_vocab_finetune):
813
+ return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!", ""
814
+ shutil.copy2(file_vocab_finetune, file_vocab)
815
+
816
+ with open(file_vocab, "r", encoding="utf-8-sig") as f:
817
+ vocab_char_map = {}
818
+ for i, char in enumerate(f):
819
+ vocab_char_map[char[:-1]] = i
820
+ vocab_size = len(vocab_char_map)
821
+
822
+ else:
823
+ with open(file_vocab, "w", encoding="utf-8-sig") as f:
824
+ for vocab in sorted(text_vocab_set):
825
+ f.write(vocab + "\n")
826
+ new_vocal += vocab + "\n"
827
+ vocab_size = len(text_vocab_set)
828
+
829
+ if error_files != []:
830
+ error_text = "\n".join([" = ".join(item) for item in error_files])
831
+ else:
832
+ error_text = ""
833
+
834
+ return (
835
+ f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\nvocab : {vocab_size}\n{error_text}",
836
+ new_vocal,
837
+ )
838
+
839
+
840
+ def check_user(value):
841
+ return gr.update(visible=not value), gr.update(visible=value)
842
+
843
+
844
+ def calculate_train(
845
+ name_project,
846
+ epochs,
847
+ learning_rate,
848
+ batch_size_per_gpu,
849
+ batch_size_type,
850
+ max_samples,
851
+ num_warmup_updates,
852
+ finetune,
853
+ ):
854
+ path_project = os.path.join(path_data, name_project)
855
+ file_duration = os.path.join(path_project, "duration.json")
856
+
857
+ hop_length = 256
858
+ sampling_rate = 24000
859
+
860
+ if not os.path.isfile(file_duration):
861
+ return (
862
+ epochs,
863
+ learning_rate,
864
+ batch_size_per_gpu,
865
+ max_samples,
866
+ num_warmup_updates,
867
+ "project not found !",
868
+ )
869
+
870
+ with open(file_duration, "r") as file:
871
+ data = json.load(file)
872
+
873
+ duration_list = data["duration"]
874
+ max_sample_length = max(duration_list) * sampling_rate / hop_length
875
+ total_samples = len(duration_list)
876
+ total_duration = sum(duration_list)
877
+
878
+ if torch.cuda.is_available():
879
+ gpu_count = torch.cuda.device_count()
880
+ total_memory = 0
881
+ for i in range(gpu_count):
882
+ gpu_properties = torch.cuda.get_device_properties(i)
883
+ total_memory += gpu_properties.total_memory / (1024**3) # in GB
884
+ elif torch.xpu.is_available():
885
+ gpu_count = torch.xpu.device_count()
886
+ total_memory = 0
887
+ for i in range(gpu_count):
888
+ gpu_properties = torch.xpu.get_device_properties(i)
889
+ total_memory += gpu_properties.total_memory / (1024**3)
890
+ elif torch.backends.mps.is_available():
891
+ gpu_count = 1
892
+ total_memory = psutil.virtual_memory().available / (1024**3)
893
+
894
+ avg_gpu_memory = total_memory / gpu_count
895
+
896
+ # rough estimate of batch size
897
+ if batch_size_type == "frame":
898
+ batch_size_per_gpu = max(int(38400 * (avg_gpu_memory - 5) / 75), int(max_sample_length))
899
+ elif batch_size_type == "sample":
900
+ batch_size_per_gpu = int(200 / (total_duration / total_samples))
901
+
902
+ if total_samples < 64:
903
+ max_samples = int(total_samples * 0.25)
904
+
905
+ num_warmup_updates = max(num_warmup_updates, int(total_samples * 0.05))
906
+
907
+ # take 1.2M updates as the maximum
908
+ max_updates = 1200000
909
+
910
+ if batch_size_type == "frame":
911
+ mini_batch_duration = batch_size_per_gpu * gpu_count * hop_length / sampling_rate
912
+ updates_per_epoch = total_duration / mini_batch_duration
913
+ elif batch_size_type == "sample":
914
+ updates_per_epoch = total_samples / batch_size_per_gpu / gpu_count
915
+
916
+ epochs = int(max_updates / updates_per_epoch)
917
+
918
+ if finetune:
919
+ learning_rate = 1e-5
920
+ else:
921
+ learning_rate = 7.5e-5
922
+
923
+ return (
924
+ epochs,
925
+ learning_rate,
926
+ batch_size_per_gpu,
927
+ max_samples,
928
+ num_warmup_updates,
929
+ total_samples,
930
+ )
931
+
932
+
933
+ def prune_checkpoint(checkpoint_path: str, new_checkpoint_path: str, save_ema: bool, safetensors: bool) -> str:
934
+ try:
935
+ checkpoint = torch.load(checkpoint_path, weights_only=True)
936
+ print("Original Checkpoint Keys:", checkpoint.keys())
937
+
938
+ to_retain = "ema_model_state_dict" if save_ema else "model_state_dict"
939
+ try:
940
+ model_state_dict_to_retain = checkpoint[to_retain]
941
+ except KeyError:
942
+ return f"{to_retain} not found in the checkpoint."
943
+
944
+ if safetensors:
945
+ new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors")
946
+ save_file(model_state_dict_to_retain, new_checkpoint_path)
947
+ else:
948
+ new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt")
949
+ new_checkpoint = {"ema_model_state_dict": model_state_dict_to_retain}
950
+ torch.save(new_checkpoint, new_checkpoint_path)
951
+
952
+ return f"New checkpoint saved at: {new_checkpoint_path}"
953
+
954
+ except Exception as e:
955
+ return f"An error occurred: {e}"
956
+
957
+
958
+ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
959
+ seed = 666
960
+ random.seed(seed)
961
+ os.environ["PYTHONHASHSEED"] = str(seed)
962
+ torch.manual_seed(seed)
963
+ torch.cuda.manual_seed(seed)
964
+ torch.cuda.manual_seed_all(seed)
965
+ torch.backends.cudnn.deterministic = True
966
+ torch.backends.cudnn.benchmark = False
967
+
968
+ if ckpt_path.endswith(".safetensors"):
969
+ ckpt = load_file(ckpt_path, device="cpu")
970
+ ckpt = {"ema_model_state_dict": ckpt}
971
+ elif ckpt_path.endswith(".pt"):
972
+ ckpt = torch.load(ckpt_path, map_location="cpu")
973
+
974
+ ema_sd = ckpt.get("ema_model_state_dict", {})
975
+ embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
976
+ old_embed_ema = ema_sd[embed_key_ema]
977
+
978
+ vocab_old = old_embed_ema.size(0)
979
+ embed_dim = old_embed_ema.size(1)
980
+ vocab_new = vocab_old + num_new_tokens
981
+
982
+ def expand_embeddings(old_embeddings):
983
+ new_embeddings = torch.zeros((vocab_new, embed_dim))
984
+ new_embeddings[:vocab_old] = old_embeddings
985
+ new_embeddings[vocab_old:] = torch.randn((num_new_tokens, embed_dim))
986
+ return new_embeddings
987
+
988
+ ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])
989
+
990
+ if new_ckpt_path.endswith(".safetensors"):
991
+ save_file(ema_sd, new_ckpt_path)
992
+ elif new_ckpt_path.endswith(".pt"):
993
+ torch.save(ckpt, new_ckpt_path)
994
+
995
+ return vocab_new
996
+
997
+
998
+ def vocab_count(text):
999
+ return str(len(text.split(",")))
1000
+
1001
+
1002
+ def vocab_extend(project_name, symbols, model_type):
1003
+ if symbols == "":
1004
+ return "Symbols empty!"
1005
+
1006
+ name_project = project_name
1007
+ path_project = os.path.join(path_data, name_project)
1008
+ file_vocab_project = os.path.join(path_project, "vocab.txt")
1009
+
1010
+ file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
1011
+ if not os.path.isfile(file_vocab):
1012
+ return f"the file {file_vocab} not found !"
1013
+
1014
+ symbols = symbols.split(",")
1015
+ if symbols == []:
1016
+ return "Symbols to extend not found."
1017
+
1018
+ with open(file_vocab, "r", encoding="utf-8-sig") as f:
1019
+ data = f.read()
1020
+ vocab = data.split("\n")
1021
+ vocab_check = set(vocab)
1022
+
1023
+ miss_symbols = []
1024
+ for item in symbols:
1025
+ item = item.replace(" ", "")
1026
+ if item in vocab_check:
1027
+ continue
1028
+ miss_symbols.append(item)
1029
+
1030
+ if miss_symbols == []:
1031
+ return "Symbols are okay no need to extend."
1032
+
1033
+ size_vocab = len(vocab)
1034
+ vocab.pop()
1035
+ for item in miss_symbols:
1036
+ vocab.append(item)
1037
+
1038
+ vocab.append("")
1039
+
1040
+ with open(file_vocab_project, "w", encoding="utf-8") as f:
1041
+ f.write("\n".join(vocab))
1042
+
1043
+ if model_type == "F5TTS_v1_Base":
1044
+ ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors"))
1045
+ elif model_type == "F5TTS_Base":
1046
+ ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
1047
+ elif model_type == "E2TTS_Base":
1048
+ ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
1049
+
1050
+ vocab_size_new = len(miss_symbols)
1051
+
1052
+ dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
1053
+ new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
1054
+ os.makedirs(new_ckpt_path, exist_ok=True)
1055
+
1056
+ # Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
1057
+ new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_" + os.path.basename(ckpt_path))
1058
+
1059
+ size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
1060
+
1061
+ vocab_new = "\n".join(miss_symbols)
1062
+ return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}"
1063
+
1064
+
1065
+ def vocab_check(project_name, tokenizer_type):
1066
+ name_project = project_name
1067
+ path_project = os.path.join(path_data, name_project)
1068
+
1069
+ file_metadata = os.path.join(path_project, "metadata.csv")
1070
+
1071
+ file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
1072
+ if not os.path.isfile(file_vocab):
1073
+ return f"the file {file_vocab} not found !", ""
1074
+
1075
+ with open(file_vocab, "r", encoding="utf-8-sig") as f:
1076
+ data = f.read()
1077
+ vocab = data.split("\n")
1078
+ vocab = set(vocab)
1079
+
1080
+ if not os.path.isfile(file_metadata):
1081
+ return f"the file {file_metadata} not found !", ""
1082
+
1083
+ with open(file_metadata, "r", encoding="utf-8-sig") as f:
1084
+ data = f.read()
1085
+
1086
+ miss_symbols = []
1087
+ miss_symbols_keep = {}
1088
+ for item in data.split("\n"):
1089
+ sp = item.split("|")
1090
+ if len(sp) != 2:
1091
+ continue
1092
+
1093
+ text = sp[1].strip()
1094
+ if tokenizer_type == "pinyin":
1095
+ text = convert_char_to_pinyin([text], polyphone=True)[0]
1096
+
1097
+ for t in text:
1098
+ if t not in vocab and t not in miss_symbols_keep:
1099
+ miss_symbols.append(t)
1100
+ miss_symbols_keep[t] = t
1101
+
1102
+ if miss_symbols == []:
1103
+ vocab_miss = ""
1104
+ info = "You can train using your language !"
1105
+ else:
1106
+ vocab_miss = ",".join(miss_symbols)
1107
+ info = f"The following {len(miss_symbols)} symbols are missing in your language\n\n"
1108
+
1109
+ return info, vocab_miss
1110
+
1111
+
1112
+ def get_random_sample_prepare(project_name):
1113
+ name_project = project_name
1114
+ path_project = os.path.join(path_data, name_project)
1115
+ file_arrow = os.path.join(path_project, "raw.arrow")
1116
+ if not os.path.isfile(file_arrow):
1117
+ return "", None
1118
+ dataset = Dataset_.from_file(file_arrow)
1119
+ random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
1120
+ text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
1121
+ audio_path = random_sample["audio_path"][0]
1122
+ return text, audio_path
1123
+
1124
+
1125
+ def get_random_sample_transcribe(project_name):
1126
+ name_project = project_name
1127
+ path_project = os.path.join(path_data, name_project)
1128
+ file_metadata = os.path.join(path_project, "metadata.csv")
1129
+ if not os.path.isfile(file_metadata):
1130
+ return "", None
1131
+
1132
+ data = ""
1133
+ with open(file_metadata, "r", encoding="utf-8-sig") as f:
1134
+ data = f.read()
1135
+
1136
+ list_data = []
1137
+ for item in data.split("\n"):
1138
+ sp = item.split("|")
1139
+ if len(sp) != 2:
1140
+ continue
1141
+
1142
+ # fixed audio when it is absolute
1143
+ file_audio = get_correct_audio_path(sp[0], os.path.join(path_project, "wavs"))
1144
+ list_data.append([file_audio, sp[1]])
1145
+
1146
+ if list_data == []:
1147
+ return "", None
1148
+
1149
+ random_item = random.choice(list_data)
1150
+
1151
+ return random_item[1], random_item[0]
1152
+
1153
+
1154
+ def get_random_sample_infer(project_name):
1155
+ text, audio = get_random_sample_transcribe(project_name)
1156
+ return (
1157
+ text,
1158
+ text,
1159
+ audio,
1160
+ )
1161
+
1162
+
1163
+ def infer(
1164
+ project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema, speed, seed, remove_silence
1165
+ ):
1166
+ global last_checkpoint, last_device, tts_api, last_ema
1167
+
1168
+ if not os.path.isfile(file_checkpoint):
1169
+ return None, "checkpoint not found!"
1170
+
1171
+ if training_process is not None:
1172
+ device_test = "cpu"
1173
+ else:
1174
+ device_test = None
1175
+
1176
+ if last_checkpoint != file_checkpoint or last_device != device_test or last_ema != use_ema or tts_api is None:
1177
+ if last_checkpoint != file_checkpoint:
1178
+ last_checkpoint = file_checkpoint
1179
+
1180
+ if last_device != device_test:
1181
+ last_device = device_test
1182
+
1183
+ if last_ema != use_ema:
1184
+ last_ema = use_ema
1185
+
1186
+ vocab_file = os.path.join(path_data, project, "vocab.txt")
1187
+
1188
+ tts_api = F5TTS(
1189
+ model=exp_name, ckpt_file=file_checkpoint, vocab_file=vocab_file, device=device_test, use_ema=use_ema
1190
+ )
1191
+
1192
+ print("update >> ", device_test, file_checkpoint, use_ema)
1193
+
1194
+ if seed == -1: # -1 used for random
1195
+ seed = None
1196
+
1197
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
1198
+ tts_api.infer(
1199
+ ref_file=ref_audio,
1200
+ ref_text=ref_text.strip(),
1201
+ gen_text=gen_text.strip(),
1202
+ nfe_step=nfe_step,
1203
+ speed=speed,
1204
+ remove_silence=remove_silence,
1205
+ file_wave=f.name,
1206
+ seed=seed,
1207
+ )
1208
+ return f.name, tts_api.device, str(tts_api.seed)
1209
+
1210
+
1211
+ def check_finetune(finetune):
1212
+ return gr.update(interactive=finetune), gr.update(interactive=finetune), gr.update(interactive=finetune)
1213
+
1214
+
1215
+ def get_checkpoints_project(project_name, is_gradio=True):
1216
+ if project_name is None:
1217
+ return [], ""
1218
+ project_name = project_name.replace("_pinyin", "").replace("_char", "")
1219
+
1220
+ if os.path.isdir(path_project_ckpts):
1221
+ files_checkpoints = glob(os.path.join(path_project_ckpts, project_name, "*.pt"))
1222
+ # Separate pretrained and regular checkpoints
1223
+ pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)]
1224
+ regular_checkpoints = [
1225
+ f
1226
+ for f in files_checkpoints
1227
+ if "pretrained_" not in os.path.basename(f) and "model_last.pt" not in os.path.basename(f)
1228
+ ]
1229
+ last_checkpoint = [f for f in files_checkpoints if "model_last.pt" in os.path.basename(f)]
1230
+
1231
+ # Sort regular checkpoints by number
1232
+ regular_checkpoints = sorted(
1233
+ regular_checkpoints, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
1234
+ )
1235
+
1236
+ # Combine in order: pretrained, regular, last
1237
+ files_checkpoints = pretrained_checkpoints + regular_checkpoints + last_checkpoint
1238
+ else:
1239
+ files_checkpoints = []
1240
+
1241
+ selelect_checkpoint = None if not files_checkpoints else files_checkpoints[0]
1242
+
1243
+ if is_gradio:
1244
+ return gr.update(choices=files_checkpoints, value=selelect_checkpoint)
1245
+
1246
+ return files_checkpoints, selelect_checkpoint
1247
+
1248
+
1249
+ def get_audio_project(project_name, is_gradio=True):
1250
+ if project_name is None:
1251
+ return [], ""
1252
+ project_name = project_name.replace("_pinyin", "").replace("_char", "")
1253
+
1254
+ if os.path.isdir(path_project_ckpts):
1255
+ files_audios = glob(os.path.join(path_project_ckpts, project_name, "samples", "*.wav"))
1256
+ files_audios = sorted(files_audios, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]))
1257
+
1258
+ files_audios = [item.replace("_gen.wav", "") for item in files_audios if item.endswith("_gen.wav")]
1259
+ else:
1260
+ files_audios = []
1261
+
1262
+ selelect_checkpoint = None if not files_audios else files_audios[0]
1263
+
1264
+ if is_gradio:
1265
+ return gr.update(choices=files_audios, value=selelect_checkpoint)
1266
+
1267
+ return files_audios, selelect_checkpoint
1268
+
1269
+
1270
+ def get_gpu_stats():
1271
+ gpu_stats = ""
1272
+
1273
+ if torch.cuda.is_available():
1274
+ gpu_count = torch.cuda.device_count()
1275
+ for i in range(gpu_count):
1276
+ gpu_name = torch.cuda.get_device_name(i)
1277
+ gpu_properties = torch.cuda.get_device_properties(i)
1278
+ total_memory = gpu_properties.total_memory / (1024**3) # in GB
1279
+ allocated_memory = torch.cuda.memory_allocated(i) / (1024**2) # in MB
1280
+ reserved_memory = torch.cuda.memory_reserved(i) / (1024**2) # in MB
1281
+
1282
+ gpu_stats += (
1283
+ f"GPU {i} Name: {gpu_name}\n"
1284
+ f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
1285
+ f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
1286
+ f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
1287
+ )
1288
+ elif torch.xpu.is_available():
1289
+ gpu_count = torch.xpu.device_count()
1290
+ for i in range(gpu_count):
1291
+ gpu_name = torch.xpu.get_device_name(i)
1292
+ gpu_properties = torch.xpu.get_device_properties(i)
1293
+ total_memory = gpu_properties.total_memory / (1024**3) # in GB
1294
+ allocated_memory = torch.xpu.memory_allocated(i) / (1024**2) # in MB
1295
+ reserved_memory = torch.xpu.memory_reserved(i) / (1024**2) # in MB
1296
+
1297
+ gpu_stats += (
1298
+ f"GPU {i} Name: {gpu_name}\n"
1299
+ f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
1300
+ f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
1301
+ f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
1302
+ )
1303
+ elif torch.backends.mps.is_available():
1304
+ gpu_count = 1
1305
+ gpu_stats += "MPS GPU\n"
1306
+ total_memory = psutil.virtual_memory().total / (
1307
+ 1024**3
1308
+ ) # Total system memory (MPS doesn't have its own memory)
1309
+ allocated_memory = 0
1310
+ reserved_memory = 0
1311
+
1312
+ gpu_stats += (
1313
+ f"Total system memory: {total_memory:.2f} GB\n"
1314
+ f"Allocated GPU memory (MPS): {allocated_memory:.2f} MB\n"
1315
+ f"Reserved GPU memory (MPS): {reserved_memory:.2f} MB\n"
1316
+ )
1317
+
1318
+ else:
1319
+ gpu_stats = "No GPU available"
1320
+
1321
+ return gpu_stats
1322
+
1323
+
1324
+ def get_cpu_stats():
1325
+ cpu_usage = psutil.cpu_percent(interval=1)
1326
+ memory_info = psutil.virtual_memory()
1327
+ memory_used = memory_info.used / (1024**2)
1328
+ memory_total = memory_info.total / (1024**2)
1329
+ memory_percent = memory_info.percent
1330
+
1331
+ pid = os.getpid()
1332
+ process = psutil.Process(pid)
1333
+ nice_value = process.nice()
1334
+
1335
+ cpu_stats = (
1336
+ f"CPU Usage: {cpu_usage:.2f}%\n"
1337
+ f"System Memory: {memory_used:.2f} MB used / {memory_total:.2f} MB total ({memory_percent}% used)\n"
1338
+ f"Process Priority (Nice value): {nice_value}"
1339
+ )
1340
+
1341
+ return cpu_stats
1342
+
1343
+
1344
+ def get_combined_stats():
1345
+ gpu_stats = get_gpu_stats()
1346
+ cpu_stats = get_cpu_stats()
1347
+ combined_stats = f"### GPU Stats\n{gpu_stats}\n\n### CPU Stats\n{cpu_stats}"
1348
+ return combined_stats
1349
+
1350
+
1351
+ def get_audio_select(file_sample):
1352
+ select_audio_ref = file_sample
1353
+ select_audio_gen = file_sample
1354
+
1355
+ if file_sample is not None:
1356
+ select_audio_ref += "_ref.wav"
1357
+ select_audio_gen += "_gen.wav"
1358
+
1359
+ return select_audio_ref, select_audio_gen
1360
+
1361
+
1362
+ with gr.Blocks() as app:
1363
+ gr.Markdown(
1364
+ """
1365
+ # F5 TTS Automatic Finetune
1366
+
1367
+ This is a local web UI for F5 TTS finetuning support. This app supports the following TTS models:
1368
+
1369
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
1370
+ * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
1371
+
1372
+ The pretrained checkpoints support English and Chinese.
1373
+
1374
+ For tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
1375
+ """
1376
+ )
1377
+
1378
+ with gr.Row():
1379
+ projects, projects_selelect = get_list_projects()
1380
+ tokenizer_type = gr.Radio(label="Tokenizer Type", choices=["pinyin", "char", "custom"], value="pinyin")
1381
+ project_name = gr.Textbox(label="Project Name", value="my_speak")
1382
+ bt_create = gr.Button("Create a New Project")
1383
+
1384
+ with gr.Row():
1385
+ cm_project = gr.Dropdown(
1386
+ choices=projects, value=projects_selelect, label="Project", allow_custom_value=True, scale=6
1387
+ )
1388
+ ch_refresh_project = gr.Button("Refresh", scale=1)
1389
+
1390
+ bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project])
1391
+
1392
+ with gr.Tabs():
1393
+ with gr.TabItem("Transcribe Data"):
1394
+ gr.Markdown("""```plaintext
1395
+ Skip this step if you have your dataset, metadata.csv, and a folder wavs with all the audio files.
1396
+ ```""")
1397
+
1398
+ ch_manual = gr.Checkbox(label="Audio from Path", value=False)
1399
+
1400
+ mark_info_transcribe = gr.Markdown(
1401
+ """```plaintext
1402
+ Place your 'wavs' folder and 'metadata.csv' file in the '{your_project_name}' directory.
1403
+
1404
+ my_speak/
1405
+ β”‚
1406
+ └── dataset/
1407
+ β”œβ”€β”€ audio1.wav
1408
+ └── audio2.wav
1409
+ ...
1410
+ ```""",
1411
+ visible=False,
1412
+ )
1413
+
1414
+ audio_speaker = gr.File(label="Voice", type="filepath", file_count="multiple")
1415
+ txt_lang = gr.Textbox(label="Language", value="English")
1416
+ bt_transcribe = bt_create = gr.Button("Transcribe")
1417
+ txt_info_transcribe = gr.Textbox(label="Info", value="")
1418
+ bt_transcribe.click(
1419
+ fn=transcribe_all,
1420
+ inputs=[cm_project, audio_speaker, txt_lang, ch_manual],
1421
+ outputs=[txt_info_transcribe],
1422
+ )
1423
+ ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
1424
+
1425
+ random_sample_transcribe = gr.Button("Random Sample")
1426
+
1427
+ with gr.Row():
1428
+ random_text_transcribe = gr.Textbox(label="Text")
1429
+ random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
1430
+
1431
+ random_sample_transcribe.click(
1432
+ fn=get_random_sample_transcribe,
1433
+ inputs=[cm_project],
1434
+ outputs=[random_text_transcribe, random_audio_transcribe],
1435
+ )
1436
+
1437
+ with gr.TabItem("Vocab Check"):
1438
+ gr.Markdown("""```plaintext
1439
+ Check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. For fine-tuning a new language.
1440
+ ```""")
1441
+
1442
+ check_button = gr.Button("Check Vocab")
1443
+ txt_info_check = gr.Textbox(label="Info", value="")
1444
+
1445
+ gr.Markdown("""```plaintext
1446
+ Using the extended model, you can finetune to a new language that is missing symbols in the vocab. This creates a new model with a new vocabulary size and saves it in your ckpts/project folder.
1447
+ ```""")
1448
+
1449
+ exp_name_extend = gr.Radio(
1450
+ label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
1451
+ )
1452
+
1453
+ with gr.Row():
1454
+ txt_extend = gr.Textbox(
1455
+ label="Symbols",
1456
+ value="",
1457
+ placeholder="To add new symbols, make sure to use ',' for each symbol",
1458
+ scale=6,
1459
+ )
1460
+ txt_count_symbol = gr.Textbox(label="New Vocab Size", value="", scale=1)
1461
+
1462
+ extend_button = gr.Button("Extend")
1463
+ txt_info_extend = gr.Textbox(label="Info", value="")
1464
+
1465
+ txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
1466
+ check_button.click(
1467
+ fn=vocab_check, inputs=[cm_project, tokenizer_type], outputs=[txt_info_check, txt_extend]
1468
+ )
1469
+ extend_button.click(
1470
+ fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend]
1471
+ )
1472
+
1473
+ with gr.TabItem("Prepare Data"):
1474
+ gr.Markdown("""```plaintext
1475
+ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
1476
+ ```""")
1477
+
1478
+ gr.Markdown(
1479
+ """```plaintext
1480
+ Place all your "wavs" folder and your "metadata.csv" file in your project name directory.
1481
+
1482
+ Supported audio formats: "wav", "mp3", "aac", "flac", "m4a", "alac", "ogg", "aiff", "wma", "amr"
1483
+
1484
+ Example wav format:
1485
+ my_speak/
1486
+ β”‚
1487
+ β”œβ”€β”€ wavs/
1488
+ β”‚ β”œβ”€β”€ audio1.wav
1489
+ β”‚ └── audio2.wav
1490
+ | ...
1491
+ β”‚
1492
+ └── metadata.csv
1493
+
1494
+ File format metadata.csv:
1495
+
1496
+ audio1|text1 or audio1.wav|text1 or your_path/audio1.wav|text1
1497
+ audio2|text1 or audio2.wav|text1 or your_path/audio2.wav|text1
1498
+ ...
1499
+
1500
+ ```"""
1501
+ )
1502
+ ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)
1503
+
1504
+ bt_prepare = bt_create = gr.Button("Prepare")
1505
+ txt_info_prepare = gr.Textbox(label="Info", value="")
1506
+ txt_vocab_prepare = gr.Textbox(label="Vocab", value="")
1507
+
1508
+ bt_prepare.click(
1509
+ fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
1510
+ )
1511
+
1512
+ random_sample_prepare = gr.Button("Random Sample")
1513
+
1514
+ with gr.Row():
1515
+ random_text_prepare = gr.Textbox(label="Tokenizer")
1516
+ random_audio_prepare = gr.Audio(label="Audio", type="filepath")
1517
+
1518
+ random_sample_prepare.click(
1519
+ fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
1520
+ )
1521
+
1522
+ with gr.TabItem("Train Model"):
1523
+ gr.Markdown("""```plaintext
1524
+ The auto-setting is still experimental. Set a large value of epoch if not sure; and keep last N checkpoints if limited disk space.
1525
+ If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
1526
+ ```""")
1527
+ with gr.Row():
1528
+ exp_name = gr.Radio(label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"])
1529
+ tokenizer_file = gr.Textbox(label="Tokenizer File")
1530
+ file_checkpoint_train = gr.Textbox(label="Path to the Pretrained Checkpoint")
1531
+
1532
+ with gr.Row():
1533
+ ch_finetune = bt_create = gr.Checkbox(label="Finetune")
1534
+ lb_samples = gr.Label(label="Samples")
1535
+ bt_calculate = bt_create = gr.Button("Auto Settings")
1536
+
1537
+ with gr.Row():
1538
+ epochs = gr.Number(label="Epochs")
1539
+ learning_rate = gr.Number(label="Learning Rate", step=0.5e-5)
1540
+ max_grad_norm = gr.Number(label="Max Gradient Norm")
1541
+ num_warmup_updates = gr.Number(label="Warmup Updates")
1542
+
1543
+ with gr.Row():
1544
+ batch_size_type = gr.Radio(
1545
+ label="Batch Size Type",
1546
+ choices=["frame", "sample"],
1547
+ info="frame is calculated as seconds * sampling_rate / hop_length",
1548
+ )
1549
+ batch_size_per_gpu = gr.Number(label="Batch Size per GPU", info="N frames or N samples")
1550
+ grad_accumulation_steps = gr.Number(
1551
+ label="Gradient Accumulation Steps", info="Effective batch size is multiplied by this value"
1552
+ )
1553
+ max_samples = gr.Number(label="Max Samples", info="Maximum number of samples per single GPU batch")
1554
+
1555
+ with gr.Row():
1556
+ save_per_updates = gr.Number(
1557
+ label="Save per Updates",
1558
+ info="Save intermediate checkpoints every N updates",
1559
+ minimum=10,
1560
+ )
1561
+ keep_last_n_checkpoints = gr.Number(
1562
+ label="Keep Last N Checkpoints",
1563
+ step=1,
1564
+ precision=0,
1565
+ info="-1 to keep all, 0 to not save intermediate, > 0 to keep last N",
1566
+ minimum=-1,
1567
+ )
1568
+ last_per_updates = gr.Number(
1569
+ label="Last per Updates",
1570
+ info="Save latest checkpoint with suffix _last.pt every N updates",
1571
+ minimum=10,
1572
+ )
1573
+ gr.Radio(label="") # placeholder
1574
+
1575
+ with gr.Row():
1576
+ ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
1577
+ mixed_precision = gr.Radio(label="Mixed Precision", choices=["none", "fp16", "bf16"])
1578
+ cd_logger = gr.Radio(label="Logger", choices=["none", "wandb", "tensorboard"])
1579
+ with gr.Column():
1580
+ start_button = gr.Button("Start Training")
1581
+ stop_button = gr.Button("Stop Training", interactive=False)
1582
+
1583
+ if projects_selelect is not None:
1584
+ (
1585
+ exp_name_value,
1586
+ learning_rate_value,
1587
+ batch_size_per_gpu_value,
1588
+ batch_size_type_value,
1589
+ max_samples_value,
1590
+ grad_accumulation_steps_value,
1591
+ max_grad_norm_value,
1592
+ epochs_value,
1593
+ num_warmup_updates_value,
1594
+ save_per_updates_value,
1595
+ keep_last_n_checkpoints_value,
1596
+ last_per_updates_value,
1597
+ finetune_value,
1598
+ file_checkpoint_train_value,
1599
+ tokenizer_type_value,
1600
+ tokenizer_file_value,
1601
+ mixed_precision_value,
1602
+ logger_value,
1603
+ bnb_optimizer_value,
1604
+ ) = load_settings(projects_selelect)
1605
+
1606
+ # Assigning values to the respective components
1607
+ exp_name.value = exp_name_value
1608
+ learning_rate.value = learning_rate_value
1609
+ batch_size_per_gpu.value = batch_size_per_gpu_value
1610
+ batch_size_type.value = batch_size_type_value
1611
+ max_samples.value = max_samples_value
1612
+ grad_accumulation_steps.value = grad_accumulation_steps_value
1613
+ max_grad_norm.value = max_grad_norm_value
1614
+ epochs.value = epochs_value
1615
+ num_warmup_updates.value = num_warmup_updates_value
1616
+ save_per_updates.value = save_per_updates_value
1617
+ keep_last_n_checkpoints.value = keep_last_n_checkpoints_value
1618
+ last_per_updates.value = last_per_updates_value
1619
+ ch_finetune.value = finetune_value
1620
+ file_checkpoint_train.value = file_checkpoint_train_value
1621
+ tokenizer_type.value = tokenizer_type_value
1622
+ tokenizer_file.value = tokenizer_file_value
1623
+ mixed_precision.value = mixed_precision_value
1624
+ cd_logger.value = logger_value
1625
+ ch_8bit_adam.value = bnb_optimizer_value
1626
+
1627
+ ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
1628
+ txt_info_train = gr.Textbox(label="Info", value="")
1629
+
1630
+ list_audios, select_audio = get_audio_project(projects_selelect, False)
1631
+
1632
+ select_audio_ref = select_audio
1633
+ select_audio_gen = select_audio
1634
+
1635
+ if select_audio is not None:
1636
+ select_audio_ref += "_ref.wav"
1637
+ select_audio_gen += "_gen.wav"
1638
+
1639
+ with gr.Row():
1640
+ ch_list_audio = gr.Dropdown(
1641
+ choices=list_audios,
1642
+ value=select_audio,
1643
+ label="Audios",
1644
+ allow_custom_value=True,
1645
+ scale=6,
1646
+ interactive=True,
1647
+ )
1648
+ bt_stream_audio = gr.Button("Refresh", scale=1)
1649
+ bt_stream_audio.click(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
1650
+ cm_project.change(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
1651
+
1652
+ with gr.Row():
1653
+ audio_ref_stream = gr.Audio(label="Original", type="filepath", value=select_audio_ref)
1654
+ audio_gen_stream = gr.Audio(label="Generate", type="filepath", value=select_audio_gen)
1655
+
1656
+ ch_list_audio.change(
1657
+ fn=get_audio_select,
1658
+ inputs=[ch_list_audio],
1659
+ outputs=[audio_ref_stream, audio_gen_stream],
1660
+ )
1661
+
1662
+ start_button.click(
1663
+ fn=start_training,
1664
+ inputs=[
1665
+ cm_project,
1666
+ exp_name,
1667
+ learning_rate,
1668
+ batch_size_per_gpu,
1669
+ batch_size_type,
1670
+ max_samples,
1671
+ grad_accumulation_steps,
1672
+ max_grad_norm,
1673
+ epochs,
1674
+ num_warmup_updates,
1675
+ save_per_updates,
1676
+ keep_last_n_checkpoints,
1677
+ last_per_updates,
1678
+ ch_finetune,
1679
+ file_checkpoint_train,
1680
+ tokenizer_type,
1681
+ tokenizer_file,
1682
+ mixed_precision,
1683
+ ch_stream,
1684
+ cd_logger,
1685
+ ch_8bit_adam,
1686
+ ],
1687
+ outputs=[txt_info_train, start_button, stop_button],
1688
+ )
1689
+ stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
1690
+
1691
+ bt_calculate.click(
1692
+ fn=calculate_train,
1693
+ inputs=[
1694
+ cm_project,
1695
+ epochs,
1696
+ learning_rate,
1697
+ batch_size_per_gpu,
1698
+ batch_size_type,
1699
+ max_samples,
1700
+ num_warmup_updates,
1701
+ ch_finetune,
1702
+ ],
1703
+ outputs=[
1704
+ epochs,
1705
+ learning_rate,
1706
+ batch_size_per_gpu,
1707
+ max_samples,
1708
+ num_warmup_updates,
1709
+ lb_samples,
1710
+ ],
1711
+ )
1712
+
1713
+ ch_finetune.change(
1714
+ check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type]
1715
+ )
1716
+
1717
+ def setup_load_settings():
1718
+ output_components = [
1719
+ exp_name,
1720
+ learning_rate,
1721
+ batch_size_per_gpu,
1722
+ batch_size_type,
1723
+ max_samples,
1724
+ grad_accumulation_steps,
1725
+ max_grad_norm,
1726
+ epochs,
1727
+ num_warmup_updates,
1728
+ save_per_updates,
1729
+ keep_last_n_checkpoints,
1730
+ last_per_updates,
1731
+ ch_finetune,
1732
+ file_checkpoint_train,
1733
+ tokenizer_type,
1734
+ tokenizer_file,
1735
+ mixed_precision,
1736
+ cd_logger,
1737
+ ch_8bit_adam,
1738
+ ]
1739
+ return output_components
1740
+
1741
+ outputs = setup_load_settings()
1742
+
1743
+ cm_project.change(
1744
+ fn=load_settings,
1745
+ inputs=[cm_project],
1746
+ outputs=outputs,
1747
+ )
1748
+
1749
+ ch_refresh_project.click(
1750
+ fn=load_settings,
1751
+ inputs=[cm_project],
1752
+ outputs=outputs,
1753
+ )
1754
+
1755
+ with gr.TabItem("Test Model"):
1756
+ gr.Markdown("""```plaintext
1757
+ Check the use_ema setting (True or False) for your model to see what works best for you. Set seed to -1 for random.
1758
+ ```""")
1759
+ exp_name = gr.Radio(
1760
+ label="Model", choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], value="F5TTS_v1_Base"
1761
+ )
1762
+ list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
1763
+
1764
+ with gr.Row():
1765
+ nfe_step = gr.Number(label="NFE Step", value=32)
1766
+ speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1)
1767
+ seed = gr.Number(label="Random Seed", value=-1, minimum=-1)
1768
+ remove_silence = gr.Checkbox(label="Remove Silence")
1769
+
1770
+ with gr.Row():
1771
+ ch_use_ema = gr.Checkbox(
1772
+ label="Use EMA", value=True, info="Turn off at early stage might offer better results"
1773
+ )
1774
+ cm_checkpoint = gr.Dropdown(
1775
+ choices=list_checkpoints, value=checkpoint_select, label="Checkpoints", allow_custom_value=True
1776
+ )
1777
+ bt_checkpoint_refresh = gr.Button("Refresh")
1778
+
1779
+ random_sample_infer = gr.Button("Random Sample")
1780
+
1781
+ ref_text = gr.Textbox(label="Reference Text")
1782
+ ref_audio = gr.Audio(label="Reference Audio", type="filepath")
1783
+ gen_text = gr.Textbox(label="Text to Generate")
1784
+
1785
+ random_sample_infer.click(
1786
+ fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
1787
+ )
1788
+
1789
+ with gr.Row():
1790
+ txt_info_gpu = gr.Textbox("", label="Inference on Device :")
1791
+ seed_info = gr.Textbox(label="Used Random Seed :")
1792
+ check_button_infer = gr.Button("Inference")
1793
+
1794
+ gen_audio = gr.Audio(label="Generated Audio", type="filepath")
1795
+
1796
+ check_button_infer.click(
1797
+ fn=infer,
1798
+ inputs=[
1799
+ cm_project,
1800
+ cm_checkpoint,
1801
+ exp_name,
1802
+ ref_text,
1803
+ ref_audio,
1804
+ gen_text,
1805
+ nfe_step,
1806
+ ch_use_ema,
1807
+ speed,
1808
+ seed,
1809
+ remove_silence,
1810
+ ],
1811
+ outputs=[gen_audio, txt_info_gpu, seed_info],
1812
+ )
1813
+
1814
+ bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1815
+ cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1816
+
1817
+ with gr.TabItem("Prune Checkpoint"):
1818
+ gr.Markdown("""```plaintext
1819
+ Reduce the Base model size from 5GB to 1.3GB. The new checkpoint file prunes out optimizer and etc., can be used for inference or finetuning afterward, but not able to resume pretraining.
1820
+ ```""")
1821
+ txt_path_checkpoint = gr.Textbox(label="Path to Checkpoint:")
1822
+ txt_path_checkpoint_small = gr.Textbox(label="Path to Output:")
1823
+ with gr.Row():
1824
+ ch_save_ema = gr.Checkbox(label="Save EMA checkpoint", value=True)
1825
+ ch_safetensors = gr.Checkbox(label="Save with safetensors format", value=True)
1826
+ txt_info_reduse = gr.Textbox(label="Info", value="")
1827
+ reduse_button = gr.Button("Prune")
1828
+ reduse_button.click(
1829
+ fn=prune_checkpoint,
1830
+ inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_save_ema, ch_safetensors],
1831
+ outputs=[txt_info_reduse],
1832
+ )
1833
+
1834
+ with gr.TabItem("System Info"):
1835
+ output_box = gr.Textbox(label="GPU and CPU Information", lines=20)
1836
+
1837
+ def update_stats():
1838
+ return get_combined_stats()
1839
+
1840
+ update_button = gr.Button("Update Stats")
1841
+ update_button.click(fn=update_stats, outputs=output_box)
1842
+
1843
+ def auto_update():
1844
+ yield gr.update(value=update_stats())
1845
+
1846
+ gr.update(fn=auto_update, inputs=[], outputs=output_box)
1847
+
1848
+
1849
+ @click.command()
1850
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
1851
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
1852
+ @click.option(
1853
+ "--share",
1854
+ "-s",
1855
+ default=False,
1856
+ is_flag=True,
1857
+ help="Share the app via Gradio share link",
1858
+ )
1859
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
1860
+ def main(port, host, share, api):
1861
+ global app
1862
+ print("Starting app...")
1863
+ app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
1864
+
1865
+
1866
+ if __name__ == "__main__":
1867
+ main()