|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import sys |
|
|
import gc |
|
|
import re |
|
|
import json |
|
|
import time |
|
|
import mmap |
|
|
import math |
|
|
import torch |
|
|
import random |
|
|
import logging |
|
|
import warnings |
|
|
import traceback |
|
|
import subprocess |
|
|
import numpy as np |
|
|
import torchaudio |
|
|
import gradio as gr |
|
|
import gradio_client.utils |
|
|
import threading |
|
|
import configparser |
|
|
from pydub import AudioSegment |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple, Dict, Any, List |
|
|
from torch.cuda.amp import autocast |
|
|
from logging.handlers import RotatingFileHandler |
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import FileResponse |
|
|
from pydantic import BaseModel |
|
|
import uvicorn |
|
|
|
|
|
from colorama import init as colorama_init, Fore |
|
|
|
|
|
RELEASE = "v1.3.3" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_original_get_type = gradio_client.utils.get_type |
|
|
def _patched_get_type(schema): |
|
|
if isinstance(schema, bool): |
|
|
return "boolean" |
|
|
return _original_get_type(schema) |
|
|
gradio_client.utils.get_type = _patched_get_type |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" |
|
|
torch.backends.cudnn.benchmark = False |
|
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
BASE_DIR = Path(__file__).parent.resolve() |
|
|
LOG_DIR = BASE_DIR / "logs" |
|
|
MP3_DIR = BASE_DIR / "mp3" |
|
|
LOG_DIR.mkdir(parents=True, exist_ok=True) |
|
|
MP3_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
LOG_FILE = LOG_DIR / "ghostai_musicgen.log" |
|
|
logger = logging.getLogger("ghostai-musicgen") |
|
|
logger.setLevel(logging.DEBUG) |
|
|
file_handler = RotatingFileHandler(LOG_FILE, maxBytes=5 * 1024 * 1024, backupCount=0, encoding="utf-8") |
|
|
file_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")) |
|
|
console_handler = logging.StreamHandler(sys.stdout) |
|
|
console_handler.setFormatter(logging.Formatter("%(message)s")) |
|
|
logger.addHandler(file_handler) |
|
|
logger.addHandler(console_handler) |
|
|
|
|
|
colorama_init() |
|
|
print(f"{Fore.CYAN}GhostAI Music Generator {Fore.MAGENTA}{RELEASE}{Fore.RESET} β {Fore.GREEN}Booting...{Fore.RESET}") |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if DEVICE != "cuda": |
|
|
print(f"{Fore.RED}CUDA not available. Exiting.{Fore.RESET}") |
|
|
logger.error("CUDA is required. Exiting.") |
|
|
sys.exit(1) |
|
|
|
|
|
gpu_name = torch.cuda.get_device_name(0) |
|
|
print(f"{Fore.YELLOW}GPU:{Fore.RESET} {gpu_name}") |
|
|
print(f"{Fore.YELLOW}Precision:{Fore.RESET} fp16 (model) / fp32 (CPU audio ops)") |
|
|
|
|
|
CSS_FILE = BASE_DIR / "styles.css" |
|
|
PROMPTS_INI = BASE_DIR / "prompts.ini" |
|
|
EXAMPLES_MD = BASE_DIR / "examples.md" |
|
|
SETTINGS_FILE = BASE_DIR / "settings.json" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_SETTINGS: Dict[str, Any] = { |
|
|
"cfg_scale": 5.8, |
|
|
"top_k": 250, |
|
|
"top_p": 0.95, |
|
|
"temperature": 0.90, |
|
|
"total_duration": 60, |
|
|
"bpm": 120, |
|
|
"drum_beat": "none", |
|
|
"synthesizer": "none", |
|
|
"rhythmic_steps": "none", |
|
|
"bass_style": "none", |
|
|
"guitar_style": "none", |
|
|
"target_volume": -23.0, |
|
|
"preset": "default", |
|
|
"max_steps": 1500, |
|
|
"bitrate": "192k", |
|
|
"output_sample_rate": "48000", |
|
|
"bit_depth": "16", |
|
|
"instrumental_prompt": "", |
|
|
"style": "custom" |
|
|
} |
|
|
|
|
|
def load_settings() -> Dict[str, Any]: |
|
|
if SETTINGS_FILE.exists(): |
|
|
try: |
|
|
data = json.loads(SETTINGS_FILE.read_text()) |
|
|
for k, v in DEFAULT_SETTINGS.items(): |
|
|
data.setdefault(k, v) |
|
|
logger.info("Settings loaded.") |
|
|
return data |
|
|
except Exception as e: |
|
|
logger.error(f"Settings read failed: {e}") |
|
|
return DEFAULT_SETTINGS.copy() |
|
|
|
|
|
def save_settings(s: Dict[str, Any]) -> None: |
|
|
try: |
|
|
SETTINGS_FILE.write_text(json.dumps(s, indent=2)) |
|
|
logger.info("Settings saved.") |
|
|
except Exception as e: |
|
|
logger.error(f"Settings write failed: {e}") |
|
|
|
|
|
CURRENT_SETTINGS = load_settings() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_memory() -> Optional[float]: |
|
|
try: |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
torch.cuda.ipc_collect() |
|
|
torch.cuda.synchronize() |
|
|
vram_mb = torch.cuda.memory_allocated() / 1024**2 |
|
|
logger.debug(f"Memory cleaned. VRAM={vram_mb:.2f} MB") |
|
|
return vram_mb |
|
|
except Exception as e: |
|
|
logger.error(f"clean_memory failed: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
return None |
|
|
|
|
|
def check_vram(): |
|
|
try: |
|
|
r = subprocess.run( |
|
|
["nvidia-smi", "--query-gpu=memory.used,memory.total", "--format=csv"], |
|
|
capture_output=True, text=True |
|
|
) |
|
|
lines = r.stdout.splitlines() |
|
|
if len(lines) > 1: |
|
|
used_mb, total_mb = map(int, re.findall(r"\d+", lines[1])) |
|
|
free_mb = total_mb - used_mb |
|
|
logger.info(f"VRAM: used {used_mb} MiB | free {free_mb} MiB | total {total_mb} MiB") |
|
|
if free_mb < 5000: |
|
|
procs = subprocess.run( |
|
|
["nvidia-smi", "--query-compute-apps=pid,used_memory", "--format=csv"], |
|
|
capture_output=True, text=True |
|
|
) |
|
|
logger.info(f"GPU processes:\n{procs.stdout}") |
|
|
return free_mb |
|
|
except Exception as e: |
|
|
logger.error(f"check_vram failed: {e}") |
|
|
return None |
|
|
|
|
|
def check_disk_space(path=".") -> bool: |
|
|
try: |
|
|
stat = os.statvfs(path) |
|
|
free_gb = stat.f_bavail * stat.f_frsize / (1024**3) |
|
|
if free_gb < 1.0: |
|
|
logger.warning(f"Low disk space: {free_gb:.2f} GB") |
|
|
return free_gb >= 1.0 |
|
|
except Exception as e: |
|
|
logger.error(f"Disk space check failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_stereo(seg: AudioSegment, sample_rate=48000, sample_width=2) -> AudioSegment: |
|
|
try: |
|
|
if seg.channels != 2: |
|
|
seg = seg.set_channels(2) |
|
|
if seg.frame_rate != sample_rate: |
|
|
seg = seg.set_frame_rate(sample_rate) |
|
|
return seg |
|
|
except Exception as e: |
|
|
logger.error(f"ensure_stereo failed: {e}") |
|
|
return seg |
|
|
|
|
|
def calculate_rms(seg: AudioSegment) -> float: |
|
|
try: |
|
|
samples = np.array(seg.get_array_of_samples(), dtype=np.float32) |
|
|
return float(np.sqrt(np.mean(samples**2))) |
|
|
except Exception: |
|
|
return 0.0 |
|
|
|
|
|
def hard_limit(seg: AudioSegment, limit_db=-3.0, sample_rate=48000) -> AudioSegment: |
|
|
try: |
|
|
seg = ensure_stereo(seg, sample_rate, seg.sample_width) |
|
|
limit = 10 ** (limit_db / 20.0) * (2**23 if seg.sample_width == 3 else 32767) |
|
|
samples = np.array(seg.get_array_of_samples(), dtype=np.float32) |
|
|
samples = np.clip(samples, -limit, limit).astype(np.int32 if seg.sample_width == 3 else np.int16) |
|
|
if len(samples) % 2 != 0: |
|
|
samples = samples[:-1] |
|
|
return AudioSegment( |
|
|
samples.tobytes(), |
|
|
frame_rate=sample_rate, |
|
|
sample_width=seg.sample_width, |
|
|
channels=2 |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"hard_limit failed: {e}") |
|
|
return seg |
|
|
|
|
|
def rms_normalize(seg: AudioSegment, target_rms_db=-23.0, peak_limit_db=-3.0, sample_rate=48000) -> AudioSegment: |
|
|
try: |
|
|
seg = ensure_stereo(seg, sample_rate, seg.sample_width) |
|
|
target_rms = 10 ** (target_rms_db / 20) * (2**23 if seg.sample_width == 3 else 32767) |
|
|
current = calculate_rms(seg) |
|
|
if current > 0: |
|
|
gain = target_rms / current |
|
|
seg = seg.apply_gain(20 * np.log10(max(gain, 1e-6))) |
|
|
return hard_limit(seg, peak_limit_db, sample_rate) |
|
|
except Exception as e: |
|
|
logger.error(f"rms_normalize failed: {e}") |
|
|
return seg |
|
|
|
|
|
def balance_stereo(seg: AudioSegment, noise_threshold=-40, sample_rate=48000) -> AudioSegment: |
|
|
try: |
|
|
seg = ensure_stereo(seg, sample_rate, seg.sample_width) |
|
|
arr = np.array(seg.get_array_of_samples(), dtype=np.float32) |
|
|
if seg.channels != 2: |
|
|
return seg |
|
|
stereo = arr.reshape(-1, 2) |
|
|
db = 20 * np.log10(np.abs(stereo) + 1e-10) |
|
|
mask = db > noise_threshold |
|
|
stereo = stereo * mask |
|
|
left, right = stereo[:, 0], stereo[:, 1] |
|
|
l_rms = np.sqrt(np.mean(left[left != 0] ** 2)) if np.any(left != 0) else 0 |
|
|
r_rms = np.sqrt(np.mean(right[right != 0] ** 2)) if np.any(right != 0) else 0 |
|
|
if l_rms > 0 and r_rms > 0: |
|
|
avg = (l_rms + r_rms) / 2 |
|
|
stereo[:, 0] *= (avg / l_rms) |
|
|
stereo[:, 1] *= (avg / r_rms) |
|
|
out = stereo.flatten().astype(np.int32 if seg.sample_width == 3 else np.int16) |
|
|
if len(out) % 2 != 0: |
|
|
out = out[:-1] |
|
|
return AudioSegment(out.tobytes(), frame_rate=sample_rate, sample_width=seg.sample_width, channels=2) |
|
|
except Exception as e: |
|
|
logger.error(f"balance_stereo failed: {e}") |
|
|
return seg |
|
|
|
|
|
def apply_noise_gate(seg: AudioSegment, threshold_db=-80, sample_rate=48000) -> AudioSegment: |
|
|
try: |
|
|
seg = ensure_stereo(seg, sample_rate, seg.sample_width) |
|
|
arr = np.array(seg.get_array_of_samples(), dtype=np.float32) |
|
|
if seg.channels != 2: |
|
|
return seg |
|
|
stereo = arr.reshape(-1, 2) |
|
|
for _ in range(2): |
|
|
db = 20 * np.log10(np.abs(stereo) + 1e-10) |
|
|
stereo = stereo * (db > threshold_db) |
|
|
out = stereo.flatten().astype(np.int32 if seg.sample_width == 3 else np.int16) |
|
|
if len(out) % 2 != 0: |
|
|
out = out[:-1] |
|
|
return AudioSegment(out.tobytes(), frame_rate=sample_rate, sample_width=seg.sample_width, channels=2) |
|
|
except Exception as e: |
|
|
logger.error(f"apply_noise_gate failed: {e}") |
|
|
return seg |
|
|
|
|
|
def apply_eq(seg: AudioSegment, sample_rate=48000) -> AudioSegment: |
|
|
try: |
|
|
seg = ensure_stereo(seg, sample_rate, seg.sample_width) |
|
|
seg = seg.high_pass_filter(20) |
|
|
seg = seg.low_pass_filter(8000) |
|
|
seg = seg - 3 |
|
|
seg = seg - 3 |
|
|
seg = seg - 10 |
|
|
return seg |
|
|
except Exception as e: |
|
|
logger.error(f"apply_eq failed: {e}") |
|
|
return seg |
|
|
|
|
|
def apply_fade(seg: AudioSegment, fade_in=500, fade_out=800) -> AudioSegment: |
|
|
try: |
|
|
seg = ensure_stereo(seg, seg.frame_rate, seg.sample_width) |
|
|
return seg.fade_in(fade_in).fade_out(fade_out) |
|
|
except Exception as e: |
|
|
logger.error(f"apply_fade failed: {e}") |
|
|
return seg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SafeFormatDict(dict): |
|
|
def __missing__(self, key): |
|
|
return "" |
|
|
|
|
|
class StylesConfig: |
|
|
def __init__(self, path: Path): |
|
|
self.path = path |
|
|
self.cfg = configparser.ConfigParser(interpolation=None) |
|
|
self.mtime = 0.0 |
|
|
self.styles: Dict[str, Dict[str, Any]] = {} |
|
|
self._load() |
|
|
|
|
|
def _load(self): |
|
|
if not self.path.exists(): |
|
|
logger.error(f"prompts.ini not found: {self.path}") |
|
|
self.cfg = configparser.ConfigParser(interpolation=None) |
|
|
self.styles = {} |
|
|
self.mtime = 0.0 |
|
|
return |
|
|
self.cfg.read(self.path, encoding="utf-8") |
|
|
self.styles = {} |
|
|
for sec in self.cfg.sections(): |
|
|
d: Dict[str, Any] = {k: v for k, v in self.cfg.items(sec)} |
|
|
listish = { |
|
|
"drum_beat", "synthesizer", "rhythmic_steps", "bass_style", "guitar_style", |
|
|
"variations", "mood", "genre", "key", "scale", "feel", "instrument", |
|
|
"lead", "pad", "arp", "drums", "bass", "guitar", "strings", "brass", "woodwinds", |
|
|
"structure" |
|
|
} |
|
|
for key in listish: |
|
|
if key in d and isinstance(d[key], str): |
|
|
d[key] = [s.strip() for s in d[key].split(",") if s.strip()] |
|
|
self.styles[sec] = d |
|
|
self.mtime = self.path.stat().st_mtime |
|
|
logger.info(f"Loaded {len(self.styles)} styles from prompts.ini") |
|
|
|
|
|
def maybe_reload(self): |
|
|
if self.path.exists(): |
|
|
mt = self.path.stat().st_mtime |
|
|
if mt != self.mtime: |
|
|
self._load() |
|
|
|
|
|
def list_styles(self) -> List[str]: |
|
|
self.maybe_reload() |
|
|
return list(self.styles.keys()) |
|
|
|
|
|
def _pick_from_list(self, vals: Any) -> str: |
|
|
if isinstance(vals, list): |
|
|
return random.choice(vals) if vals else "" |
|
|
return str(vals or "") |
|
|
|
|
|
def build_prompt( |
|
|
self, |
|
|
style: str, |
|
|
bpm: int, |
|
|
chunk_num: int = 1, |
|
|
drum_beat: str = "none", |
|
|
synthesizer: str = "none", |
|
|
rhythmic_steps: str = "none", |
|
|
bass_style: str = "none", |
|
|
guitar_style: str = "none" |
|
|
) -> str: |
|
|
self.maybe_reload() |
|
|
if style not in self.styles: |
|
|
return "" |
|
|
s = self.styles[style] |
|
|
|
|
|
bpm_min = int(s.get("bpm_min", "100")) |
|
|
bpm_max = int(s.get("bpm_max", "140")) |
|
|
final_bpm = bpm if bpm != 120 else random.randint(bpm_min, bpm_max) |
|
|
|
|
|
def choose(field_name: str, incoming: str) -> str: |
|
|
if incoming and incoming != "none": |
|
|
return incoming |
|
|
return self._pick_from_list(s.get(field_name, [])) or "" |
|
|
|
|
|
d = choose("drum_beat", drum_beat) |
|
|
syn = choose("synthesizer", synthesizer) |
|
|
r = choose("rhythmic_steps", rhythmic_steps) |
|
|
b = choose("bass_style", bass_style) |
|
|
g = choose("guitar_style", guitar_style) |
|
|
|
|
|
var_list = s.get("variations", []) |
|
|
variation = "" |
|
|
if isinstance(var_list, list) and var_list: |
|
|
if chunk_num == 1: |
|
|
variation = random.choice(var_list[: max(1, len(var_list)//2)]) |
|
|
else: |
|
|
variation = random.choice(var_list) |
|
|
|
|
|
fields: Dict[str, Any] = {} |
|
|
for k, v in s.items(): |
|
|
fields[k] = self._pick_from_list(v) if isinstance(v, list) else v |
|
|
|
|
|
if "structure" in s: |
|
|
fields["section"] = self._pick_from_list(s["structure"]) |
|
|
|
|
|
fields.update({ |
|
|
"bpm": final_bpm, |
|
|
"chunk": chunk_num, |
|
|
"drum": f" {d}" if d else "", |
|
|
"synth": f" {syn}" if syn else "", |
|
|
"rhythm": f" {r}" if r else "", |
|
|
"bass": f" {b}" if b else "", |
|
|
"guitar": f" {g}" if g else "", |
|
|
"variation": variation |
|
|
}) |
|
|
|
|
|
tpl = s.get( |
|
|
"prompt_template", |
|
|
"Instrumental track at {bpm} BPM {variation}. {mood} {section} {drum}{bass}{guitar}{synth}{rhythm}" |
|
|
) |
|
|
|
|
|
prompt = tpl.format_map(SafeFormatDict(fields)) |
|
|
prompt = re.sub(r"\s{2,}", " ", prompt).strip() |
|
|
return prompt |
|
|
|
|
|
def style_defaults_for_ui(self, style: str) -> Dict[str, Any]: |
|
|
self.maybe_reload() |
|
|
s = self.styles.get(style, {}) |
|
|
bpm_min = int(s.get("bpm_min", "100")) |
|
|
bpm_max = int(s.get("bpm_max", "140")) |
|
|
chosen = { |
|
|
"bpm": random.randint(bpm_min, bpm_max), |
|
|
"drum_beat": self._pick_from_list(s.get("drum_beat", [])) or "none", |
|
|
"synthesizer": self._pick_from_list(s.get("synthesizer", [])) or "none", |
|
|
"rhythmic_steps": self._pick_from_list(s.get("rhythmic_steps", [])) or "none", |
|
|
"bass_style": self._pick_from_list(s.get("bass_style", [])) or "none", |
|
|
"guitar_style": self._pick_from_list(s.get("guitar_style", [])) or "none", |
|
|
} |
|
|
for k, v in chosen.items(): |
|
|
if v == "": |
|
|
chosen[k] = "none" |
|
|
return chosen |
|
|
|
|
|
STYLES = StylesConfig(PROMPTS_INI) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from audiocraft.models import MusicGen |
|
|
except Exception as e: |
|
|
logger.error("audiocraft is required. pip install audiocraft") |
|
|
raise |
|
|
|
|
|
def load_model(): |
|
|
free = check_vram() |
|
|
if free is not None and free < 5000: |
|
|
logger.warning("Low free VRAM; consider closing other apps.") |
|
|
clean_memory() |
|
|
local_model_path = str(BASE_DIR / "models" / "musicgen-large") |
|
|
if not os.path.exists(local_model_path): |
|
|
logger.error(f"Model path missing: {local_model_path}") |
|
|
sys.exit(1) |
|
|
logger.info("Loading MusicGen (large)...") |
|
|
with autocast(dtype=torch.float16): |
|
|
model = MusicGen.get_pretrained(local_model_path, device=DEVICE) |
|
|
model.set_generation_params(duration=30, two_step_cfg=False) |
|
|
logger.info("MusicGen loaded.") |
|
|
return model |
|
|
|
|
|
musicgen_model = load_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _export_torch_to_segment(audio_tensor: torch.Tensor, sample_rate: int, bit_depth_int: int) -> Optional[AudioSegment]: |
|
|
tmp = f"temp_audio_{int(time.time()*1000)}.wav" |
|
|
try: |
|
|
torchaudio.save(tmp, audio_tensor, sample_rate, bits_per_sample=bit_depth_int) |
|
|
with open(tmp, "rb") as f: |
|
|
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) |
|
|
seg = AudioSegment.from_wav(tmp) |
|
|
mm.close() |
|
|
return seg |
|
|
except Exception as e: |
|
|
logger.error(f"_export_torch_to_segment failed: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
return None |
|
|
finally: |
|
|
try: |
|
|
if os.path.exists(tmp): |
|
|
os.remove(tmp) |
|
|
except OSError: |
|
|
pass |
|
|
|
|
|
def _crossfade(seg_a: AudioSegment, seg_b: AudioSegment, overlap_ms: int, sr: int, bit_depth_int: int) -> AudioSegment: |
|
|
try: |
|
|
seg_a = ensure_stereo(seg_a, sr, seg_a.sample_width) |
|
|
seg_b = ensure_stereo(seg_b, sr, seg_b.sample_width) |
|
|
if overlap_ms <= 0 or len(seg_a) < overlap_ms or len(seg_b) < overlap_ms: |
|
|
return seg_a + seg_b |
|
|
prev_wav = f"tmp_prev_{int(time.time()*1000)}.wav" |
|
|
curr_wav = f"tmp_curr_{int(time.time()*1000)}.wav" |
|
|
try: |
|
|
seg_a[-overlap_ms:].export(prev_wav, format="wav") |
|
|
seg_b[:overlap_ms].export(curr_wav, format="wav") |
|
|
a_audio, sra = torchaudio.load(prev_wav) |
|
|
b_audio, srb = torchaudio.load(curr_wav) |
|
|
if sra != sr: |
|
|
a_audio = torchaudio.functional.resample(a_audio, sra, sr, lowpass_filter_width=64) |
|
|
if srb != sr: |
|
|
b_audio = torchaudio.functional.resample(b_audio, srb, sr, lowpass_filter_width=64) |
|
|
n = min(a_audio.shape[1], b_audio.shape[1]) |
|
|
n = n - (n % 2) |
|
|
if n <= 0: |
|
|
return seg_a + seg_b |
|
|
a = a_audio[:, :n] |
|
|
b = b_audio[:, :n] |
|
|
hann = torch.hann_window(n, periodic=False) |
|
|
fade_in = hann |
|
|
fade_out = hann.flip(0) |
|
|
blended = (a * fade_out + b * fade_in).to(torch.float32).clamp(-1.0, 1.0) |
|
|
scale = (2**23 if bit_depth_int == 24 else 32767) |
|
|
blended_i = (blended * scale).to(torch.int32 if bit_depth_int == 24 else torch.int16) |
|
|
tmpx = f"tmp_cross_{int(time.time()*1000)}.wav" |
|
|
torchaudio.save(tmpx, blended_i, sr, bits_per_sample=bit_depth_int) |
|
|
blend_seg = AudioSegment.from_wav(tmpx) |
|
|
blend_seg = ensure_stereo(blend_seg, sr, blend_seg.sample_width) |
|
|
result = seg_a[:-overlap_ms] + blend_seg + seg_b[overlap_ms:] |
|
|
try: |
|
|
if os.path.exists(tmpx): |
|
|
os.remove(tmpx) |
|
|
except OSError: |
|
|
pass |
|
|
return result |
|
|
finally: |
|
|
for p in [prev_wav, curr_wav]: |
|
|
try: |
|
|
if os.path.exists(p): |
|
|
os.remove(p) |
|
|
except OSError: |
|
|
pass |
|
|
except Exception as e: |
|
|
logger.error(f"_crossfade failed: {e}") |
|
|
return seg_a + seg_b |
|
|
|
|
|
def _slugify_style(style_key: Optional[str]) -> str: |
|
|
if not style_key: |
|
|
return "ghostai" |
|
|
slug = style_key.lower().strip() |
|
|
slug = re.sub(r"\s+", "_", slug) |
|
|
slug = re.sub(r"[^a-z0-9_\-]+", "-", slug) |
|
|
slug = re.sub(r"-{2,}", "-", slug).strip("-") |
|
|
return slug or "ghostai" |
|
|
|
|
|
def generate_music( |
|
|
instrumental_prompt: str, |
|
|
cfg_scale: float, |
|
|
top_k: int, |
|
|
top_p: float, |
|
|
temperature: float, |
|
|
total_duration: int, |
|
|
bpm: int, |
|
|
drum_beat: str, |
|
|
synthesizer: str, |
|
|
rhythmic_steps: str, |
|
|
bass_style: str, |
|
|
guitar_style: str, |
|
|
target_volume: float, |
|
|
preset: str, |
|
|
max_steps: str, |
|
|
vram_status_text: str, |
|
|
bitrate: str, |
|
|
output_sample_rate: str, |
|
|
bit_depth: str, |
|
|
style_key: Optional[str] = None |
|
|
) -> Tuple[Optional[str], str, str]: |
|
|
|
|
|
if not instrumental_prompt.strip(): |
|
|
return None, "β οΈ Enter a prompt.", vram_status_text |
|
|
|
|
|
try: |
|
|
out_sr = int(output_sample_rate) |
|
|
except: |
|
|
return None, "β Invalid sample rate.", vram_status_text |
|
|
try: |
|
|
bd = int(bit_depth) |
|
|
sample_width = 3 if bd == 24 else 2 |
|
|
except: |
|
|
return None, "β Invalid bit depth.", vram_status_text |
|
|
if not check_disk_space(): |
|
|
return None, "β οΈ Low disk space (<1GB).", vram_status_text |
|
|
|
|
|
CHUNK_SEC = 30 |
|
|
total_duration = max(30, min(int(total_duration), 120)) |
|
|
num_chunks = math.ceil(total_duration / CHUNK_SEC) |
|
|
|
|
|
PROCESS_SR = 48000 |
|
|
OVERLAP_SEC = 0.20 |
|
|
seed = random.randint(0, 2**31 - 1) |
|
|
random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
musicgen_model.set_generation_params( |
|
|
duration=CHUNK_SEC, |
|
|
use_sampling=True, |
|
|
top_k=int(top_k), |
|
|
top_p=float(top_p), |
|
|
temperature=float(temperature), |
|
|
cfg_coef=float(cfg_scale), |
|
|
two_step_cfg=False, |
|
|
) |
|
|
|
|
|
vram_status_text = f"Start VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" |
|
|
segments: List[AudioSegment] = [] |
|
|
start_time = time.time() |
|
|
|
|
|
for idx in range(num_chunks): |
|
|
chunk_idx = idx + 1 |
|
|
dur = CHUNK_SEC if (idx < num_chunks - 1) else (total_duration - CHUNK_SEC * (num_chunks - 1) or CHUNK_SEC) |
|
|
logger.info(f"Generating chunk {chunk_idx}/{num_chunks} ({dur}s)") |
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
with autocast(dtype=torch.float16): |
|
|
clean_memory() |
|
|
if idx == 0: |
|
|
audio = musicgen_model.generate([instrumental_prompt], progress=True)[0].cpu() |
|
|
else: |
|
|
prev_seg = segments[-1] |
|
|
prev_seg = apply_noise_gate(prev_seg, threshold_db=-80, sample_rate=PROCESS_SR) |
|
|
prev_seg = balance_stereo(prev_seg, noise_threshold=-40, sample_rate=PROCESS_SR) |
|
|
tmp_prev = f"prev_{int(time.time()*1000)}.wav" |
|
|
try: |
|
|
prev_seg.export(tmp_prev, format="wav") |
|
|
prev_audio, prev_sr = torchaudio.load(tmp_prev) |
|
|
if prev_sr != PROCESS_SR: |
|
|
prev_audio = torchaudio.functional.resample(prev_audio, prev_sr, PROCESS_SR, lowpass_filter_width=64) |
|
|
if prev_audio.shape[0] != 2: |
|
|
prev_audio = prev_audio.repeat(2, 1)[:, :prev_audio.shape[1]] |
|
|
prev_audio = prev_audio.to(DEVICE) |
|
|
tail = prev_audio[:, -int(PROCESS_SR * OVERLAP_SEC):] |
|
|
audio = musicgen_model.generate_continuation( |
|
|
prompt=tail, |
|
|
prompt_sample_rate=PROCESS_SR, |
|
|
descriptions=[instrumental_prompt], |
|
|
progress=True |
|
|
)[0].cpu() |
|
|
del prev_audio, tail |
|
|
finally: |
|
|
try: |
|
|
if os.path.exists(tmp_prev): |
|
|
os.remove(tmp_prev) |
|
|
except OSError: |
|
|
pass |
|
|
clean_memory() |
|
|
except Exception as e: |
|
|
logger.error(f"Chunk {chunk_idx} generation failed: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
return None, f"β Generate failed at chunk {chunk_idx}.", vram_status_text |
|
|
|
|
|
try: |
|
|
if audio.shape[0] != 2: |
|
|
audio = audio.repeat(2, 1)[:, :audio.shape[1]] |
|
|
audio = audio.to(dtype=torch.float32) |
|
|
audio = torchaudio.functional.resample(audio, 32000, PROCESS_SR, lowpass_filter_width=64) |
|
|
seg = _export_torch_to_segment(audio, PROCESS_SR, bd) |
|
|
if seg is None: |
|
|
return None, f"β Convert failed chunk {chunk_idx}.", vram_status_text |
|
|
seg = ensure_stereo(seg, PROCESS_SR, sample_width) |
|
|
seg = seg - 15 |
|
|
seg = apply_noise_gate(seg, threshold_db=-80, sample_rate=PROCESS_SR) |
|
|
seg = balance_stereo(seg, noise_threshold=-40, sample_rate=PROCESS_SR) |
|
|
seg = rms_normalize(seg, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=PROCESS_SR) |
|
|
seg = apply_eq(seg, sample_rate=PROCESS_SR) |
|
|
seg = seg[:dur * 1000] |
|
|
segments.append(seg) |
|
|
del audio |
|
|
clean_memory() |
|
|
vram_status_text = f"VRAM after chunk {chunk_idx}: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" |
|
|
except Exception as e: |
|
|
logger.error(f"Post-process failed chunk {chunk_idx}: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
return None, f"β Post-process failed chunk {chunk_idx}.", vram_status_text |
|
|
|
|
|
if not segments: |
|
|
return None, "β No audio generated.", vram_status_text |
|
|
|
|
|
logger.info("Combining chunks...") |
|
|
final_seg = segments[0] |
|
|
overlap_ms = int(OVERLAP_SEC * 1000) |
|
|
for i in range(1, len(segments)): |
|
|
final_seg = _crossfade(final_seg, segments[i], overlap_ms, PROCESS_SR, bd) |
|
|
|
|
|
final_seg = final_seg[:total_duration * 1000] |
|
|
final_seg = apply_noise_gate(final_seg, threshold_db=-80, sample_rate=PROCESS_SR) |
|
|
final_seg = balance_stereo(final_seg, noise_threshold=-40, sample_rate=PROCESS_SR) |
|
|
final_seg = rms_normalize(final_seg, target_rms_db=target_volume, peak_limit_db=-3.0, sample_rate=PROCESS_SR) |
|
|
final_seg = apply_eq(final_seg, sample_rate=PROCESS_SR) |
|
|
final_seg = apply_fade(final_seg, 500, 800) |
|
|
final_seg = final_seg - 10 |
|
|
final_seg = final_seg.set_frame_rate(out_sr) |
|
|
|
|
|
style_slug = _slugify_style(style_key) |
|
|
fname = f"{style_slug}_{int(time.time())}.mp3" |
|
|
mp3_path = str(MP3_DIR / fname) |
|
|
try: |
|
|
clean_memory() |
|
|
final_seg.export( |
|
|
mp3_path, |
|
|
format="mp3", |
|
|
bitrate=bitrate, |
|
|
tags={"title": f"GhostAI Instrumental β {style_slug}", "artist": "GhostAI"} |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"MP3 export failed: {e}") |
|
|
fb = str(MP3_DIR / f"{style_slug}_fb_{int(time.time())}.mp3") |
|
|
try: |
|
|
final_seg.export(fb, format="mp3", bitrate="128k") |
|
|
mp3_path = fb |
|
|
except Exception as ee: |
|
|
return None, f"β Export failed: {ee}", vram_status_text |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
vram_status_text = f"Final VRAM: {torch.cuda.memory_allocated() / 1024**2:.2f} MB" |
|
|
logger.info(f"Done in {elapsed:.2f}s -> {mp3_path}") |
|
|
return mp3_path, "β
Generated.", vram_status_text |
|
|
|
|
|
def generate_music_wrapper(*args): |
|
|
try: |
|
|
return generate_music(*args) |
|
|
finally: |
|
|
clean_memory() |
|
|
|
|
|
def clear_inputs(): |
|
|
s = DEFAULT_SETTINGS.copy() |
|
|
return ( |
|
|
s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"], |
|
|
s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"], |
|
|
s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"], |
|
|
s["bitrate"], s["output_sample_rate"], s["bit_depth"], s["style"] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BUSY_LOCK = threading.Lock() |
|
|
BUSY_FLAG = False |
|
|
BUSY_FILE = "/tmp/musicgen_busy.lock" |
|
|
CURRENT_JOB: Dict[str, Any] = {"id": None, "start": None} |
|
|
|
|
|
def set_busy(val: bool, job_id: Optional[str] = None): |
|
|
global BUSY_FLAG, CURRENT_JOB |
|
|
with BUSY_LOCK: |
|
|
BUSY_FLAG = val |
|
|
if val: |
|
|
CURRENT_JOB["id"] = job_id or f"job_{int(time.time())}" |
|
|
CURRENT_JOB["start"] = time.time() |
|
|
try: |
|
|
Path(BUSY_FILE).write_text(CURRENT_JOB["id"]) |
|
|
except Exception: |
|
|
pass |
|
|
else: |
|
|
CURRENT_JOB["id"] = None |
|
|
CURRENT_JOB["start"] = None |
|
|
try: |
|
|
if os.path.exists(BUSY_FILE): |
|
|
os.remove(BUSY_FILE) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def is_busy() -> bool: |
|
|
with BUSY_LOCK: |
|
|
return BUSY_FLAG |
|
|
|
|
|
def job_elapsed() -> float: |
|
|
with BUSY_LOCK: |
|
|
if CURRENT_JOB["start"] is None: |
|
|
return 0.0 |
|
|
return time.time() - CURRENT_JOB["start"] |
|
|
|
|
|
class RenderRequest(BaseModel): |
|
|
instrumental_prompt: str |
|
|
cfg_scale: Optional[float] = None |
|
|
top_k: Optional[int] = None |
|
|
top_p: Optional[float] = None |
|
|
temperature: Optional[float] = None |
|
|
total_duration: Optional[int] = None |
|
|
bpm: Optional[int] = None |
|
|
drum_beat: Optional[str] = None |
|
|
synthesizer: Optional[str] = None |
|
|
rhythmic_steps: Optional[str] = None |
|
|
bass_style: Optional[str] = None |
|
|
guitar_style: Optional[str] = None |
|
|
target_volume: Optional[float] = None |
|
|
preset: Optional[str] = None |
|
|
max_steps: Optional[int] = None |
|
|
bitrate: Optional[str] = None |
|
|
output_sample_rate: Optional[str] = None |
|
|
bit_depth: Optional[str] = None |
|
|
style: Optional[str] = None |
|
|
|
|
|
fastapp = FastAPI(title=f"GhostAI Music Server {RELEASE}", version=RELEASE) |
|
|
fastapp.add_middleware( |
|
|
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"] |
|
|
) |
|
|
|
|
|
@fastapp.get("/health") |
|
|
def health(): |
|
|
return {"ok": True, "ts": int(time.time()), "release": RELEASE} |
|
|
|
|
|
@fastapp.get("/status") |
|
|
def status(): |
|
|
return {"busy": is_busy(), "job_id": CURRENT_JOB["id"], "since": CURRENT_JOB["start"], "elapsed": job_elapsed()} |
|
|
|
|
|
@fastapp.get("/styles") |
|
|
def styles(): |
|
|
return {"styles": STYLES.list_styles()} |
|
|
|
|
|
@fastapp.get("/prompt/{style}") |
|
|
def prompt(style: str, bpm: int = 120, chunk: int = 1, |
|
|
drum_beat: str = "none", synthesizer: str = "none", rhythmic_steps: str = "none", |
|
|
bass_style: str = "none", guitar_style: str = "none"): |
|
|
txt = STYLES.build_prompt(style, bpm, chunk, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style) |
|
|
if not txt: |
|
|
raise HTTPException(status_code=404, detail="Style not found") |
|
|
return {"style": style, "prompt": txt} |
|
|
|
|
|
for sec, cfg in list(STYLES.styles.items()): |
|
|
api_name = cfg.get("api_name") |
|
|
if api_name: |
|
|
route_path = api_name |
|
|
def make_route(sname, route_path_): |
|
|
@fastapp.get(route_path_) |
|
|
def _(bpm: int = 120, chunk: int = 1, |
|
|
drum_beat: str = "none", synthesizer: str = "none", rhythmic_steps: str = "none", |
|
|
bass_style: str = "none", guitar_style: str = "none"): |
|
|
txt = STYLES.build_prompt(sname, bpm, chunk, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style) |
|
|
if not txt: |
|
|
raise HTTPException(status_code=404, detail="Style not found") |
|
|
return {"style": sname, "prompt": txt} |
|
|
make_route(sec, route_path) |
|
|
|
|
|
@fastapp.get("/config") |
|
|
def get_config(): |
|
|
return {"defaults": CURRENT_SETTINGS, "release": RELEASE} |
|
|
|
|
|
@fastapp.post("/settings") |
|
|
def set_settings(payload: Dict[str, Any]): |
|
|
try: |
|
|
s = CURRENT_SETTINGS.copy() |
|
|
s.update(payload or {}) |
|
|
save_settings(s) |
|
|
for k, v in s.items(): |
|
|
CURRENT_SETTINGS[k] = v |
|
|
return {"ok": True, "saved": s} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ascii_header(s: str) -> str: |
|
|
return re.sub(r'[^\x20-\x7E]', '', str(s or '')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@fastapp.post("/render") |
|
|
def render(req: RenderRequest): |
|
|
if is_busy(): |
|
|
raise HTTPException(status_code=409, detail="Server busy") |
|
|
job_id = f"render_{int(time.time())}" |
|
|
set_busy(True, job_id) |
|
|
try: |
|
|
s = CURRENT_SETTINGS.copy() |
|
|
for k, v in req.dict().items(): |
|
|
if v is not None: |
|
|
s[k] = v |
|
|
|
|
|
mp3_path, msg, vram = generate_music( |
|
|
s.get("instrumental_prompt", req.instrumental_prompt), |
|
|
float(s.get("cfg_scale", DEFAULT_SETTINGS["cfg_scale"])), |
|
|
int(s.get("top_k", DEFAULT_SETTINGS["top_k"])), |
|
|
float(s.get("top_p", DEFAULT_SETTINGS["top_p"])), |
|
|
float(s.get("temperature", DEFAULT_SETTINGS["temperature"])), |
|
|
int(s.get("total_duration", DEFAULT_SETTINGS["total_duration"])), |
|
|
int(s.get("bpm", DEFAULT_SETTINGS["bpm"])), |
|
|
str(s.get("drum_beat", DEFAULT_SETTINGS["drum_beat"])), |
|
|
str(s.get("synthesizer", DEFAULT_SETTINGS["synthesizer"])), |
|
|
str(s.get("rhythmic_steps", DEFAULT_SETTINGS["rhythmic_steps"])), |
|
|
str(s.get("bass_style", DEFAULT_SETTINGS["bass_style"])), |
|
|
str(s.get("guitar_style", DEFAULT_SETTINGS["guitar_style"])), |
|
|
float(s.get("target_volume", DEFAULT_SETTINGS["target_volume"])), |
|
|
str(s.get("preset", DEFAULT_SETTINGS["preset"])), |
|
|
str(s.get("max_steps", DEFAULT_SETTINGS["max_steps"])), |
|
|
"", |
|
|
str(s.get("bitrate", DEFAULT_SETTINGS["bitrate"])), |
|
|
str(s.get("output_sample_rate", DEFAULT_SETTINGS["output_sample_rate"])), |
|
|
str(s.get("bit_depth", DEFAULT_SETTINGS["bit_depth"])), |
|
|
str(s.get("style", "custom")) |
|
|
) |
|
|
|
|
|
if not mp3_path or not os.path.exists(mp3_path): |
|
|
raise HTTPException(status_code=500, detail=_ascii_header(msg or "No file produced")) |
|
|
|
|
|
filename = os.path.basename(mp3_path) |
|
|
headers = { |
|
|
"X-Job-ID": _ascii_header(job_id), |
|
|
"X-Status": _ascii_header(msg), |
|
|
"X-VRAM": _ascii_header(vram), |
|
|
"X-Release": _ascii_header(RELEASE), |
|
|
} |
|
|
return FileResponse( |
|
|
path=mp3_path, |
|
|
media_type="audio/mpeg", |
|
|
filename=_ascii_header(filename), |
|
|
headers=headers, |
|
|
) |
|
|
finally: |
|
|
set_busy(False, None) |
|
|
|
|
|
def _start_fastapi(): |
|
|
uvicorn.run(fastapp, host="0.0.0.0", port=8555, log_level="info") |
|
|
|
|
|
api_thread = threading.Thread(target=_start_fastapi, daemon=True) |
|
|
api_thread.start() |
|
|
logger.info(f"FastAPI server started on http://0.0.0.0:8555 [{RELEASE}]") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_css() -> str: |
|
|
try: |
|
|
if CSS_FILE.exists(): |
|
|
return CSS_FILE.read_text(encoding="utf-8") |
|
|
return """ |
|
|
:root { color-scheme: dark; } |
|
|
body, .gradio-container { background: #0E1014 !important; color: #FFFFFF !important; } |
|
|
* { color: #FFFFFF !important; } |
|
|
input, textarea, select { |
|
|
background: #151922 !important; color: #FFFFFF !important; |
|
|
border: 1px solid #2A3142 !important; border-radius: 10px !important; |
|
|
} |
|
|
.ga-header { display:flex; gap:12px; align-items:center; } |
|
|
.ga-header .logo { font-size: 28px; } |
|
|
""" |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to read CSS: {e}") |
|
|
return "" |
|
|
|
|
|
def read_examples() -> str: |
|
|
try: |
|
|
return EXAMPLES_MD.read_text(encoding="utf-8") |
|
|
except Exception: |
|
|
return "# GhostAI Examples\n\n_Provide examples.md next to app.py_" |
|
|
|
|
|
loaded = CURRENT_SETTINGS |
|
|
|
|
|
with gr.Blocks(css=read_css(), analytics_enabled=False, title=f"GhostAI Music Generator {RELEASE}") as demo: |
|
|
with gr.Tabs(): |
|
|
with gr.Tab(f"ποΈ Generator β {RELEASE}"): |
|
|
gr.Markdown(f""" |
|
|
<div class="ga-header" role="banner" aria-label="GhostAI Music Generator"> |
|
|
<div class="logo">π»</div> |
|
|
<h1>GhostAI Music Generator</h1> |
|
|
<p>Unified 30s chunking Β· 60β120s ready Β· API & status</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Group(elem_classes="ga-section"): |
|
|
gr.Markdown("### Prompt") |
|
|
instrumental_prompt = gr.Textbox( |
|
|
label="Instrumental Prompt", |
|
|
placeholder="Type a prompt or click a style button below", |
|
|
lines=4, |
|
|
value=loaded.get("instrumental_prompt", "") |
|
|
) |
|
|
|
|
|
with gr.Group(elem_classes="ga-section"): |
|
|
gr.Markdown("### Band / Style (grid 4 per row)") |
|
|
def row_of_buttons(entries): |
|
|
with gr.Row(equal_height=True): |
|
|
buttons = [] |
|
|
for key, label in entries: |
|
|
btn = gr.Button(label, variant="secondary", scale=1, min_width=0) |
|
|
buttons.append((key, btn)) |
|
|
return buttons |
|
|
|
|
|
row1 = row_of_buttons([ |
|
|
("metallica", "Metallica (Thrash) πΈ"), |
|
|
("nirvana", "Nirvana (Grunge) π€"), |
|
|
("pearl_jam", "Pearl Jam (Grunge) π¦ͺ"), |
|
|
("soundgarden", "Soundgarden (Grunge/Alt Metal) π"), |
|
|
]) |
|
|
row2 = row_of_buttons([ |
|
|
("foo_fighters", "Foo Fighters (Alt Rock) π€"), |
|
|
("red_hot_chili_peppers", "Red Hot Chili Peppers (Funk Rock) πΆοΈ"), |
|
|
("smashing_pumpkins", "Smashing Pumpkins (Alt) π"), |
|
|
("radiohead", "Radiohead (Experimental) π§ "), |
|
|
]) |
|
|
row3 = row_of_buttons([ |
|
|
("alternative_rock", "Alternative Rock (Pixies) π΅"), |
|
|
("post_punk", "Post-Punk (Joy Division) π€"), |
|
|
("indie_rock", "Indie Rock (Arctic Monkeys) π€"), |
|
|
("funk_rock", "Funk Rock (RATM) πΊ"), |
|
|
]) |
|
|
row4 = row_of_buttons([ |
|
|
("detroit_techno", "Detroit Techno ποΈ"), |
|
|
("deep_house", "Deep House π "), |
|
|
("classical_star_wars", "Classical (Star Wars Suite) β¨"), |
|
|
("foo_pad", "β") |
|
|
]) |
|
|
|
|
|
with gr.Group(elem_classes="ga-section"): |
|
|
gr.Markdown("### Settings") |
|
|
with gr.Group(): |
|
|
with gr.Row(): |
|
|
cfg_scale = gr.Slider(1.0, 10.0, step=0.1, value=float(loaded.get("cfg_scale", DEFAULT_SETTINGS["cfg_scale"])), label="CFG Scale") |
|
|
top_k = gr.Slider(10, 500, step=10, value=int(loaded.get("top_k", DEFAULT_SETTINGS["top_k"])), label="Top-K") |
|
|
top_p = gr.Slider(0.0, 1.0, step=0.01, value=float(loaded.get("top_p", DEFAULT_SETTINGS["top_p"])), label="Top-P") |
|
|
temperature = gr.Slider(0.1, 2.0, step=0.01, value=float(loaded.get("temperature", DEFAULT_SETTINGS["temperature"])), label="Temperature") |
|
|
with gr.Row(): |
|
|
total_duration = gr.Dropdown(choices=[30, 60, 90, 120], value=int(loaded.get("total_duration", 60)), label="Song Length (seconds)") |
|
|
bpm = gr.Slider(60, 180, step=1, value=int(loaded.get("bpm", 120)), label="Tempo (BPM)") |
|
|
target_volume = gr.Slider(-30.0, -20.0, step=0.5, value=float(loaded.get("target_volume", -23.0)), label="Target Loudness (dBFS RMS)") |
|
|
preset = gr.Dropdown(choices=["default", "rock", "techno", "grunge", "indie", "funk_rock"], value=str(loaded.get("preset", "default")), label="Preset") |
|
|
with gr.Row(): |
|
|
drum_beat = gr.Dropdown(choices=["none", "standard rock", "funk groove", "techno kick", "jazz swing", "four-on-the-floor", "steady kick", "orchestral percussion", "precise drums", "heavy drums"], value=str(loaded.get("drum_beat", "none")), label="Drum Beat") |
|
|
synthesizer = gr.Dropdown(choices=["none", "analog synth", "digital pad", "arpeggiated synth", "lush synths", "atmospheric synths", "pulsing synths", "analog pad", "warm synths"], value=str(loaded.get("synthesizer", "none")), label="Synthesizer") |
|
|
rhythmic_steps = gr.Dropdown(choices=["none", "syncopated steps", "steady steps", "complex steps", "martial march", "staccato ostinato", "triplet swells"], value=str(loaded.get("rhythmic_steps", "none")), label="Rhythmic Steps") |
|
|
with gr.Row(): |
|
|
bass_style = gr.Dropdown(choices=["none", "slap bass", "deep bass", "melodic bass", "groovy bass", "hypnotic bass", "driving bass", "low brass", "cellos", "double basses", "subby bass"], value=str(loaded.get("bass_style", "none")), label="Bass Style") |
|
|
guitar_style = gr.Dropdown(choices=["none", "distorted", "clean", "jangle", "downpicked", "thrash riffing", "dreamy", "experimental", "funky"], value=str(loaded.get("guitar_style", "none")), label="Guitar Style") |
|
|
max_steps = gr.Dropdown(choices=[1000, 1200, 1300, 1500], value=int(loaded.get("max_steps", 1500)), label="Max Steps (hint)") |
|
|
|
|
|
bitrate_state = gr.State(value=str(loaded.get("bitrate", "192k"))) |
|
|
sample_rate_state = gr.State(value=str(loaded.get("output_sample_rate", "48000"))) |
|
|
bit_depth_state = gr.State(value=str(loaded.get("bit_depth", "16"))) |
|
|
selected_style = gr.State(value=str(loaded.get("style", "custom"))) |
|
|
|
|
|
with gr.Row(): |
|
|
bitrate_128_btn = gr.Button("Bitrate 128k", variant="secondary") |
|
|
bitrate_192_btn = gr.Button("Bitrate 192k", variant="secondary") |
|
|
bitrate_320_btn = gr.Button("Bitrate 320k", variant="secondary") |
|
|
sample_rate_22050_btn = gr.Button("SR 22.05k", variant="secondary") |
|
|
sample_rate_44100_btn = gr.Button("SR 44.1k", variant="secondary") |
|
|
sample_rate_48000_btn = gr.Button("SR 48k", variant="secondary") |
|
|
bit_depth_16_btn = gr.Button("16-bit", variant="secondary") |
|
|
bit_depth_24_btn = gr.Button("24-bit", variant="secondary") |
|
|
|
|
|
with gr.Row(): |
|
|
gen_btn = gr.Button("Generate πΆ", variant="primary") |
|
|
clr_btn = gr.Button("Clear π§Ή", variant="secondary") |
|
|
save_btn = gr.Button("Save Settings πΎ", variant="secondary") |
|
|
load_btn = gr.Button("Load Settings π", variant="secondary") |
|
|
reset_btn = gr.Button("Reset Defaults β»οΈ", variant="secondary") |
|
|
|
|
|
with gr.Group(elem_classes="ga-section"): |
|
|
gr.Markdown("### Output") |
|
|
out_audio = gr.Audio(label="Generated Track", type="filepath") |
|
|
status_box = gr.Textbox(label="Status", interactive=False) |
|
|
vram_box = gr.Textbox(label="VRAM", interactive=False, value="") |
|
|
|
|
|
with gr.Group(elem_classes="ga-section"): |
|
|
gr.Markdown("### Logs") |
|
|
log_output = gr.Textbox(label="Current Log (rotating β€ 5MB)", lines=14, interactive=False) |
|
|
log_btn = gr.Button("View Log π", variant="secondary") |
|
|
|
|
|
with gr.Tab("π Info & Examples"): |
|
|
md_box = gr.Markdown(read_examples()) |
|
|
refresh_md = gr.Button("Refresh Examples.md", variant="secondary") |
|
|
refresh_md.click(lambda: read_examples(), outputs=md_box) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_prompt_and_settings_from_style(style_key, current_bpm, current_drum, current_synth, current_steps, current_bass, current_guitar): |
|
|
defaults = STYLES.style_defaults_for_ui(style_key) |
|
|
new_bpm = int(defaults.get("bpm", current_bpm or 120)) |
|
|
new_drum = str(defaults.get("drum_beat", "none")) |
|
|
new_synth = str(defaults.get("synthesizer", "none")) |
|
|
new_steps = str(defaults.get("rhythmic_steps", "none")) |
|
|
new_bass = str(defaults.get("bass_style", "none")) |
|
|
new_guitar = str(defaults.get("guitar_style", "none")) |
|
|
|
|
|
prompt_txt = STYLES.build_prompt( |
|
|
style_key, |
|
|
new_bpm, |
|
|
1, |
|
|
new_drum, |
|
|
new_synth, |
|
|
new_steps, |
|
|
new_bass, |
|
|
new_guitar |
|
|
) |
|
|
if not prompt_txt: |
|
|
prompt_txt = f"{style_key}: update prompts.ini" |
|
|
|
|
|
return ( |
|
|
prompt_txt, |
|
|
new_bpm, |
|
|
new_drum, |
|
|
new_synth, |
|
|
new_steps, |
|
|
new_bass, |
|
|
new_guitar, |
|
|
style_key |
|
|
) |
|
|
|
|
|
for key, btn in row1 + row2 + row3 + row4: |
|
|
if key == "foo_pad": |
|
|
continue |
|
|
btn.click( |
|
|
set_prompt_and_settings_from_style, |
|
|
inputs=[gr.State(key), bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style], |
|
|
outputs=[instrumental_prompt, bpm, drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, selected_style] |
|
|
) |
|
|
|
|
|
|
|
|
bitrate_128_btn.click(lambda: "128k", outputs=bitrate_state) |
|
|
bitrate_192_btn.click(lambda: "192k", outputs=bitrate_state) |
|
|
bitrate_320_btn.click(lambda: "320k", outputs=bitrate_state) |
|
|
sample_rate_22050_btn.click(lambda: "22050", outputs=sample_rate_state) |
|
|
sample_rate_44100_btn.click(lambda: "44100", outputs=sample_rate_state) |
|
|
sample_rate_48000_btn.click(lambda: "48000", outputs=sample_rate_state) |
|
|
bit_depth_16_btn.click(lambda: "16", outputs=bit_depth_state) |
|
|
bit_depth_24_btn.click(lambda: "24", outputs=bit_depth_state) |
|
|
|
|
|
|
|
|
gen_btn.click( |
|
|
generate_music_wrapper, |
|
|
inputs=[ |
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, |
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, |
|
|
preset, max_steps, vram_box, bitrate_state, sample_rate_state, bit_depth_state, selected_style |
|
|
], |
|
|
outputs=[out_audio, status_box, vram_box] |
|
|
) |
|
|
|
|
|
|
|
|
clr_btn.click( |
|
|
clear_inputs, outputs=[ |
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, |
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, |
|
|
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state, selected_style |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
def _save_action( |
|
|
instrumental_prompt_v, cfg_v, top_k_v, top_p_v, temp_v, dur_v, bpm_v, |
|
|
drum_v, synth_v, steps_v, bass_v, guitar_v, vol_v, preset_v, maxsteps_v, br_v, sr_v, bd_v, style_v |
|
|
): |
|
|
s = { |
|
|
"instrumental_prompt": instrumental_prompt_v, |
|
|
"cfg_scale": float(cfg_v), |
|
|
"top_k": int(top_k_v), |
|
|
"top_p": float(top_p_v), |
|
|
"temperature": float(temp_v), |
|
|
"total_duration": int(dur_v), |
|
|
"bpm": int(bpm_v), |
|
|
"drum_beat": str(drum_v), |
|
|
"synthesizer": str(synth_v), |
|
|
"rhythmic_steps": str(steps_v), |
|
|
"bass_style": str(bass_v), |
|
|
"guitar_style": str(guitar_v), |
|
|
"target_volume": float(vol_v), |
|
|
"preset": str(preset_v), |
|
|
"max_steps": int(maxsteps_v), |
|
|
"bitrate": str(br_v), |
|
|
"output_sample_rate": str(sr_v), |
|
|
"bit_depth": str(bd_v), |
|
|
"style": str(style_v or "custom") |
|
|
} |
|
|
save_settings(s) |
|
|
for k, v in s.items(): |
|
|
CURRENT_SETTINGS[k] = v |
|
|
return "β
Settings saved." |
|
|
|
|
|
def _load_action(): |
|
|
s = load_settings() |
|
|
for k, v in s.items(): |
|
|
CURRENT_SETTINGS[k] = v |
|
|
return ( |
|
|
s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"], |
|
|
s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"], |
|
|
s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"], |
|
|
s["bitrate"], s["output_sample_rate"], s["bit_depth"], s.get("style", "custom"), |
|
|
"β
Settings loaded." |
|
|
) |
|
|
|
|
|
def _reset_action(): |
|
|
s = DEFAULT_SETTINGS.copy() |
|
|
save_settings(s) |
|
|
for k, v in s.items(): |
|
|
CURRENT_SETTINGS[k] = v |
|
|
return ( |
|
|
s["instrumental_prompt"], s["cfg_scale"], s["top_k"], s["top_p"], s["temperature"], |
|
|
s["total_duration"], s["bpm"], s["drum_beat"], s["synthesizer"], s["rhythmic_steps"], |
|
|
s["bass_style"], s["guitar_style"], s["target_volume"], s["preset"], s["max_steps"], |
|
|
s["bitrate"], s["output_sample_rate"], s["bit_depth"], s["style"], |
|
|
"β
Defaults restored." |
|
|
) |
|
|
|
|
|
save_btn.click( |
|
|
_save_action, |
|
|
inputs=[ |
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, |
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, |
|
|
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state, selected_style |
|
|
], |
|
|
outputs=status_box |
|
|
) |
|
|
|
|
|
load_btn.click( |
|
|
_load_action, |
|
|
outputs=[ |
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, |
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, |
|
|
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state, selected_style, status_box |
|
|
] |
|
|
) |
|
|
|
|
|
reset_btn.click( |
|
|
_reset_action, |
|
|
outputs=[ |
|
|
instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, bpm, |
|
|
drum_beat, synthesizer, rhythmic_steps, bass_style, guitar_style, target_volume, |
|
|
preset, max_steps, bitrate_state, sample_rate_state, bit_depth_state, selected_style, status_box |
|
|
] |
|
|
) |
|
|
|
|
|
def _get_log(): |
|
|
try: |
|
|
return LOG_FILE.read_text(encoding="utf-8")[-40000:] |
|
|
except Exception as e: |
|
|
return f"Log read error: {e}" |
|
|
|
|
|
log_btn.click(_get_log, outputs=log_output) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print(f"{Fore.CYAN}Launching Gradio UI http://0.0.0.0:9999 [{RELEASE}]{Fore.RESET}") |
|
|
try: |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=9999, |
|
|
share=False, |
|
|
inbrowser=False, |
|
|
show_error=True |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Gradio launch failed: {e}") |
|
|
logger.error(traceback.format_exc()) |
|
|
sys.exit(1) |
|
|
|