asigalov61 commited on
Commit
97d0dfa
·
verified ·
1 Parent(s): 32b4a26

Upload Orpheus_Auto_Continuations_Generator.ipynb

Browse files
inference_code/Orpheus_Auto_Continuations_Generator.ipynb ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "4821055b-c45c-4a9f-8196-2a9d09df6c39",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Orpheus Auto-Continuations Generator (ver. 1.0)\n",
9
+ "\n",
10
+ "***\n",
11
+ "\n",
12
+ "Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools\n",
13
+ "\n",
14
+ "***\n",
15
+ "\n",
16
+ "WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/\n",
17
+ "\n",
18
+ "***\n",
19
+ "\n",
20
+ "#### Project Los Angeles\n",
21
+ "\n",
22
+ "#### Tegridy Code 2025\n",
23
+ "\n",
24
+ "***"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "id": "a6e2249a-6b57-4193-830d-7772c29b6f38",
30
+ "metadata": {},
31
+ "source": [
32
+ "# Setup environment"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "id": "1de7766b-1df0-4281-9322-650068da2a2d",
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "!git clone --depth 1 https://github.com/asigalov61/tegridy-tools"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "7e9de3f7-4a3d-41d0-a6b6-1bc8fc98fa6e",
49
+ "metadata": {
50
+ "scrolled": true
51
+ },
52
+ "outputs": [],
53
+ "source": [
54
+ "!pip install huggingface_hub\n",
55
+ "!pip install hf-transfer\n",
56
+ "\n",
57
+ "!pip install ipywidgets\n",
58
+ "!pip install tqdm\n",
59
+ "\n",
60
+ "!pip install einx\n",
61
+ "!pip install einops\n",
62
+ "!pip install torch-summary\n",
63
+ "!pip install scikit-learn\n",
64
+ "!pip install matplotlib"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "id": "68799e16-da90-4f1b-97c8-813bd5df665e",
70
+ "metadata": {},
71
+ "source": [
72
+ "# Import modules"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "id": "6073f1b3-edca-49b1-bfed-2029a9efda35",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "# Load modules and make data dir\n",
83
+ "\n",
84
+ "print('Loading modules...')\n",
85
+ "\n",
86
+ "import os\n",
87
+ "\n",
88
+ "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"\n",
89
+ "\n",
90
+ "import pickle\n",
91
+ "import random\n",
92
+ "import tqdm\n",
93
+ "\n",
94
+ "!set USE_FLASH_ATTENTION=1\n",
95
+ "os.environ['USE_FLASH_ATTENTION'] = '1'\n",
96
+ "\n",
97
+ "import torch\n",
98
+ "import numpy as np\n",
99
+ "\n",
100
+ "from torchsummary import summary\n",
101
+ "from sklearn.metrics.pairwise import cosine_similarity\n",
102
+ "\n",
103
+ "%cd /home/ubuntu/tegridy-tools/tegridy-tools/\n",
104
+ "\n",
105
+ "import TMIDIX\n",
106
+ "\n",
107
+ "%cd /home/ubuntu/tegridy-tools/tegridy-tools/X-Transformer\n",
108
+ "\n",
109
+ "from x_transformer_2_3_1 import *\n",
110
+ "\n",
111
+ "torch.set_float32_matmul_precision('high')\n",
112
+ "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n",
113
+ "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n",
114
+ "torch.backends.cuda.enable_flash_sdp(True)\n",
115
+ "torch.backends.cuda.enable_cudnn_sdp(False)\n",
116
+ "\n",
117
+ "!set USE_FLASH_ATTENTION=1\n",
118
+ "\n",
119
+ "%cd /home/ubuntu/\n",
120
+ "\n",
121
+ "import random\n",
122
+ "\n",
123
+ "from huggingface_hub import hf_hub_download\n",
124
+ "\n",
125
+ "print('Done')\n",
126
+ "\n",
127
+ "print('Torch version:', torch.__version__)"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "markdown",
132
+ "id": "f94d805e-ac8a-400b-9e9c-a6ff572c4b80",
133
+ "metadata": {},
134
+ "source": [
135
+ "# Download Orpheus model and Orpheus embeddings dataset"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "id": "2f8d75d4-982b-4a60-a234-afc71aa6dd84",
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "print('=' * 70)\n",
146
+ "print('Donwloading Orpheus Music Transformer model...')\n",
147
+ "print('=' * 70)\n",
148
+ "\n",
149
+ "model_file = hf_hub_download(repo_id='asigalov61/Orpheus-Music-Transformer',\n",
150
+ " filename='Orpheus_Music_Transformer_Trained_Model_128497_steps_0.6934_loss_0.7927_acc.pth',\n",
151
+ " local_dir='/home/ubuntu/Models/',\n",
152
+ " )\n",
153
+ "\n",
154
+ "\n",
155
+ "print('=' * 70)\n",
156
+ "print('Donwloading Orpheus embeddings dataset...')\n",
157
+ "print('=' * 70)\n",
158
+ "\n",
159
+ "emb_file = hf_hub_download(repo_id='asigalov61/Orpheus-Music-Transformer',\n",
160
+ " filename='orpheus_data/1765807_Orpheus_Training_Data_Reference_MP_Embeddings_CC_BY_NC_SA.npy',\n",
161
+ " local_dir='/home/ubuntu/Models/',\n",
162
+ " )\n",
163
+ "\n",
164
+ "print('=' * 70)\n",
165
+ "print('Done!')\n",
166
+ "print('=' * 70)"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "markdown",
171
+ "id": "30147799-9bc5-4acd-8352-d5fe309bd844",
172
+ "metadata": {},
173
+ "source": [
174
+ "# Load model and embeddings"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "id": "d95d650e-a6b6-4fca-bc3d-75bd1c042c06",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "#=================================================================\n",
185
+ "\n",
186
+ "def get_embeddings(inputs):\n",
187
+ " \n",
188
+ " with ctx:\n",
189
+ " with torch.no_grad():\n",
190
+ " out = model(inputs, return_outputs=True)\n",
191
+ " \n",
192
+ " cache = out[3]\n",
193
+ "\n",
194
+ " hidden = cache.layer_hiddens[-1]\n",
195
+ " \n",
196
+ " mean_pool = torch.mean(hidden, dim=1)\n",
197
+ " \n",
198
+ " return mean_pool.cpu().detach().numpy()\n",
199
+ "\n",
200
+ "#=================================================================\n",
201
+ "\n",
202
+ "exists_ratio = lambda sub, main, ratio: sum(x in set(main) for x in sub) / len(sub) >= ratio\n",
203
+ "\n",
204
+ "#=================================================================\n",
205
+ "\n",
206
+ "print('=' * 70)\n",
207
+ "print('Loading Orpheus Music Transformer model...')\n",
208
+ "print('=' * 70)\n",
209
+ "\n",
210
+ "SEQ_LEN = 8192\n",
211
+ "PAD_IDX = 18819\n",
212
+ "\n",
213
+ "model = TransformerWrapper(\n",
214
+ " num_tokens = PAD_IDX+1,\n",
215
+ " max_seq_len = SEQ_LEN,\n",
216
+ " attn_layers = Decoder(dim = 2048,\n",
217
+ " depth = 8,\n",
218
+ " heads = 32,\n",
219
+ " rotary_pos_emb = True,\n",
220
+ " attn_flash = True\n",
221
+ " )\n",
222
+ " )\n",
223
+ "\n",
224
+ "model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)\n",
225
+ "\n",
226
+ "print('=' * 70)\n",
227
+ "print('Loading model checkpoint...')\n",
228
+ "\n",
229
+ "model.load_state_dict(torch.load(model_file, weights_only=True))\n",
230
+ "\n",
231
+ "print('=' * 70)\n",
232
+ "\n",
233
+ "model.cuda()\n",
234
+ "model.eval()\n",
235
+ "\n",
236
+ "print('Done!')\n",
237
+ "\n",
238
+ "summary(model)\n",
239
+ "\n",
240
+ "dtype = torch.bfloat16\n",
241
+ "\n",
242
+ "ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)\n",
243
+ "\n",
244
+ "#=================================================================\n",
245
+ "\n",
246
+ "print('=' * 70)\n",
247
+ "print('Loading Orpheus embeddings dataset...')\n",
248
+ "print('=' * 70)\n",
249
+ "\n",
250
+ "embeddings = np.load(emb_file)\n",
251
+ "\n",
252
+ "print('=' * 70)\n",
253
+ "print('Done!')\n",
254
+ "print('=' * 70)"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "markdown",
259
+ "id": "e3a4cd32-4680-4d35-b46e-0022369715b7",
260
+ "metadata": {},
261
+ "source": [
262
+ "# Create IO dirs"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": null,
268
+ "id": "d27279a0-2892-4aca-a074-1ebe3e82bb94",
269
+ "metadata": {},
270
+ "outputs": [],
271
+ "source": [
272
+ "print('=' * 70) \n",
273
+ "print('Creating IO dirs...')\n",
274
+ "\n",
275
+ "input_midis_dir = '/home/ubuntu/Input MIDIs/'\n",
276
+ "output_midis_dir = '/home/ubuntu/Output MIDIs/'\n",
277
+ "\n",
278
+ "midi_files_list = []\n",
279
+ "\n",
280
+ "os.makedirs(input_midis_dir, exist_ok=True)\n",
281
+ "os.makedirs(output_midis_dir, exist_ok=True)\n",
282
+ "\n",
283
+ "print('Done!')\n",
284
+ "print('=' * 70) "
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "markdown",
289
+ "id": "03e505df-42bc-4246-8b11-1e98e7b2515a",
290
+ "metadata": {},
291
+ "source": [
292
+ "# Create MIDIs files list"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": null,
298
+ "id": "829f6093-67f5-4655-83c7-8af36ef60079",
299
+ "metadata": {},
300
+ "outputs": [],
301
+ "source": [
302
+ "print('=' * 70)\n",
303
+ "print('Creating MIDI files list...')\n",
304
+ "print('=' * 70) \n",
305
+ "\n",
306
+ "midi_files_list = TMIDIX.create_files_list([input_midis_dir])\n",
307
+ "print('=' * 70) "
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "markdown",
312
+ "id": "e2e041a5-5564-41d1-a2b2-0cc92dc713ae",
313
+ "metadata": {},
314
+ "source": [
315
+ "# Generate"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": null,
321
+ "id": "cb9a7864-028b-468f-9fda-72abe84d6edc",
322
+ "metadata": {
323
+ "scrolled": true
324
+ },
325
+ "outputs": [],
326
+ "source": [
327
+ "print('=' * 70) \n",
328
+ "print('Orpheus Auto-Continuations Generator')\n",
329
+ "print('=' * 70)\n",
330
+ "\n",
331
+ "#=========================================================================\n",
332
+ "# Generation options\n",
333
+ "#=========================================================================\n",
334
+ "\n",
335
+ "# Primary generation options\n",
336
+ "num_prime_tokens = 1024\n",
337
+ "num_songs_per_midi = 4\n",
338
+ "num_gen_chunks = 12\n",
339
+ "max_num_tries = 4\n",
340
+ "\n",
341
+ "# Model sampling options\n",
342
+ "num_gen_tokens = 512\n",
343
+ "batch_size = 32\n",
344
+ "temperature = 1.0\n",
345
+ "top_p_value = 0.96\n",
346
+ "num_mem_tokens = 7168 # up to 12 chunks\n",
347
+ "\n",
348
+ "# Advanced options\n",
349
+ "max_tok_rep_ratio = 0.95\n",
350
+ "num_rep_window_toks = 1024\n",
351
+ "num_emb_tokens = 1024\n",
352
+ "\n",
353
+ "# Aux options\n",
354
+ "score_var = 0.05\n",
355
+ "batch_size_step = 4\n",
356
+ "\n",
357
+ "#=========================================================================\n",
358
+ "\n",
359
+ "if not midi_files_list:\n",
360
+ " \n",
361
+ " print('=' * 70)\n",
362
+ " print('Generating prime tokens...')\n",
363
+ " print('=' * 70)\n",
364
+ "\n",
365
+ " x = torch.LongTensor([[18816, 0]] * batch_size).cuda()\n",
366
+ "\n",
367
+ " with ctx:\n",
368
+ " out = model.generate(x,\n",
369
+ " num_prime_tokens,\n",
370
+ " temperature=temperature,\n",
371
+ " filter_logits_fn=top_p,\n",
372
+ " filter_kwargs={'thres': top_p_value},\n",
373
+ " return_prime=True,\n",
374
+ " verbose=True)\n",
375
+ "\n",
376
+ " y = out.tolist()\n",
377
+ " \n",
378
+ " inp = torch.LongTensor(y).cuda()\n",
379
+ " \n",
380
+ " embs = get_embeddings(inp)\n",
381
+ " \n",
382
+ " scores = cosine_similarity(embeddings, embs).max(axis=0)\n",
383
+ "\n",
384
+ " scores = [o for o in scores if o != max(scores)]\n",
385
+ "\n",
386
+ " max_score = max(scores)\n",
387
+ "\n",
388
+ " max_score_idx = scores.index(max_score)\n",
389
+ " melody_chords = y[max_score_idx]\n",
390
+ "\n",
391
+ " midi_fname = 'Improvisation'\n",
392
+ " midi_files_list.append(midi_fname)\n",
393
+ " \n",
394
+ " print('=' * 70)\n",
395
+ " print('Done!')\n",
396
+ " print('=' * 70)\n",
397
+ " print('Generating songs for \"Improvisation\"')\n",
398
+ " print('=' * 70)\n",
399
+ " \n",
400
+ "#=========================================================================\n",
401
+ "\n",
402
+ "for midi_file in midi_files_list:\n",
403
+ "\n",
404
+ " if midi_file != 'Improvisation':\n",
405
+ "\n",
406
+ " midi_fname = os.path.splitext(os.path.basename(midi_file))[0]\n",
407
+ " \n",
408
+ " print('=' * 70)\n",
409
+ " print('Generating songs for MIDI file \"' + midi_fname + '\"')\n",
410
+ " print('-' * 70) \n",
411
+ " \n",
412
+ " #==============================================================================\n",
413
+ " \n",
414
+ " raw_score = TMIDIX.midi2single_track_ms_score(midi_file)\n",
415
+ " \n",
416
+ " escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True, apply_sustain=True)\n",
417
+ " \n",
418
+ " escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], sort_drums_last=True)\n",
419
+ " \n",
420
+ " escore_notes = TMIDIX.remove_duplicate_pitches_from_escore_notes(escore_notes)\n",
421
+ " \n",
422
+ " escore_notes = TMIDIX.fix_escore_notes_durations(escore_notes, min_notes_gap=0)\n",
423
+ " \n",
424
+ " dscore = TMIDIX.delta_score_notes(escore_notes)\n",
425
+ " \n",
426
+ " dcscore = TMIDIX.chordify_score([d[1:] for d in dscore])\n",
427
+ " \n",
428
+ " melody_chords = [18816]\n",
429
+ " \n",
430
+ " #=======================================================\n",
431
+ " # MAIN PROCESSING CYCLE\n",
432
+ " #=======================================================\n",
433
+ " \n",
434
+ " for i, c in enumerate(dcscore):\n",
435
+ " \n",
436
+ " delta_time = c[0][0]\n",
437
+ " \n",
438
+ " melody_chords.append(delta_time)\n",
439
+ " \n",
440
+ " for e in c:\n",
441
+ " \n",
442
+ " #=======================================================\n",
443
+ " \n",
444
+ " # Durations\n",
445
+ " dur = max(1, min(255, e[1]))\n",
446
+ " \n",
447
+ " # Patches\n",
448
+ " pat = max(0, min(128, e[5]))\n",
449
+ " \n",
450
+ " # Pitches\n",
451
+ " ptc = max(1, min(127, e[3]))\n",
452
+ " \n",
453
+ " # Velocities\n",
454
+ " # Calculating octo-velocity\n",
455
+ " \n",
456
+ " vel = max(8, min(127, e[4]))\n",
457
+ " velocity = round(vel / 15)-1\n",
458
+ " \n",
459
+ " #=======================================================\n",
460
+ " # FINAL NOTE SEQ\n",
461
+ " #=======================================================\n",
462
+ " \n",
463
+ " # Writing final note\n",
464
+ " pat_ptc = (128 * pat) + ptc \n",
465
+ " dur_vel = (8 * dur) + velocity\n",
466
+ " \n",
467
+ " melody_chords.extend([pat_ptc+256, dur_vel+16768]) # 18816\n",
468
+ "\n",
469
+ " #==============================================================================\n",
470
+ "\n",
471
+ " print('Total number of input tokens:', len(melody_chords))\n",
472
+ " print('=' * 70)\n",
473
+ "\n",
474
+ " #==============================================================================\n",
475
+ "\n",
476
+ " song_number = 0\n",
477
+ " \n",
478
+ " while song_number < num_songs_per_midi:\n",
479
+ "\n",
480
+ " print('Generating song #', song_number+1, '/', num_songs_per_midi)\n",
481
+ " print('=' * 70)\n",
482
+ " \n",
483
+ " song = melody_chords[:num_prime_tokens][-num_mem_tokens:]\n",
484
+ " \n",
485
+ " inp = torch.LongTensor([song]).cuda()\n",
486
+ " \n",
487
+ " embs = get_embeddings(inp)\n",
488
+ " \n",
489
+ " start_score = cosine_similarity(embeddings, embs).max(axis=0)[0]\n",
490
+ "\n",
491
+ " b_size = batch_size\n",
492
+ " stop = False\n",
493
+ " \n",
494
+ " for i in tqdm.tqdm(range(num_gen_chunks)):\n",
495
+ " \n",
496
+ " max_score = -1\n",
497
+ " num_tries = 0\n",
498
+ " \n",
499
+ " if i > 7:\n",
500
+ " bsize = b_size - batch_size_step\n",
501
+ " \n",
502
+ " while max_score < start_score - score_var:\n",
503
+ " \n",
504
+ " output = []\n",
505
+ " output_scores = []\n",
506
+ " \n",
507
+ " x = torch.LongTensor([song[-num_mem_tokens:]] * b_size).cuda()\n",
508
+ " \n",
509
+ " with ctx:\n",
510
+ " out = model.generate(x,\n",
511
+ " num_gen_tokens,\n",
512
+ " temperature=temperature,\n",
513
+ " filter_logits_fn=top_p,\n",
514
+ " filter_kwargs={'thres': top_p_value},\n",
515
+ " return_prime=True,\n",
516
+ " verbose=False)\n",
517
+ " \n",
518
+ " y = out.tolist()\n",
519
+ " \n",
520
+ " for yy in y:\n",
521
+ " if 18817 not in yy and 18818 not in yy and not exists_ratio(yy[-num_gen_tokens:], \n",
522
+ " song[-num_rep_window_toks:], \n",
523
+ " max_tok_rep_ratio\n",
524
+ " ):\n",
525
+ " output.append(yy[-num_emb_tokens:])\n",
526
+ " \n",
527
+ " if output:\n",
528
+ " \n",
529
+ " inp = torch.LongTensor(output).cuda()\n",
530
+ " \n",
531
+ " embs = get_embeddings(inp)\n",
532
+ " \n",
533
+ " scores = cosine_similarity(embeddings, embs).max(axis=0)\n",
534
+ " output_scores.extend(scores)\n",
535
+ " \n",
536
+ " scores = [o for o in output_scores if o != max(output_scores)]\n",
537
+ " \n",
538
+ " if not scores:\n",
539
+ " max_score = -1\n",
540
+ " num_tries += 1\n",
541
+ " \n",
542
+ " if num_tries == max_num_tries:\n",
543
+ " stop = True\n",
544
+ " break\n",
545
+ " \n",
546
+ " if i > max_num_tries:\n",
547
+ " song = song[:-num_gen_tokens]\n",
548
+ " \n",
549
+ " else:\n",
550
+ " max_score = max(scores)\n",
551
+ " \n",
552
+ " else:\n",
553
+ " num_tries += 1\n",
554
+ " \n",
555
+ " if num_tries == max_num_tries:\n",
556
+ " stop = True\n",
557
+ " break\n",
558
+ " \n",
559
+ " if i > max_num_tries:\n",
560
+ " song = song[:-num_gen_tokens]\n",
561
+ " \n",
562
+ " if stop:\n",
563
+ " break\n",
564
+ " \n",
565
+ " max_score_idx = output_scores.index(max_score)\n",
566
+ " max_score_chunk = output[max_score_idx]\n",
567
+ " \n",
568
+ " song.extend(max_score_chunk[-num_gen_tokens:])\n",
569
+ "\n",
570
+ " #==============================================================================\n",
571
+ " \n",
572
+ " if i > num_gen_chunks // 2:\n",
573
+ "\n",
574
+ " print('=' * 70)\n",
575
+ " print('Saving song...')\n",
576
+ " print('=' * 70)\n",
577
+ " \n",
578
+ " print('Sample INTs', song[:15])\n",
579
+ " \n",
580
+ " song_f = []\n",
581
+ " \n",
582
+ " time = 0\n",
583
+ " dur = 1\n",
584
+ " vel = 90\n",
585
+ " pitch = 60\n",
586
+ " channel = 0\n",
587
+ " patch = 0\n",
588
+ " \n",
589
+ " patches = [-1] * 16\n",
590
+ " \n",
591
+ " channels = [0] * 16\n",
592
+ " channels[9] = 1\n",
593
+ " \n",
594
+ " for ss in song:\n",
595
+ " \n",
596
+ " if 0 <= ss < 256:\n",
597
+ " \n",
598
+ " time += ss * 16\n",
599
+ " \n",
600
+ " if 256 <= ss < 16768:\n",
601
+ " \n",
602
+ " patch = (ss-256) // 128\n",
603
+ " \n",
604
+ " if patch < 128:\n",
605
+ " \n",
606
+ " if patch not in patches:\n",
607
+ " if 0 in channels:\n",
608
+ " cha = channels.index(0)\n",
609
+ " channels[cha] = 1\n",
610
+ " else:\n",
611
+ " cha = 15\n",
612
+ " \n",
613
+ " patches[cha] = patch\n",
614
+ " channel = patches.index(patch)\n",
615
+ " else:\n",
616
+ " channel = patches.index(patch)\n",
617
+ " \n",
618
+ " if patch == 128:\n",
619
+ " channel = 9\n",
620
+ " \n",
621
+ " pitch = (ss-256) % 128\n",
622
+ " \n",
623
+ " \n",
624
+ " if 16768 <= ss < 18816:\n",
625
+ " \n",
626
+ " dur = ((ss-16768) // 8) * 16\n",
627
+ " vel = (((ss-16768) % 8)+1) * 15\n",
628
+ " \n",
629
+ " song_f.append(['note', time, dur, channel, pitch, vel, patch])\n",
630
+ " \n",
631
+ " patches = [0 if x==-1 else x for x in patches]\n",
632
+ "\n",
633
+ " output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f)\n",
634
+ "\n",
635
+ " output_dir = os.path.join(output_midis_dir, midi_fname)\n",
636
+ "\n",
637
+ " os.makedirs(output_dir, exist_ok=True)\n",
638
+ " \n",
639
+ " detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,\n",
640
+ " output_signature = 'Orpheus Music Transformer',\n",
641
+ " output_file_name = output_dir + '/Orpheus-Music-Transformer-Composition_'+str(song_number+1).zfill(3),\n",
642
+ " track_name='Project Los Angeles',\n",
643
+ " list_of_MIDI_patches=patches\n",
644
+ " )\n",
645
+ "\n",
646
+ " song_number += 1\n",
647
+ " \n",
648
+ " print('=' * 70)\n",
649
+ " print('Done!')\n",
650
+ " print('=' * 70)"
651
+ ]
652
+ },
653
+ {
654
+ "cell_type": "markdown",
655
+ "id": "f892ac8a-9f5f-462d-b3b9-d4be1f78b31d",
656
+ "metadata": {},
657
+ "source": [
658
+ "# Congrats! You did it ! :)"
659
+ ]
660
+ }
661
+ ],
662
+ "metadata": {
663
+ "kernelspec": {
664
+ "display_name": "Python 3 (ipykernel)",
665
+ "language": "python",
666
+ "name": "python3"
667
+ },
668
+ "language_info": {
669
+ "codemirror_mode": {
670
+ "name": "ipython",
671
+ "version": 3
672
+ },
673
+ "file_extension": ".py",
674
+ "mimetype": "text/x-python",
675
+ "name": "python",
676
+ "nbconvert_exporter": "python",
677
+ "pygments_lexer": "ipython3",
678
+ "version": "3.10.12"
679
+ }
680
+ },
681
+ "nbformat": 4,
682
+ "nbformat_minor": 5
683
+ }