Spaces:
Running
on
Zero
Running
on
Zero
Big refactor + more features
Browse files
app.py
CHANGED
|
@@ -4,11 +4,13 @@ import gc
|
|
| 4 |
import math
|
| 5 |
import time
|
| 6 |
import uuid
|
|
|
|
| 7 |
import spaces
|
| 8 |
import random
|
| 9 |
-
from
|
| 10 |
-
from
|
| 11 |
-
import
|
|
|
|
| 12 |
|
| 13 |
import gradio as gr
|
| 14 |
import numpy as np
|
|
@@ -18,386 +20,1142 @@ from transformers import AutoModel, AutoTokenizer
|
|
| 18 |
import mido
|
| 19 |
from mido import Message, MidiFile, MidiTrack
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
"
|
| 27 |
-
"
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
@dataclass
|
| 60 |
-
class
|
|
|
|
| 61 |
model_name: str
|
| 62 |
-
compute_mode:
|
| 63 |
base_tempo: int
|
| 64 |
velocity_range: Tuple[int, int]
|
| 65 |
-
scale:
|
| 66 |
num_layers_limit: int
|
| 67 |
seed: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
# --- Core math helpers ---
|
| 70 |
-
|
| 71 |
-
def entropy(p: np.ndarray) -> float:
|
| 72 |
-
p = p / (p.sum() + 1e-9)
|
| 73 |
-
return float(-np.sum(p * np.log2(p + 1e-9)))
|
| 74 |
-
|
| 75 |
-
def quantize_time(time_val: int, grid: int = 120) -> int:
|
| 76 |
-
return int(round(time_val / grid) * grid)
|
| 77 |
-
|
| 78 |
-
def norm_to_scale(val: float, scale: np.ndarray, octave_range: int = 2) -> int:
|
| 79 |
-
octave = int(abs(val) * octave_range) * 12
|
| 80 |
-
note_idx = int(abs(val * 100) % len(scale))
|
| 81 |
-
return int(scale[note_idx] + octave)
|
| 82 |
-
|
| 83 |
-
ROLE_FREQS = {
|
| 84 |
-
'melody': 2.0,
|
| 85 |
-
'bass': 0.5,
|
| 86 |
-
'harmony': 1.5,
|
| 87 |
-
'pad': 0.25,
|
| 88 |
-
'accent': 3.0,
|
| 89 |
-
'atmosphere': 0.33
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
ROLE_WEIGHTS = {
|
| 93 |
-
'melody': np.array([0.4, 0.2, 0.2, 0.1, 0.1]),
|
| 94 |
-
'bass': np.array([0.1, 0.4, 0.1, 0.3, 0.1]),
|
| 95 |
-
'harmony': np.array([0.2, 0.2, 0.3, 0.2, 0.1]),
|
| 96 |
-
'pad': np.array([0.1, 0.3, 0.1, 0.1, 0.4]),
|
| 97 |
-
'accent': np.array([0.5, 0.1, 0.2, 0.1, 0.1]),
|
| 98 |
-
'atmosphere': np.array([0.1, 0.2, 0.1, 0.2, 0.4])
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
def create_note_probability(layer_idx, token_idx, attention_val, hidden_state, num_tokens, role: str):
|
| 102 |
-
base_prob = 1 / (1 + np.exp(-10 * (attention_val - 0.5)))
|
| 103 |
-
temporal_factor = 0.5 + 0.5 * np.sin(2 * np.pi * ROLE_FREQS[role] * token_idx / max(1, num_tokens))
|
| 104 |
-
energy = np.linalg.norm(hidden_state)
|
| 105 |
-
energy_factor = np.tanh(energy / 10)
|
| 106 |
-
local_variance = np.var(hidden_state)
|
| 107 |
-
variance_factor = 1 - np.exp(-local_variance)
|
| 108 |
-
state_entropy = entropy(np.abs(hidden_state))
|
| 109 |
-
max_entropy = np.log2(max(2, hidden_state.shape[0]))
|
| 110 |
-
entropy_factor = state_entropy / max_entropy
|
| 111 |
-
factors = np.array([base_prob, temporal_factor, energy_factor, variance_factor, entropy_factor])
|
| 112 |
-
weights = ROLE_WEIGHTS[role]
|
| 113 |
-
combined_prob = float(np.dot(weights, factors))
|
| 114 |
-
noise_seed = layer_idx * 1000 + token_idx
|
| 115 |
-
noise = 0.1 * (np.sin(noise_seed * 0.1) + np.cos(noise_seed * 0.23)) / 2
|
| 116 |
-
final_prob = (combined_prob + noise) ** 1.5
|
| 117 |
-
return float(np.clip(final_prob, 0, 1))
|
| 118 |
-
|
| 119 |
-
def should_play_note_stochastic(layer_idx, token_idx, attention_val, hidden_state, num_tokens, role: str, history: Dict[int,int]):
|
| 120 |
-
prob = create_note_probability(layer_idx, token_idx, attention_val, hidden_state, num_tokens, role)
|
| 121 |
-
if layer_idx in history:
|
| 122 |
-
last_played = history[layer_idx]
|
| 123 |
-
silence_duration = token_idx - last_played
|
| 124 |
-
prob *= (1 + np.tanh(silence_duration / 5) * 0.5)
|
| 125 |
-
play_note = np.random.random() < prob
|
| 126 |
-
if play_note:
|
| 127 |
-
history[layer_idx] = token_idx
|
| 128 |
-
return play_note
|
| 129 |
-
|
| 130 |
-
# -------------------- Model / Latents --------------------
|
| 131 |
|
| 132 |
@dataclass
|
| 133 |
class Latents:
|
|
|
|
| 134 |
hidden_states: List[torch.Tensor]
|
| 135 |
attentions: List[torch.Tensor]
|
| 136 |
num_layers: int
|
| 137 |
num_tokens: int
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
pass
|
| 193 |
|
| 194 |
-
return Latents(hidden_states=hidden_states[:layers], attentions=attentions[:layers], num_layers=layers, num_tokens=tokens)
|
| 195 |
-
|
| 196 |
-
# -------------------- MIDI Rendering --------------------
|
| 197 |
-
|
| 198 |
-
def render_midi(latents: Latents, scale_notes: List[int], base_tempo: int, velocity_range: Tuple[int, int], preset_name: str, seed: int) -> Tuple[bytes, Dict]:
|
| 199 |
-
np.random.seed(seed)
|
| 200 |
-
random.seed(seed)
|
| 201 |
-
|
| 202 |
-
scale = np.array(scale_notes, dtype=int)
|
| 203 |
-
num_layers = latents.num_layers
|
| 204 |
-
num_tokens = latents.num_tokens
|
| 205 |
-
hidden_states = [hs.float().numpy() if isinstance(hs, torch.Tensor) else hs for hs in latents.hidden_states]
|
| 206 |
-
attentions = [att.float().numpy() if isinstance(att, torch.Tensor) else att for att in latents.attentions]
|
| 207 |
-
|
| 208 |
-
layer_instruments = LAYER_INSTRUMENT_PRESETS[preset_name]
|
| 209 |
-
|
| 210 |
-
mid = MidiFile()
|
| 211 |
-
tracks: List[MidiTrack] = []
|
| 212 |
-
for ch in range(num_layers):
|
| 213 |
-
track = MidiTrack()
|
| 214 |
-
mid.tracks.append(track)
|
| 215 |
-
tracks.append(track)
|
| 216 |
-
instrument = layer_instruments.get(ch, (0, 'melody'))[0]
|
| 217 |
-
track.append(Message('program_change', program=int(instrument), time=0, channel=ch))
|
| 218 |
-
|
| 219 |
-
history: Dict[int, int] = {}
|
| 220 |
-
current_time = [0] * num_layers
|
| 221 |
-
notes_count = [0] * num_layers
|
| 222 |
-
|
| 223 |
-
for token_idx in range(num_tokens):
|
| 224 |
-
if token_idx > 0 and token_idx % 4 == 0:
|
| 225 |
-
for layer_idx in range(num_layers):
|
| 226 |
-
current_time[layer_idx] += base_tempo
|
| 227 |
-
|
| 228 |
-
pan = 64 + int(32 * np.sin(token_idx * math.pi / max(1, num_tokens)))
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
for layer_idx in range(num_layers):
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
continue
|
| 239 |
-
|
| 240 |
-
if role == 'melody':
|
| 241 |
-
note = norm_to_scale(layer_vec[0], scale, octave_range=1)
|
| 242 |
-
notes_to_play = [note]
|
| 243 |
-
elif role == 'bass':
|
| 244 |
-
note = norm_to_scale(layer_vec[0], scale, octave_range=0) - 12
|
| 245 |
-
notes_to_play = [note]
|
| 246 |
-
elif role == 'harmony':
|
| 247 |
-
notes_to_play = [norm_to_scale(layer_vec[i], scale, octave_range=1) for i in range(0, min(2, len(layer_vec)), 1)]
|
| 248 |
-
elif role == 'pad':
|
| 249 |
-
notes_to_play = [norm_to_scale(layer_vec[i], scale, octave_range=1) for i in range(0, min(3, len(layer_vec)), 2)]
|
| 250 |
-
elif role == 'accent':
|
| 251 |
-
note = norm_to_scale(layer_vec[0], scale, octave_range=2) + 12
|
| 252 |
-
notes_to_play = [note]
|
| 253 |
-
else:
|
| 254 |
-
notes_to_play = [norm_to_scale(layer_vec[i], scale, octave_range=1) for i in range(0, min(2, len(layer_vec)), 3)]
|
| 255 |
-
|
| 256 |
-
base_velocity = int(attention_strength * (velocity_range[1] - velocity_range[0]) + velocity_range[0])
|
| 257 |
-
if role == 'melody':
|
| 258 |
-
velocity = min(base_velocity + 10, 127)
|
| 259 |
-
elif role == 'bass':
|
| 260 |
-
velocity = base_velocity
|
| 261 |
-
elif role == 'accent':
|
| 262 |
-
velocity = min(base_velocity + 20, 127)
|
| 263 |
else:
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
-
if role in ['pad', 'atmosphere']:
|
| 267 |
-
duration = base_tempo * 4
|
| 268 |
-
elif role == 'bass':
|
| 269 |
-
duration = base_tempo
|
| 270 |
-
else:
|
| 271 |
-
try:
|
| 272 |
-
dur_factor = entropy(attn_matrix.mean(axis=0)) / (np.log2(attn_matrix.shape[-1]) + 1e-9)
|
| 273 |
-
except Exception:
|
| 274 |
-
dur_factor = 0.5
|
| 275 |
-
duration = quantize_time(int(base_tempo * (0.5 + dur_factor * 1.5)))
|
| 276 |
-
|
| 277 |
-
for note in notes_to_play:
|
| 278 |
-
note = max(21, min(108, int(note)))
|
| 279 |
-
tracks[layer_idx].append(Message('note_on', note=note, velocity=velocity, time=current_time[layer_idx], channel=layer_idx))
|
| 280 |
-
tracks[layer_idx].append(Message('note_off', note=note, velocity=0, time=duration, channel=layer_idx))
|
| 281 |
-
current_time[layer_idx] = 0
|
| 282 |
-
notes_count[layer_idx] += 1
|
| 283 |
-
|
| 284 |
-
if token_idx == 0:
|
| 285 |
-
tracks[layer_idx].append(Message('control_change', control=10, value=pan, time=0, channel=layer_idx))
|
| 286 |
-
|
| 287 |
-
# Save to bytes
|
| 288 |
-
bio = io.BytesIO()
|
| 289 |
-
mid.save(file=bio)
|
| 290 |
-
bio.seek(0)
|
| 291 |
-
|
| 292 |
-
meta = {
|
| 293 |
-
"num_layers": num_layers,
|
| 294 |
-
"num_tokens": num_tokens,
|
| 295 |
-
"notes_per_layer": notes_count,
|
| 296 |
-
"total_notes": int(sum(notes_count)),
|
| 297 |
-
"tempo_ticks_per_beat": int(base_tempo),
|
| 298 |
-
"scale": list(map(int, scale.tolist())),
|
| 299 |
-
}
|
| 300 |
-
return bio.read(), meta
|
| 301 |
-
|
| 302 |
-
# -------------------- Gradio UI --------------------
|
| 303 |
-
|
| 304 |
-
DESCRIPTION = """
|
| 305 |
-
# LLM Forest Orchestra — Sonify Transformer Internals
|
| 306 |
-
Turn hidden states and attentions into a multi-track MIDI composition.
|
| 307 |
-
|
| 308 |
-
- **Two compute modes**: *Full model* (loads a HF model and extracts latents) or *Mock latents* (quick demo with synthetic tensors — great for CPU-only Spaces).
|
| 309 |
-
- Choose **scale**, **tempo**, **velocity range**, and **instrument/role preset**.
|
| 310 |
-
- Exports a **MIDI** you can arrange further in your DAW.
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
## Inspiration
|
| 314 |
-
|
| 315 |
-
This project is inspired by the way **mushrooms and mycelial networks in forests**
|
| 316 |
-
connect plants and trees, forming a living web of communication and resource sharing.
|
| 317 |
-
These connections, can be turned into ethereal music.
|
| 318 |
-
Just as signals move through these hidden connections, transformer models also
|
| 319 |
-
pass hidden states and attentions across their layers. Here, those hidden
|
| 320 |
-
connections are translated into **music**, analogous to the forest's secret orchestra.
|
| 321 |
-
"""
|
| 322 |
-
|
| 323 |
-
EXAMPLE_TEXT = """Joy cascades in golden waterfalls, crashing into pools of melancholy blue.
|
| 324 |
-
Anger burns red through veins of marble, while serenity floats on clouds of softest grey.
|
| 325 |
-
Love pulses in waves of crimson and rose, intertwining with longing's purple haze.
|
| 326 |
-
Each feeling resonates at its own frequency, painting music across the soul's canvas.
|
| 327 |
-
"""
|
| 328 |
-
|
| 329 |
-
def parse_scale(selection: str, custom: str) -> List[int]:
|
| 330 |
-
if selection == "Custom (comma-separated MIDI notes)":
|
| 331 |
-
try:
|
| 332 |
-
return [int(x.strip()) for x in custom.split(",") if x.strip()]
|
| 333 |
-
except Exception:
|
| 334 |
-
return SCALES["C pentatonic"]
|
| 335 |
-
return SCALES[selection] if SCALES[selection] else SCALES["C pentatonic"]
|
| 336 |
-
|
| 337 |
-
def generate(text, model_name, compute_mode, base_tempo, velocity_low, velocity_high, scale_choice, custom_scale, num_layers, preset, seed):
|
| 338 |
-
scale = parse_scale(scale_choice, custom_scale)
|
| 339 |
-
cfg = GenConfig(
|
| 340 |
-
model_name=model_name or DEFAULT_MODEL,
|
| 341 |
-
compute_mode=compute_mode,
|
| 342 |
-
base_tempo=int(base_tempo),
|
| 343 |
-
velocity_range=(int(velocity_low), int(velocity_high)),
|
| 344 |
-
scale=scale,
|
| 345 |
-
num_layers_limit=int(num_layers),
|
| 346 |
-
seed=int(seed),
|
| 347 |
-
)
|
| 348 |
-
|
| 349 |
-
# Get latents
|
| 350 |
-
latents = get_latents(text, cfg.model_name, cfg.compute_mode, cfg.num_layers_limit)
|
| 351 |
-
|
| 352 |
-
# Render MIDI
|
| 353 |
-
midi_bytes, meta = render_midi(latents, cfg.scale, cfg.base_tempo, cfg.velocity_range, preset, cfg.seed)
|
| 354 |
-
|
| 355 |
-
# Persist to a file for download
|
| 356 |
-
out_name = f"llm_forest_orchestra_{uuid.uuid4().hex[:8]}.mid"
|
| 357 |
-
with open(out_name, "wb") as f:
|
| 358 |
-
f.write(midi_bytes)
|
| 359 |
-
|
| 360 |
-
# Prepare quick stats
|
| 361 |
-
stats = (
|
| 362 |
-
f"Layers: {meta['num_layers']} | Tokens: {meta['num_tokens']} | "
|
| 363 |
-
f"Total notes: {meta['total_notes']} | Scale: {meta['scale']} | "
|
| 364 |
-
f"Tempo (ticks/beat): {meta['tempo_ticks_per_beat']}"
|
| 365 |
-
)
|
| 366 |
-
|
| 367 |
-
return out_name, stats, json.dumps(meta, indent=2)
|
| 368 |
-
|
| 369 |
-
with gr.Blocks(title="LLM Forest Orchestra — MIDI from Transformer Internals") as demo:
|
| 370 |
-
gr.Markdown(DESCRIPTION)
|
| 371 |
-
|
| 372 |
-
with gr.Row():
|
| 373 |
-
with gr.Column():
|
| 374 |
-
text = gr.Textbox(value=EXAMPLE_TEXT, label="Input text", lines=8)
|
| 375 |
-
model_name = gr.Textbox(value=DEFAULT_MODEL, label="HF model (base) to probe", info="Should support output_hidden_states & output_attentions")
|
| 376 |
-
compute_mode = gr.Radio(choices=["Mock latents", "Full model"], value="Full model", label="Compute mode")
|
| 377 |
-
preset = gr.Dropdown(choices=list(LAYER_INSTRUMENT_PRESETS.keys()), value="Ensemble (melody+bass+pad etc.)", label="Instrument/Role preset")
|
| 378 |
-
with gr.Row():
|
| 379 |
-
base_tempo = gr.Slider(120, 960, value=480, step=1, label="Ticks per beat (tempo grid)")
|
| 380 |
-
num_layers = gr.Slider(1, 6, value=6, step=1, label="Max layers to use")
|
| 381 |
-
with gr.Row():
|
| 382 |
-
velocity_low = gr.Slider(1, 126, value=40, step=1, label="Velocity min")
|
| 383 |
-
velocity_high = gr.Slider(2, 127, value=90, step=1, label="Velocity max")
|
| 384 |
-
with gr.Row():
|
| 385 |
-
scale_choice = gr.Dropdown(choices=list(SCALES.keys()), value="C pentatonic", label="Scale")
|
| 386 |
-
custom_scale = gr.Textbox(value="", label="Custom scale notes (e.g. 60,62,65,67)")
|
| 387 |
-
seed = gr.Number(value=42, precision=0, label="Random seed")
|
| 388 |
-
|
| 389 |
-
btn = gr.Button("Generate MIDI", variant="primary")
|
| 390 |
-
|
| 391 |
-
with gr.Column():
|
| 392 |
-
midi_file = gr.File(label="MIDI output (.mid)")
|
| 393 |
-
stats = gr.Markdown("")
|
| 394 |
-
meta_json = gr.Code(label="Meta (JSON)")
|
| 395 |
-
|
| 396 |
-
btn.click(
|
| 397 |
-
fn=generate,
|
| 398 |
-
inputs=[text, model_name, compute_mode, base_tempo, velocity_low, velocity_high, scale_choice, custom_scale, num_layers, preset, seed],
|
| 399 |
-
outputs=[midi_file, stats, meta_json]
|
| 400 |
-
)
|
| 401 |
|
| 402 |
if __name__ == "__main__":
|
| 403 |
-
|
|
|
|
| 4 |
import math
|
| 5 |
import time
|
| 6 |
import uuid
|
| 7 |
+
import json
|
| 8 |
import spaces
|
| 9 |
import random
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
from dataclasses import dataclass, field, asdict
|
| 12 |
+
from typing import Dict, List, Tuple, Optional, Any, Union
|
| 13 |
+
from enum import Enum
|
| 14 |
|
| 15 |
import gradio as gr
|
| 16 |
import numpy as np
|
|
|
|
| 20 |
import mido
|
| 21 |
from mido import Message, MidiFile, MidiTrack
|
| 22 |
|
| 23 |
+
|
| 24 |
+
# Configuration Classes
|
| 25 |
+
|
| 26 |
+
class ComputeMode(Enum):
|
| 27 |
+
"""Enum for computation modes."""
|
| 28 |
+
FULL_MODEL = "Full model"
|
| 29 |
+
MOCK_LATENTS = "Mock latents"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class MusicRole(Enum):
|
| 33 |
+
"""Enum for musical roles/layers."""
|
| 34 |
+
MELODY = "melody"
|
| 35 |
+
BASS = "bass"
|
| 36 |
+
HARMONY = "harmony"
|
| 37 |
+
PAD = "pad"
|
| 38 |
+
ACCENT = "accent"
|
| 39 |
+
ATMOSPHERE = "atmosphere"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class ScaleDefinition:
|
| 44 |
+
"""Represents a musical scale."""
|
| 45 |
+
name: str
|
| 46 |
+
notes: List[int]
|
| 47 |
+
description: str = ""
|
| 48 |
+
|
| 49 |
+
def __post_init__(self):
|
| 50 |
+
"""Validate scale notes are within MIDI range."""
|
| 51 |
+
for note in self.notes:
|
| 52 |
+
if not 0 <= note <= 127:
|
| 53 |
+
raise ValueError(f"MIDI note {note} out of range (0-127)")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class InstrumentMapping:
|
| 58 |
+
"""Maps a layer to an instrument and musical role."""
|
| 59 |
+
program: int # MIDI program number
|
| 60 |
+
role: MusicRole
|
| 61 |
+
channel: int
|
| 62 |
+
name: str = ""
|
| 63 |
+
|
| 64 |
+
def __post_init__(self):
|
| 65 |
+
"""Validate MIDI program and channel."""
|
| 66 |
+
if not 0 <= self.program <= 127:
|
| 67 |
+
raise ValueError(f"MIDI program {self.program} out of range")
|
| 68 |
+
if not 0 <= self.channel <= 15:
|
| 69 |
+
raise ValueError(f"MIDI channel {self.channel} out of range")
|
| 70 |
+
|
| 71 |
|
| 72 |
@dataclass
|
| 73 |
+
class GenerationConfig:
|
| 74 |
+
"""Complete configuration for music generation."""
|
| 75 |
model_name: str
|
| 76 |
+
compute_mode: ComputeMode
|
| 77 |
base_tempo: int
|
| 78 |
velocity_range: Tuple[int, int]
|
| 79 |
+
scale: ScaleDefinition
|
| 80 |
num_layers_limit: int
|
| 81 |
seed: int
|
| 82 |
+
instrument_preset: str
|
| 83 |
+
|
| 84 |
+
# Additional configuration options
|
| 85 |
+
quantization_grid: int = 120
|
| 86 |
+
octave_range: int = 2
|
| 87 |
+
dynamics_curve: str = "linear" # linear, exponential, logarithmic
|
| 88 |
+
|
| 89 |
+
def validate(self):
|
| 90 |
+
"""Validate configuration parameters."""
|
| 91 |
+
if not 1 <= self.base_tempo <= 2000:
|
| 92 |
+
raise ValueError("Tempo must be between 1 and 2000")
|
| 93 |
+
if not 1 <= self.velocity_range[0] < self.velocity_range[1] <= 127:
|
| 94 |
+
raise ValueError("Invalid velocity range")
|
| 95 |
+
if not 1 <= self.num_layers_limit <= 32:
|
| 96 |
+
raise ValueError("Number of layers must be between 1 and 32")
|
| 97 |
+
|
| 98 |
+
def to_dict(self) -> Dict:
|
| 99 |
+
"""Convert config to dictionary for serialization."""
|
| 100 |
+
return {
|
| 101 |
+
"model_name": self.model_name,
|
| 102 |
+
"compute_mode": self.compute_mode.value,
|
| 103 |
+
"base_tempo": self.base_tempo,
|
| 104 |
+
"velocity_range": self.velocity_range,
|
| 105 |
+
"scale_name": self.scale.name,
|
| 106 |
+
"scale_notes": self.scale.notes,
|
| 107 |
+
"num_layers_limit": self.num_layers_limit,
|
| 108 |
+
"seed": self.seed,
|
| 109 |
+
"instrument_preset": self.instrument_preset,
|
| 110 |
+
"quantization_grid": self.quantization_grid,
|
| 111 |
+
"octave_range": self.octave_range,
|
| 112 |
+
"dynamics_curve": self.dynamics_curve
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
@classmethod
|
| 116 |
+
def from_dict(cls, data: Dict, scale_manager: "ScaleManager") -> "GenerationConfig":
|
| 117 |
+
"""Create config from dictionary."""
|
| 118 |
+
scale = scale_manager.get_scale(data["scale_name"])
|
| 119 |
+
if scale is None:
|
| 120 |
+
scale = ScaleDefinition(name="Custom", notes=data["scale_notes"])
|
| 121 |
+
|
| 122 |
+
return cls(
|
| 123 |
+
model_name=data["model_name"],
|
| 124 |
+
compute_mode=ComputeMode(data["compute_mode"]),
|
| 125 |
+
base_tempo=data["base_tempo"],
|
| 126 |
+
velocity_range=tuple(data["velocity_range"]),
|
| 127 |
+
scale=scale,
|
| 128 |
+
num_layers_limit=data["num_layers_limit"],
|
| 129 |
+
seed=data["seed"],
|
| 130 |
+
instrument_preset=data["instrument_preset"],
|
| 131 |
+
quantization_grid=data.get("quantization_grid", 120),
|
| 132 |
+
octave_range=data.get("octave_range", 2),
|
| 133 |
+
dynamics_curve=data.get("dynamics_curve", "linear")
|
| 134 |
+
)
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
@dataclass
|
| 138 |
class Latents:
|
| 139 |
+
"""Container for model latents."""
|
| 140 |
hidden_states: List[torch.Tensor]
|
| 141 |
attentions: List[torch.Tensor]
|
| 142 |
num_layers: int
|
| 143 |
num_tokens: int
|
| 144 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# Music Components
|
| 148 |
+
|
| 149 |
+
class ScaleManager:
|
| 150 |
+
"""Manages musical scales and modes."""
|
| 151 |
+
|
| 152 |
+
def __init__(self):
|
| 153 |
+
"""Initialize with default scales."""
|
| 154 |
+
self.scales = {
|
| 155 |
+
"C pentatonic": ScaleDefinition(
|
| 156 |
+
"C pentatonic",
|
| 157 |
+
[60, 62, 65, 67, 70, 72, 74, 77],
|
| 158 |
+
"Major pentatonic scale"
|
| 159 |
+
),
|
| 160 |
+
"C major": ScaleDefinition(
|
| 161 |
+
"C major",
|
| 162 |
+
[60, 62, 64, 65, 67, 69, 71, 72],
|
| 163 |
+
"Major scale (Ionian mode)"
|
| 164 |
+
),
|
| 165 |
+
"A minor": ScaleDefinition(
|
| 166 |
+
"A minor",
|
| 167 |
+
[57, 59, 60, 62, 64, 65, 67, 69],
|
| 168 |
+
"Natural minor scale (Aeolian mode)"
|
| 169 |
+
),
|
| 170 |
+
"D dorian": ScaleDefinition(
|
| 171 |
+
"D dorian",
|
| 172 |
+
[62, 64, 65, 67, 69, 71, 72, 74],
|
| 173 |
+
"Dorian mode - minor with raised 6th"
|
| 174 |
+
),
|
| 175 |
+
"E phrygian": ScaleDefinition(
|
| 176 |
+
"E phrygian",
|
| 177 |
+
[64, 65, 67, 69, 71, 72, 74, 76],
|
| 178 |
+
"Phrygian mode - minor with lowered 2nd"
|
| 179 |
+
),
|
| 180 |
+
"G mixolydian": ScaleDefinition(
|
| 181 |
+
"G mixolydian",
|
| 182 |
+
[67, 69, 71, 72, 74, 76, 77, 79],
|
| 183 |
+
"Mixolydian mode - major with lowered 7th"
|
| 184 |
+
),
|
| 185 |
+
"Blues scale": ScaleDefinition(
|
| 186 |
+
"Blues scale",
|
| 187 |
+
[60, 63, 65, 66, 67, 70, 72, 75],
|
| 188 |
+
"Blues scale with blue notes"
|
| 189 |
+
),
|
| 190 |
+
"Chromatic": ScaleDefinition(
|
| 191 |
+
"Chromatic",
|
| 192 |
+
list(range(60, 72)),
|
| 193 |
+
"All 12 semitones"
|
| 194 |
+
)
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
def get_scale(self, name: str) -> Optional[ScaleDefinition]:
|
| 198 |
+
"""Get scale by name."""
|
| 199 |
+
return self.scales.get(name)
|
| 200 |
+
|
| 201 |
+
def add_custom_scale(self, name: str, notes: List[int], description: str = "") -> ScaleDefinition:
|
| 202 |
+
"""Add a custom scale."""
|
| 203 |
+
scale = ScaleDefinition(name, notes, description)
|
| 204 |
+
self.scales[name] = scale
|
| 205 |
+
return scale
|
| 206 |
+
|
| 207 |
+
def list_scales(self) -> List[str]:
|
| 208 |
+
"""Get list of available scale names."""
|
| 209 |
+
return list(self.scales.keys())
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class InstrumentPresetManager:
|
| 213 |
+
"""Manages instrument presets for different musical styles."""
|
| 214 |
+
|
| 215 |
+
def __init__(self):
|
| 216 |
+
"""Initialize with default presets."""
|
| 217 |
+
self.presets = {
|
| 218 |
+
"Ensemble (melody+bass+pad etc.)": [
|
| 219 |
+
InstrumentMapping(0, MusicRole.MELODY, 0, "Piano"),
|
| 220 |
+
InstrumentMapping(33, MusicRole.BASS, 1, "Electric Bass"),
|
| 221 |
+
InstrumentMapping(46, MusicRole.HARMONY, 2, "Harp"),
|
| 222 |
+
InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"),
|
| 223 |
+
InstrumentMapping(11, MusicRole.ACCENT, 4, "Vibraphone"),
|
| 224 |
+
InstrumentMapping(89, MusicRole.ATMOSPHERE, 5, "Pad Warm")
|
| 225 |
+
],
|
| 226 |
+
"Piano Trio (melody+bass+harmony)": [
|
| 227 |
+
InstrumentMapping(0, MusicRole.MELODY, 0, "Piano"),
|
| 228 |
+
InstrumentMapping(33, MusicRole.BASS, 1, "Electric Bass"),
|
| 229 |
+
InstrumentMapping(0, MusicRole.HARMONY, 2, "Piano"),
|
| 230 |
+
InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"),
|
| 231 |
+
InstrumentMapping(0, MusicRole.ACCENT, 4, "Piano"),
|
| 232 |
+
InstrumentMapping(0, MusicRole.ATMOSPHERE, 5, "Piano")
|
| 233 |
+
],
|
| 234 |
+
"Pads & Atmosphere": [
|
| 235 |
+
InstrumentMapping(48, MusicRole.PAD, 0, "String Ensemble"),
|
| 236 |
+
InstrumentMapping(48, MusicRole.PAD, 1, "String Ensemble"),
|
| 237 |
+
InstrumentMapping(89, MusicRole.ATMOSPHERE, 2, "Pad Warm"),
|
| 238 |
+
InstrumentMapping(89, MusicRole.ATMOSPHERE, 3, "Pad Warm"),
|
| 239 |
+
InstrumentMapping(46, MusicRole.HARMONY, 4, "Harp"),
|
| 240 |
+
InstrumentMapping(11, MusicRole.ACCENT, 5, "Vibraphone")
|
| 241 |
+
],
|
| 242 |
+
"Orchestral": [
|
| 243 |
+
InstrumentMapping(40, MusicRole.MELODY, 0, "Violin"),
|
| 244 |
+
InstrumentMapping(42, MusicRole.BASS, 1, "Cello"),
|
| 245 |
+
InstrumentMapping(46, MusicRole.HARMONY, 2, "Harp"),
|
| 246 |
+
InstrumentMapping(48, MusicRole.PAD, 3, "String Ensemble"),
|
| 247 |
+
InstrumentMapping(73, MusicRole.ACCENT, 4, "Flute"),
|
| 248 |
+
InstrumentMapping(49, MusicRole.ATMOSPHERE, 5, "Slow Strings")
|
| 249 |
+
],
|
| 250 |
+
"Electronic": [
|
| 251 |
+
InstrumentMapping(80, MusicRole.MELODY, 0, "Lead Square"),
|
| 252 |
+
InstrumentMapping(38, MusicRole.BASS, 1, "Synth Bass"),
|
| 253 |
+
InstrumentMapping(81, MusicRole.HARMONY, 2, "Lead Sawtooth"),
|
| 254 |
+
InstrumentMapping(90, MusicRole.PAD, 3, "Pad Polysynth"),
|
| 255 |
+
InstrumentMapping(82, MusicRole.ACCENT, 4, "Lead Calliope"),
|
| 256 |
+
InstrumentMapping(91, MusicRole.ATMOSPHERE, 5, "Pad Bowed")
|
| 257 |
+
]
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
def get_preset(self, name: str) -> List[InstrumentMapping]:
|
| 261 |
+
"""Get instrument preset by name."""
|
| 262 |
+
return self.presets.get(name, self.presets["Ensemble (melody+bass+pad etc.)"])
|
| 263 |
+
|
| 264 |
+
def list_presets(self) -> List[str]:
|
| 265 |
+
"""Get list of available preset names."""
|
| 266 |
+
return list(self.presets.keys())
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# Music Generation Components
|
| 270 |
+
|
| 271 |
+
class MusicMathUtils:
|
| 272 |
+
"""Utility class for music-related mathematical operations."""
|
| 273 |
+
|
| 274 |
+
@staticmethod
|
| 275 |
+
def entropy(p: np.ndarray) -> float:
|
| 276 |
+
"""Calculate Shannon entropy of a probability distribution."""
|
| 277 |
+
p = p / (p.sum() + 1e-9)
|
| 278 |
+
return float(-np.sum(p * np.log2(p + 1e-9)))
|
| 279 |
+
|
| 280 |
+
@staticmethod
|
| 281 |
+
def quantize_time(time_val: int, grid: int = 120) -> int:
|
| 282 |
+
"""Quantize time value to grid."""
|
| 283 |
+
return int(round(time_val / grid) * grid)
|
| 284 |
+
|
| 285 |
+
@staticmethod
|
| 286 |
+
def norm_to_scale(val: float, scale: np.ndarray, octave_range: int = 2) -> int:
|
| 287 |
+
"""Map normalized value to scale note with octave range."""
|
| 288 |
+
octave = int(abs(val) * octave_range) * 12
|
| 289 |
+
note_idx = int(abs(val * 100) % len(scale))
|
| 290 |
+
return int(scale[note_idx] + octave)
|
| 291 |
+
|
| 292 |
+
@staticmethod
|
| 293 |
+
def apply_dynamics_curve(value: float, curve_type: str = "linear") -> float:
|
| 294 |
+
"""Apply dynamics curve to a value."""
|
| 295 |
+
value = np.clip(value, 0, 1)
|
| 296 |
+
if curve_type == "exponential":
|
| 297 |
+
return value ** 2
|
| 298 |
+
elif curve_type == "logarithmic":
|
| 299 |
+
return np.log1p(value * np.e) / np.log1p(np.e)
|
| 300 |
+
else: # linear
|
| 301 |
+
return value
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class NoteGenerator:
|
| 305 |
+
"""Generates notes based on neural network latents."""
|
| 306 |
+
|
| 307 |
+
# Role-specific frequency multipliers
|
| 308 |
+
ROLE_FREQUENCIES = {
|
| 309 |
+
MusicRole.MELODY: 2.0,
|
| 310 |
+
MusicRole.BASS: 0.5,
|
| 311 |
+
MusicRole.HARMONY: 1.5,
|
| 312 |
+
MusicRole.PAD: 0.25,
|
| 313 |
+
MusicRole.ACCENT: 3.0,
|
| 314 |
+
MusicRole.ATMOSPHERE: 0.33
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
# Role-specific weight distributions
|
| 318 |
+
ROLE_WEIGHTS = {
|
| 319 |
+
MusicRole.MELODY: np.array([0.4, 0.2, 0.2, 0.1, 0.1]),
|
| 320 |
+
MusicRole.BASS: np.array([0.1, 0.4, 0.1, 0.3, 0.1]),
|
| 321 |
+
MusicRole.HARMONY: np.array([0.2, 0.2, 0.3, 0.2, 0.1]),
|
| 322 |
+
MusicRole.PAD: np.array([0.1, 0.3, 0.1, 0.1, 0.4]),
|
| 323 |
+
MusicRole.ACCENT: np.array([0.5, 0.1, 0.2, 0.1, 0.1]),
|
| 324 |
+
MusicRole.ATMOSPHERE: np.array([0.1, 0.2, 0.1, 0.2, 0.4])
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
def __init__(self, config: GenerationConfig):
|
| 328 |
+
"""Initialize with generation configuration."""
|
| 329 |
+
self.config = config
|
| 330 |
+
self.math_utils = MusicMathUtils()
|
| 331 |
+
self.history: Dict[int, int] = {}
|
| 332 |
+
|
| 333 |
+
def create_note_probability(
|
| 334 |
+
self,
|
| 335 |
+
layer_idx: int,
|
| 336 |
+
token_idx: int,
|
| 337 |
+
attention_val: float,
|
| 338 |
+
hidden_state: np.ndarray,
|
| 339 |
+
num_tokens: int,
|
| 340 |
+
role: MusicRole
|
| 341 |
+
) -> float:
|
| 342 |
+
"""Calculate probability of playing a note based on multiple factors."""
|
| 343 |
+
# Base probability from attention
|
| 344 |
+
base_prob = 1 / (1 + np.exp(-10 * (attention_val - 0.5)))
|
| 345 |
+
|
| 346 |
+
# Temporal factor based on role frequency
|
| 347 |
+
temporal_factor = 0.5 + 0.5 * np.sin(
|
| 348 |
+
2 * np.pi * self.ROLE_FREQUENCIES[role] * token_idx / max(1, num_tokens)
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Energy factor from hidden state norm
|
| 352 |
+
energy = np.linalg.norm(hidden_state)
|
| 353 |
+
energy_factor = np.tanh(energy / 10)
|
| 354 |
+
|
| 355 |
+
# Variance factor
|
| 356 |
+
local_variance = np.var(hidden_state)
|
| 357 |
+
variance_factor = 1 - np.exp(-local_variance)
|
| 358 |
+
|
| 359 |
+
# Entropy factor
|
| 360 |
+
state_entropy = self.math_utils.entropy(np.abs(hidden_state))
|
| 361 |
+
max_entropy = np.log2(max(2, hidden_state.shape[0]))
|
| 362 |
+
entropy_factor = state_entropy / max_entropy
|
| 363 |
+
|
| 364 |
+
# Combine factors with role-specific weights
|
| 365 |
+
factors = np.array([base_prob, temporal_factor, energy_factor, variance_factor, entropy_factor])
|
| 366 |
+
weights = self.ROLE_WEIGHTS[role]
|
| 367 |
+
combined_prob = float(np.dot(weights, factors))
|
| 368 |
+
|
| 369 |
+
# Add deterministic noise for variation
|
| 370 |
+
noise_seed = layer_idx * 1000 + token_idx
|
| 371 |
+
noise = 0.1 * (np.sin(noise_seed * 0.1) + np.cos(noise_seed * 0.23)) / 2
|
| 372 |
+
|
| 373 |
+
# Apply dynamics curve
|
| 374 |
+
final_prob = (combined_prob + noise) ** 1.5
|
| 375 |
+
final_prob = self.math_utils.apply_dynamics_curve(final_prob, self.config.dynamics_curve)
|
| 376 |
+
|
| 377 |
+
return float(np.clip(final_prob, 0, 1))
|
| 378 |
+
|
| 379 |
+
def should_play_note(
|
| 380 |
+
self,
|
| 381 |
+
layer_idx: int,
|
| 382 |
+
token_idx: int,
|
| 383 |
+
attention_val: float,
|
| 384 |
+
hidden_state: np.ndarray,
|
| 385 |
+
num_tokens: int,
|
| 386 |
+
role: MusicRole
|
| 387 |
+
) -> bool:
|
| 388 |
+
"""Determine if a note should be played."""
|
| 389 |
+
prob = self.create_note_probability(
|
| 390 |
+
layer_idx, token_idx, attention_val, hidden_state, num_tokens, role
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# Adjust probability based on silence duration
|
| 394 |
+
if layer_idx in self.history:
|
| 395 |
+
last_played = self.history[layer_idx]
|
| 396 |
+
silence_duration = token_idx - last_played
|
| 397 |
+
prob *= (1 + np.tanh(silence_duration / 5) * 0.5)
|
| 398 |
+
|
| 399 |
+
# Stochastic decision
|
| 400 |
+
play_note = np.random.random() < prob
|
| 401 |
+
|
| 402 |
+
if play_note:
|
| 403 |
+
self.history[layer_idx] = token_idx
|
| 404 |
+
|
| 405 |
+
return play_note
|
| 406 |
+
|
| 407 |
+
def generate_notes_for_role(
|
| 408 |
+
self,
|
| 409 |
+
role: MusicRole,
|
| 410 |
+
hidden_state: np.ndarray,
|
| 411 |
+
scale: np.ndarray
|
| 412 |
+
) -> List[int]:
|
| 413 |
+
"""Generate notes based on role and hidden state."""
|
| 414 |
+
if role == MusicRole.MELODY:
|
| 415 |
+
note = self.math_utils.norm_to_scale(
|
| 416 |
+
hidden_state[0], scale, octave_range=1
|
| 417 |
+
)
|
| 418 |
+
return [note]
|
| 419 |
+
|
| 420 |
+
elif role == MusicRole.BASS:
|
| 421 |
+
note = self.math_utils.norm_to_scale(
|
| 422 |
+
hidden_state[0], scale, octave_range=0
|
| 423 |
+
) - 12
|
| 424 |
+
return [note]
|
| 425 |
+
|
| 426 |
+
elif role == MusicRole.HARMONY:
|
| 427 |
+
return [
|
| 428 |
+
self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1)
|
| 429 |
+
for i in range(0, min(2, len(hidden_state)), 1)
|
| 430 |
+
]
|
| 431 |
+
|
| 432 |
+
elif role == MusicRole.PAD:
|
| 433 |
+
return [
|
| 434 |
+
self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1)
|
| 435 |
+
for i in range(0, min(3, len(hidden_state)), 2)
|
| 436 |
+
]
|
| 437 |
+
|
| 438 |
+
elif role == MusicRole.ACCENT:
|
| 439 |
+
note = self.math_utils.norm_to_scale(
|
| 440 |
+
hidden_state[0], scale, octave_range=2
|
| 441 |
+
) + 12
|
| 442 |
+
return [note]
|
| 443 |
+
|
| 444 |
+
else: # ATMOSPHERE
|
| 445 |
+
return [
|
| 446 |
+
self.math_utils.norm_to_scale(hidden_state[i], scale, octave_range=1)
|
| 447 |
+
for i in range(0, min(2, len(hidden_state)), 3)
|
| 448 |
+
]
|
| 449 |
+
|
| 450 |
+
def calculate_velocity(
|
| 451 |
+
self,
|
| 452 |
+
role: MusicRole,
|
| 453 |
+
attention_strength: float
|
| 454 |
+
) -> int:
|
| 455 |
+
"""Calculate note velocity based on role and attention."""
|
| 456 |
+
base_velocity = int(
|
| 457 |
+
attention_strength * (self.config.velocity_range[1] - self.config.velocity_range[0])
|
| 458 |
+
+ self.config.velocity_range[0]
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# Role-specific adjustments
|
| 462 |
+
if role == MusicRole.MELODY:
|
| 463 |
+
velocity = min(base_velocity + 10, 127)
|
| 464 |
+
elif role == MusicRole.ACCENT:
|
| 465 |
+
velocity = min(base_velocity + 20, 127)
|
| 466 |
+
elif role in [MusicRole.PAD, MusicRole.ATMOSPHERE]:
|
| 467 |
+
velocity = max(base_velocity - 10, 20)
|
| 468 |
+
else:
|
| 469 |
+
velocity = base_velocity
|
| 470 |
+
|
| 471 |
+
return velocity
|
| 472 |
+
|
| 473 |
+
def calculate_duration(
|
| 474 |
+
self,
|
| 475 |
+
role: MusicRole,
|
| 476 |
+
attention_matrix: np.ndarray
|
| 477 |
+
) -> int:
|
| 478 |
+
"""Calculate note duration based on role and attention."""
|
| 479 |
+
if role in [MusicRole.PAD, MusicRole.ATMOSPHERE]:
|
| 480 |
+
duration = self.config.base_tempo * 4
|
| 481 |
+
elif role == MusicRole.BASS:
|
| 482 |
+
duration = self.config.base_tempo
|
| 483 |
+
else:
|
| 484 |
+
try:
|
| 485 |
+
dur_factor = self.math_utils.entropy(attention_matrix.mean(axis=0)) / (
|
| 486 |
+
np.log2(attention_matrix.shape[-1]) + 1e-9
|
| 487 |
+
)
|
| 488 |
+
except Exception:
|
| 489 |
+
dur_factor = 0.5
|
| 490 |
+
duration = self.math_utils.quantize_time(
|
| 491 |
+
int(self.config.base_tempo * (0.5 + dur_factor * 1.5)),
|
| 492 |
+
self.config.quantization_grid
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
return duration
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
# Model Interaction
|
| 499 |
+
|
| 500 |
+
class LatentExtractor(ABC):
|
| 501 |
+
"""Abstract base class for latent extraction strategies."""
|
| 502 |
+
|
| 503 |
+
@abstractmethod
|
| 504 |
+
def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents:
|
| 505 |
+
"""Extract latents from text."""
|
| 506 |
pass
|
| 507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
+
class MockLatentExtractor(LatentExtractor):
|
| 510 |
+
"""Generate mock latents for testing without loading models."""
|
| 511 |
+
|
| 512 |
+
def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents:
|
| 513 |
+
"""Generate synthetic latents based on text."""
|
| 514 |
+
# Simulate token count based on text length
|
| 515 |
+
tokens = max(16, min(128, len(text.split()) * 4))
|
| 516 |
+
layers = min(config.num_layers_limit, 6)
|
| 517 |
+
|
| 518 |
+
# Generate deterministic but varied latents based on text
|
| 519 |
+
np.random.seed(hash(text) % 2**32)
|
| 520 |
+
|
| 521 |
+
hidden_states = [
|
| 522 |
+
torch.randn(1, tokens, 128) for _ in range(layers)
|
| 523 |
+
]
|
| 524 |
+
attentions = [
|
| 525 |
+
torch.rand(1, 8, tokens, tokens) for _ in range(layers)
|
| 526 |
+
]
|
| 527 |
+
|
| 528 |
+
metadata = {
|
| 529 |
+
"mode": "mock",
|
| 530 |
+
"text_length": len(text),
|
| 531 |
+
"generated_tokens": tokens,
|
| 532 |
+
"generated_layers": layers
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
return Latents(
|
| 536 |
+
hidden_states=hidden_states,
|
| 537 |
+
attentions=attentions,
|
| 538 |
+
num_layers=layers,
|
| 539 |
+
num_tokens=tokens,
|
| 540 |
+
metadata=metadata
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
class ModelLatentExtractor(LatentExtractor):
|
| 545 |
+
"""Extract real latents from transformer models."""
|
| 546 |
+
|
| 547 |
+
@spaces.GPU(duration=45)
|
| 548 |
+
def extract(self, text: str, config: GenerationConfig, progress=None) -> Latents:
|
| 549 |
+
"""Extract latents from a real transformer model."""
|
| 550 |
+
model_name = config.model_name
|
| 551 |
+
|
| 552 |
+
# Load tokenizer
|
| 553 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 554 |
+
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
| 555 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 556 |
+
|
| 557 |
+
# Configure model loading
|
| 558 |
+
load_kwargs = {
|
| 559 |
+
"output_hidden_states": True,
|
| 560 |
+
"output_attentions": True,
|
| 561 |
+
"device_map": "cuda" if torch.cuda.is_available() else "cpu",
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
# Set appropriate dtype
|
| 565 |
+
try:
|
| 566 |
+
load_kwargs["torch_dtype"] = (
|
| 567 |
+
torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 568 |
+
)
|
| 569 |
+
except Exception:
|
| 570 |
+
pass
|
| 571 |
+
|
| 572 |
+
# Load model
|
| 573 |
+
model = AutoModel.from_pretrained(model_name, **load_kwargs)
|
| 574 |
+
|
| 575 |
+
# Tokenize input
|
| 576 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
|
| 577 |
+
device = next(model.parameters()).device
|
| 578 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 579 |
+
|
| 580 |
+
# Get model outputs
|
| 581 |
+
with torch.no_grad():
|
| 582 |
+
outputs = model(**inputs)
|
| 583 |
+
hidden_states = list(outputs.hidden_states)
|
| 584 |
+
attentions = list(outputs.attentions)
|
| 585 |
+
|
| 586 |
+
# Move to CPU to free VRAM
|
| 587 |
+
hidden_states = [hs.to("cpu") for hs in hidden_states]
|
| 588 |
+
attentions = [att.to("cpu") for att in attentions]
|
| 589 |
+
|
| 590 |
+
# Limit layers
|
| 591 |
+
layers = min(config.num_layers_limit, len(hidden_states))
|
| 592 |
+
tokens = hidden_states[0].shape[1]
|
| 593 |
+
|
| 594 |
+
# Clean up
|
| 595 |
+
try:
|
| 596 |
+
del model
|
| 597 |
+
if torch.cuda.is_available():
|
| 598 |
+
torch.cuda.empty_cache()
|
| 599 |
+
gc.collect()
|
| 600 |
+
except Exception:
|
| 601 |
+
pass
|
| 602 |
+
|
| 603 |
+
metadata = {
|
| 604 |
+
"mode": "full_model",
|
| 605 |
+
"model_name": model_name,
|
| 606 |
+
"actual_layers": len(hidden_states),
|
| 607 |
+
"used_layers": layers,
|
| 608 |
+
"tokens": tokens
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
return Latents(
|
| 612 |
+
hidden_states=hidden_states[:layers],
|
| 613 |
+
attentions=attentions[:layers],
|
| 614 |
+
num_layers=layers,
|
| 615 |
+
num_tokens=tokens,
|
| 616 |
+
metadata=metadata
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
class LatentExtractorFactory:
|
| 621 |
+
"""Factory for creating appropriate latent extractors."""
|
| 622 |
+
|
| 623 |
+
@staticmethod
|
| 624 |
+
def create(compute_mode: ComputeMode) -> LatentExtractor:
|
| 625 |
+
"""Create a latent extractor based on compute mode."""
|
| 626 |
+
if compute_mode == ComputeMode.MOCK_LATENTS:
|
| 627 |
+
return MockLatentExtractor()
|
| 628 |
+
else:
|
| 629 |
+
return ModelLatentExtractor()
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
# MIDI Generation
|
| 633 |
+
|
| 634 |
+
class MIDIRenderer:
|
| 635 |
+
"""Renders MIDI files from latents."""
|
| 636 |
+
|
| 637 |
+
def __init__(self, config: GenerationConfig, instrument_manager: InstrumentPresetManager):
|
| 638 |
+
"""Initialize MIDI renderer."""
|
| 639 |
+
self.config = config
|
| 640 |
+
self.instrument_manager = instrument_manager
|
| 641 |
+
self.note_generator = NoteGenerator(config)
|
| 642 |
+
self.math_utils = MusicMathUtils()
|
| 643 |
+
|
| 644 |
+
def render(self, latents: Latents) -> Tuple[bytes, Dict[str, Any]]:
|
| 645 |
+
"""Render MIDI from latents."""
|
| 646 |
+
# Set random seeds for reproducibility
|
| 647 |
+
np.random.seed(self.config.seed)
|
| 648 |
+
random.seed(self.config.seed)
|
| 649 |
+
torch.manual_seed(self.config.seed)
|
| 650 |
+
|
| 651 |
+
# Prepare data
|
| 652 |
+
scale = np.array(self.config.scale.notes, dtype=int)
|
| 653 |
+
num_layers = latents.num_layers
|
| 654 |
+
num_tokens = latents.num_tokens
|
| 655 |
+
|
| 656 |
+
# Convert tensors to numpy
|
| 657 |
+
hidden_states = [
|
| 658 |
+
hs.float().numpy() if isinstance(hs, torch.Tensor) else hs
|
| 659 |
+
for hs in latents.hidden_states
|
| 660 |
+
]
|
| 661 |
+
attentions = [
|
| 662 |
+
att.float().numpy() if isinstance(att, torch.Tensor) else att
|
| 663 |
+
for att in latents.attentions
|
| 664 |
+
]
|
| 665 |
+
|
| 666 |
+
# Get instrument mappings
|
| 667 |
+
instrument_mappings = self.instrument_manager.get_preset(self.config.instrument_preset)
|
| 668 |
+
|
| 669 |
+
# Create MIDI file and tracks
|
| 670 |
+
midi_file = MidiFile()
|
| 671 |
+
tracks = self._create_tracks(midi_file, num_layers, instrument_mappings)
|
| 672 |
+
|
| 673 |
+
# Generate notes
|
| 674 |
+
stats = self._generate_notes(
|
| 675 |
+
tracks, hidden_states, attentions,
|
| 676 |
+
scale, num_tokens, instrument_mappings
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
# Convert to bytes
|
| 680 |
+
bio = io.BytesIO()
|
| 681 |
+
midi_file.save(file=bio)
|
| 682 |
+
bio.seek(0)
|
| 683 |
+
|
| 684 |
+
# Prepare metadata
|
| 685 |
+
metadata = {
|
| 686 |
+
"config": self.config.to_dict(),
|
| 687 |
+
"latents_info": latents.metadata,
|
| 688 |
+
"stats": stats,
|
| 689 |
+
"timestamp": time.time()
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
return bio.read(), metadata
|
| 693 |
+
|
| 694 |
+
def _create_tracks(
|
| 695 |
+
self,
|
| 696 |
+
midi_file: MidiFile,
|
| 697 |
+
num_layers: int,
|
| 698 |
+
instrument_mappings: List[InstrumentMapping]
|
| 699 |
+
) -> List[MidiTrack]:
|
| 700 |
+
"""Create MIDI tracks with instrument assignments."""
|
| 701 |
+
tracks = []
|
| 702 |
+
|
| 703 |
for layer_idx in range(num_layers):
|
| 704 |
+
track = MidiTrack()
|
| 705 |
+
midi_file.tracks.append(track)
|
| 706 |
+
tracks.append(track)
|
| 707 |
+
|
| 708 |
+
# Get instrument mapping for this layer
|
| 709 |
+
if layer_idx < len(instrument_mappings):
|
| 710 |
+
mapping = instrument_mappings[layer_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
else:
|
| 712 |
+
# Default to piano if not enough mappings
|
| 713 |
+
mapping = InstrumentMapping(0, MusicRole.MELODY, layer_idx % 16)
|
| 714 |
+
|
| 715 |
+
# Set instrument
|
| 716 |
+
track.append(Message(
|
| 717 |
+
"program_change",
|
| 718 |
+
program=mapping.program,
|
| 719 |
+
time=0,
|
| 720 |
+
channel=mapping.channel
|
| 721 |
+
))
|
| 722 |
+
|
| 723 |
+
# Add track name
|
| 724 |
+
if mapping.name:
|
| 725 |
+
track.append(mido.MetaMessage(
|
| 726 |
+
"track_name",
|
| 727 |
+
name=f"{mapping.name} - {mapping.role.value}",
|
| 728 |
+
time=0
|
| 729 |
+
))
|
| 730 |
+
|
| 731 |
+
return tracks
|
| 732 |
+
|
| 733 |
+
def _generate_notes(
|
| 734 |
+
self,
|
| 735 |
+
tracks: List[MidiTrack],
|
| 736 |
+
hidden_states: List[np.ndarray],
|
| 737 |
+
attentions: List[np.ndarray],
|
| 738 |
+
scale: np.ndarray,
|
| 739 |
+
num_tokens: int,
|
| 740 |
+
instrument_mappings: List[InstrumentMapping]
|
| 741 |
+
) -> Dict[str, Any]:
|
| 742 |
+
"""Generate notes for all tracks."""
|
| 743 |
+
current_time = [0] * len(tracks)
|
| 744 |
+
notes_count = [0] * len(tracks)
|
| 745 |
+
|
| 746 |
+
for token_idx in range(num_tokens):
|
| 747 |
+
# Update time periodically
|
| 748 |
+
if token_idx > 0 and token_idx % 4 == 0:
|
| 749 |
+
for layer_idx in range(len(tracks)):
|
| 750 |
+
current_time[layer_idx] += self.config.base_tempo
|
| 751 |
+
|
| 752 |
+
# Calculate panning
|
| 753 |
+
pan = 64 + int(32 * np.sin(token_idx * math.pi / max(1, num_tokens)))
|
| 754 |
+
|
| 755 |
+
# Generate notes for each layer
|
| 756 |
+
for layer_idx in range(len(tracks)):
|
| 757 |
+
if layer_idx >= len(instrument_mappings):
|
| 758 |
+
continue
|
| 759 |
+
|
| 760 |
+
mapping = instrument_mappings[layer_idx]
|
| 761 |
+
|
| 762 |
+
# Get attention and hidden state
|
| 763 |
+
attn_matrix = attentions[min(layer_idx, len(attentions) - 1)][0, :, token_idx, :]
|
| 764 |
+
attention_strength = float(np.mean(attn_matrix))
|
| 765 |
+
layer_vec = hidden_states[layer_idx][0, token_idx]
|
| 766 |
+
|
| 767 |
+
# Check if note should be played
|
| 768 |
+
if not self.note_generator.should_play_note(
|
| 769 |
+
layer_idx, token_idx, attention_strength,
|
| 770 |
+
layer_vec, num_tokens, mapping.role
|
| 771 |
+
):
|
| 772 |
+
continue
|
| 773 |
+
|
| 774 |
+
# Generate notes
|
| 775 |
+
notes_to_play = self.note_generator.generate_notes_for_role(
|
| 776 |
+
mapping.role, layer_vec, scale
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
# Calculate velocity and duration
|
| 780 |
+
velocity = self.note_generator.calculate_velocity(
|
| 781 |
+
mapping.role, attention_strength
|
| 782 |
+
)
|
| 783 |
+
duration = self.note_generator.calculate_duration(
|
| 784 |
+
mapping.role, attn_matrix
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
# Add notes to track
|
| 788 |
+
for note in notes_to_play:
|
| 789 |
+
note = max(21, min(108, int(note))) # Clamp to piano range
|
| 790 |
+
|
| 791 |
+
tracks[layer_idx].append(Message(
|
| 792 |
+
"note_on",
|
| 793 |
+
note=note,
|
| 794 |
+
velocity=velocity,
|
| 795 |
+
time=current_time[layer_idx],
|
| 796 |
+
channel=mapping.channel
|
| 797 |
+
))
|
| 798 |
+
|
| 799 |
+
tracks[layer_idx].append(Message(
|
| 800 |
+
"note_off",
|
| 801 |
+
note=note,
|
| 802 |
+
velocity=0,
|
| 803 |
+
time=duration,
|
| 804 |
+
channel=mapping.channel
|
| 805 |
+
))
|
| 806 |
+
|
| 807 |
+
current_time[layer_idx] = 0
|
| 808 |
+
notes_count[layer_idx] += 1
|
| 809 |
+
|
| 810 |
+
# Set panning on first token
|
| 811 |
+
if token_idx == 0:
|
| 812 |
+
tracks[layer_idx].append(Message(
|
| 813 |
+
"control_change",
|
| 814 |
+
control=10,
|
| 815 |
+
value=pan,
|
| 816 |
+
time=0,
|
| 817 |
+
channel=mapping.channel
|
| 818 |
+
))
|
| 819 |
+
|
| 820 |
+
return {
|
| 821 |
+
"num_layers": len(tracks),
|
| 822 |
+
"num_tokens": num_tokens,
|
| 823 |
+
"notes_per_layer": notes_count,
|
| 824 |
+
"total_notes": int(sum(notes_count)),
|
| 825 |
+
"tempo_ticks_per_beat": int(self.config.base_tempo),
|
| 826 |
+
"scale": list(map(int, scale.tolist())),
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
# Main Orchestrator
|
| 831 |
+
|
| 832 |
+
class LLMForestOrchestra:
|
| 833 |
+
"""Main orchestrator class that coordinates the entire pipeline."""
|
| 834 |
+
|
| 835 |
+
DEFAULT_MODEL = "unsloth/Qwen3-14B-Base"
|
| 836 |
+
|
| 837 |
+
def __init__(self):
|
| 838 |
+
"""Initialize the orchestra."""
|
| 839 |
+
self.scale_manager = ScaleManager()
|
| 840 |
+
self.instrument_manager = InstrumentPresetManager()
|
| 841 |
+
self.saved_configs: Dict[str, GenerationConfig] = {}
|
| 842 |
+
|
| 843 |
+
def generate(
|
| 844 |
+
self,
|
| 845 |
+
text: str,
|
| 846 |
+
model_name: str,
|
| 847 |
+
compute_mode: str,
|
| 848 |
+
base_tempo: int,
|
| 849 |
+
velocity_range: Tuple[int, int],
|
| 850 |
+
scale_name: str,
|
| 851 |
+
custom_scale_notes: Optional[List[int]],
|
| 852 |
+
num_layers: int,
|
| 853 |
+
instrument_preset: str,
|
| 854 |
+
seed: int,
|
| 855 |
+
quantization_grid: int = 120,
|
| 856 |
+
octave_range: int = 2,
|
| 857 |
+
dynamics_curve: str = "linear"
|
| 858 |
+
) -> Tuple[str, Dict[str, Any]]:
|
| 859 |
+
"""Generate MIDI from text input."""
|
| 860 |
+
# Get or create scale
|
| 861 |
+
if scale_name == "Custom":
|
| 862 |
+
if not custom_scale_notes:
|
| 863 |
+
raise ValueError("Custom scale requires note list")
|
| 864 |
+
scale = ScaleDefinition("Custom", custom_scale_notes)
|
| 865 |
+
else:
|
| 866 |
+
scale = self.scale_manager.get_scale(scale_name)
|
| 867 |
+
if scale is None:
|
| 868 |
+
raise ValueError(f"Unknown scale: {scale_name}")
|
| 869 |
+
|
| 870 |
+
# Create configuration
|
| 871 |
+
config = GenerationConfig(
|
| 872 |
+
model_name=model_name or self.DEFAULT_MODEL,
|
| 873 |
+
compute_mode=ComputeMode(compute_mode),
|
| 874 |
+
base_tempo=base_tempo,
|
| 875 |
+
velocity_range=velocity_range,
|
| 876 |
+
scale=scale,
|
| 877 |
+
num_layers_limit=num_layers,
|
| 878 |
+
seed=seed,
|
| 879 |
+
instrument_preset=instrument_preset,
|
| 880 |
+
quantization_grid=quantization_grid,
|
| 881 |
+
octave_range=octave_range,
|
| 882 |
+
dynamics_curve=dynamics_curve
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
# Validate configuration
|
| 886 |
+
config.validate()
|
| 887 |
+
|
| 888 |
+
# Extract latents
|
| 889 |
+
extractor = LatentExtractorFactory.create(config.compute_mode)
|
| 890 |
+
latents = extractor.extract(text, config)
|
| 891 |
+
|
| 892 |
+
# Render MIDI
|
| 893 |
+
renderer = MIDIRenderer(config, self.instrument_manager)
|
| 894 |
+
midi_bytes, metadata = renderer.render(latents)
|
| 895 |
+
|
| 896 |
+
# Save MIDI file
|
| 897 |
+
filename = f"llm_forest_orchestra_{uuid.uuid4().hex[:8]}.mid"
|
| 898 |
+
with open(filename, "wb") as f:
|
| 899 |
+
f.write(midi_bytes)
|
| 900 |
+
|
| 901 |
+
return filename, metadata
|
| 902 |
+
|
| 903 |
+
def save_config(self, name: str, config: GenerationConfig):
|
| 904 |
+
"""Save a configuration for later use."""
|
| 905 |
+
self.saved_configs[name] = config
|
| 906 |
+
|
| 907 |
+
def load_config(self, name: str) -> Optional[GenerationConfig]:
|
| 908 |
+
"""Load a saved configuration."""
|
| 909 |
+
return self.saved_configs.get(name)
|
| 910 |
+
|
| 911 |
+
def export_config(self, config: GenerationConfig, filepath: str):
|
| 912 |
+
"""Export configuration to JSON file."""
|
| 913 |
+
with open(filepath, "w") as f:
|
| 914 |
+
json.dump(config.to_dict(), f, indent=2)
|
| 915 |
+
|
| 916 |
+
def import_config(self, filepath: str) -> GenerationConfig:
|
| 917 |
+
"""Import configuration from JSON file."""
|
| 918 |
+
with open(filepath, "r") as f:
|
| 919 |
+
data = json.load(f)
|
| 920 |
+
return GenerationConfig.from_dict(data, self.scale_manager)
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
# Gradio UI
|
| 924 |
+
|
| 925 |
+
class GradioInterface:
|
| 926 |
+
"""Manages the Gradio user interface."""
|
| 927 |
+
|
| 928 |
+
DESCRIPTION = """
|
| 929 |
+
# 🌲 LLM Forest Orchestra — Sonify Transformer Internals
|
| 930 |
+
|
| 931 |
+
Transform the hidden states and attention patterns of language models into multi-layered musical compositions.
|
| 932 |
+
|
| 933 |
+
## 🍄 Inspiration
|
| 934 |
+
|
| 935 |
+
This project is inspired by the way **mushrooms and mycelial networks in forests**
|
| 936 |
+
connect plants and trees, forming a living web of communication and resource sharing.
|
| 937 |
+
These connections, can be turned into ethereal music.
|
| 938 |
+
Just as signals move through these hidden connections, transformer models also
|
| 939 |
+
pass hidden states and attentions across their layers. Here, those hidden
|
| 940 |
+
connections are translated into **music**, analogous to the forest's secret orchestra.
|
| 941 |
+
|
| 942 |
+
## Features
|
| 943 |
+
- **Two compute modes**: Full model (GPU) or Mock latents (CPU-friendly)
|
| 944 |
+
- **Multiple musical scales**: From pentatonic to chromatic
|
| 945 |
+
- **Instrument presets**: Orchestral, electronic, ensemble, and more
|
| 946 |
+
- **Advanced controls**: Dynamics curves, quantization, velocity ranges
|
| 947 |
+
- **Export**: Standard MIDI files for further editing in your DAW
|
| 948 |
+
"""
|
| 949 |
+
|
| 950 |
+
EXAMPLE_TEXT = """Joy cascades in golden waterfalls, crashing into pools of melancholy blue.
|
| 951 |
+
Anger burns red through veins of marble, while serenity floats on clouds of softest grey.
|
| 952 |
+
Love pulses in waves of crimson and rose, intertwining with longing's purple haze.
|
| 953 |
+
Each feeling resonates at its own frequency, painting music across the soul's canvas."""
|
| 954 |
+
|
| 955 |
+
def __init__(self, orchestra: LLMForestOrchestra):
|
| 956 |
+
"""Initialize the interface."""
|
| 957 |
+
self.orchestra = orchestra
|
| 958 |
+
|
| 959 |
+
def create_interface(self) -> gr.Blocks:
|
| 960 |
+
"""Create the Gradio interface."""
|
| 961 |
+
with gr.Blocks(title="LLM Forest Orchestra", theme=gr.themes.Soft()) as demo:
|
| 962 |
+
gr.Markdown(self.DESCRIPTION)
|
| 963 |
+
|
| 964 |
+
with gr.Tabs():
|
| 965 |
+
with gr.TabItem("🎵 Generate Music"):
|
| 966 |
+
self._create_generation_tab()
|
| 967 |
+
|
| 968 |
+
return demo
|
| 969 |
+
|
| 970 |
+
def _create_generation_tab(self):
|
| 971 |
+
"""Create the main generation tab."""
|
| 972 |
+
with gr.Row():
|
| 973 |
+
with gr.Column(scale=1):
|
| 974 |
+
text_input = gr.Textbox(
|
| 975 |
+
value=self.EXAMPLE_TEXT,
|
| 976 |
+
label="Input Text",
|
| 977 |
+
lines=8,
|
| 978 |
+
placeholder="Enter text to sonify..."
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
model_name = gr.Textbox(
|
| 982 |
+
value=self.orchestra.DEFAULT_MODEL,
|
| 983 |
+
label="Hugging Face Model",
|
| 984 |
+
info="Model must support output_hidden_states and output_attentions"
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
compute_mode = gr.Radio(
|
| 988 |
+
choices=["Full model", "Mock latents"],
|
| 989 |
+
value="Mock latents",
|
| 990 |
+
label="Compute Mode",
|
| 991 |
+
info="Mock latents for quick CPU-only demo"
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
with gr.Row():
|
| 995 |
+
instrument_preset = gr.Dropdown(
|
| 996 |
+
choices=self.orchestra.instrument_manager.list_presets(),
|
| 997 |
+
value="Ensemble (melody+bass+pad etc.)",
|
| 998 |
+
label="Instrument Preset"
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
scale_choice = gr.Dropdown(
|
| 1002 |
+
choices=self.orchestra.scale_manager.list_scales() + ["Custom"],
|
| 1003 |
+
value="C pentatonic",
|
| 1004 |
+
label="Musical Scale"
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
custom_scale = gr.Textbox(
|
| 1008 |
+
value="",
|
| 1009 |
+
label="Custom Scale Notes",
|
| 1010 |
+
placeholder="60,62,65,67,70",
|
| 1011 |
+
visible=False
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
with gr.Row():
|
| 1015 |
+
base_tempo = gr.Slider(
|
| 1016 |
+
120, 960,
|
| 1017 |
+
value=480,
|
| 1018 |
+
step=1,
|
| 1019 |
+
label="Tempo (ticks per beat)"
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
num_layers = gr.Slider(
|
| 1023 |
+
1, 6,
|
| 1024 |
+
value=6,
|
| 1025 |
+
step=1,
|
| 1026 |
+
label="Max Layers"
|
| 1027 |
+
)
|
| 1028 |
+
|
| 1029 |
+
with gr.Row():
|
| 1030 |
+
velocity_low = gr.Slider(
|
| 1031 |
+
1, 126,
|
| 1032 |
+
value=40,
|
| 1033 |
+
step=1,
|
| 1034 |
+
label="Min Velocity"
|
| 1035 |
+
)
|
| 1036 |
+
|
| 1037 |
+
velocity_high = gr.Slider(
|
| 1038 |
+
2, 127,
|
| 1039 |
+
value=90,
|
| 1040 |
+
step=1,
|
| 1041 |
+
label="Max Velocity"
|
| 1042 |
+
)
|
| 1043 |
+
|
| 1044 |
+
seed = gr.Number(
|
| 1045 |
+
value=42,
|
| 1046 |
+
precision=0,
|
| 1047 |
+
label="Random Seed"
|
| 1048 |
+
)
|
| 1049 |
+
|
| 1050 |
+
generate_btn = gr.Button(
|
| 1051 |
+
"🎼 Generate MIDI",
|
| 1052 |
+
variant="primary",
|
| 1053 |
+
size="lg"
|
| 1054 |
+
)
|
| 1055 |
+
|
| 1056 |
+
with gr.Column(scale=1):
|
| 1057 |
+
midi_output = gr.File(
|
| 1058 |
+
label="Generated MIDI File",
|
| 1059 |
+
file_types=[".mid", ".midi"]
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
stats_display = gr.Markdown(label="Quick Stats")
|
| 1063 |
+
|
| 1064 |
+
metadata_json = gr.Code(
|
| 1065 |
+
label="Metadata (JSON)",
|
| 1066 |
+
language="json"
|
| 1067 |
+
)
|
| 1068 |
+
|
| 1069 |
+
with gr.Row():
|
| 1070 |
+
play_instructions = gr.Markdown(
|
| 1071 |
+
"""
|
| 1072 |
+
### 🎧 How to Play
|
| 1073 |
+
1. Download the MIDI file
|
| 1074 |
+
2. Open in any DAW or MIDI player
|
| 1075 |
+
3. Adjust instruments and effects as desired
|
| 1076 |
+
4. Export to audio format
|
| 1077 |
+
"""
|
| 1078 |
+
)
|
| 1079 |
+
|
| 1080 |
+
# Set up interactions
|
| 1081 |
+
def update_custom_scale_visibility(choice):
|
| 1082 |
+
return gr.update(visible=(choice == "Custom"))
|
| 1083 |
+
|
| 1084 |
+
scale_choice.change(
|
| 1085 |
+
update_custom_scale_visibility,
|
| 1086 |
+
inputs=[scale_choice],
|
| 1087 |
+
outputs=[custom_scale]
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
+
def generate_wrapper(
|
| 1091 |
+
text, model_name, compute_mode, base_tempo,
|
| 1092 |
+
velocity_low, velocity_high, scale_choice,
|
| 1093 |
+
custom_scale, num_layers, instrument_preset, seed
|
| 1094 |
+
):
|
| 1095 |
+
"""Wrapper for generation with error handling."""
|
| 1096 |
+
try:
|
| 1097 |
+
# Parse custom scale if needed
|
| 1098 |
+
custom_notes = None
|
| 1099 |
+
if scale_choice == "Custom" and custom_scale:
|
| 1100 |
+
custom_notes = [int(x.strip()) for x in custom_scale.split(",")]
|
| 1101 |
+
|
| 1102 |
+
# Generate
|
| 1103 |
+
filename, metadata = self.orchestra.generate(
|
| 1104 |
+
text=text,
|
| 1105 |
+
model_name=model_name,
|
| 1106 |
+
compute_mode=compute_mode,
|
| 1107 |
+
base_tempo=int(base_tempo),
|
| 1108 |
+
velocity_range=(int(velocity_low), int(velocity_high)),
|
| 1109 |
+
scale_name=scale_choice,
|
| 1110 |
+
custom_scale_notes=custom_notes,
|
| 1111 |
+
num_layers=int(num_layers),
|
| 1112 |
+
instrument_preset=instrument_preset,
|
| 1113 |
+
seed=int(seed)
|
| 1114 |
+
)
|
| 1115 |
+
|
| 1116 |
+
# Format stats
|
| 1117 |
+
stats = metadata.get("stats", {})
|
| 1118 |
+
stats_text = f"""
|
| 1119 |
+
### Generation Statistics
|
| 1120 |
+
- **Layers Used**: {stats.get('num_layers', 'N/A')}
|
| 1121 |
+
- **Tokens Processed**: {stats.get('num_tokens', 'N/A')}
|
| 1122 |
+
- **Total Notes**: {stats.get('total_notes', 'N/A')}
|
| 1123 |
+
- **Notes per Layer**: {stats.get('notes_per_layer', [])}
|
| 1124 |
+
- **Scale**: {stats.get('scale', [])}
|
| 1125 |
+
- **Tempo**: {stats.get('tempo_ticks_per_beat', 'N/A')} ticks/beat
|
| 1126 |
+
"""
|
| 1127 |
+
|
| 1128 |
+
return filename, stats_text, json.dumps(metadata, indent=2)
|
| 1129 |
+
|
| 1130 |
+
except Exception as e:
|
| 1131 |
+
error_msg = f"### ❌ Error\n{str(e)}"
|
| 1132 |
+
return None, error_msg, json.dumps({"error": str(e)}, indent=2)
|
| 1133 |
+
|
| 1134 |
+
generate_btn.click(
|
| 1135 |
+
fn=generate_wrapper,
|
| 1136 |
+
inputs=[
|
| 1137 |
+
text_input, model_name, compute_mode, base_tempo,
|
| 1138 |
+
velocity_low, velocity_high, scale_choice,
|
| 1139 |
+
custom_scale, num_layers, instrument_preset, seed
|
| 1140 |
+
],
|
| 1141 |
+
outputs=[midi_output, stats_display, metadata_json]
|
| 1142 |
+
)
|
| 1143 |
+
|
| 1144 |
+
|
| 1145 |
+
# Main Entry Point
|
| 1146 |
+
|
| 1147 |
+
def main():
|
| 1148 |
+
"""Main entry point for the application."""
|
| 1149 |
+
# Initialize orchestra
|
| 1150 |
+
orchestra = LLMForestOrchestra()
|
| 1151 |
+
|
| 1152 |
+
# Create interface
|
| 1153 |
+
interface = GradioInterface(orchestra)
|
| 1154 |
+
demo = interface.create_interface()
|
| 1155 |
+
|
| 1156 |
+
# Launch
|
| 1157 |
+
demo.launch()
|
| 1158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1159 |
|
| 1160 |
if __name__ == "__main__":
|
| 1161 |
+
main()
|