{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "VGrGd6__l5ch" }, "source": [ "# Orpheus Drums Transformer (ver. 1.0)\n", "\n", "***\n", "\n", "Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools\n", "\n", "***\n", "\n", "WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/\n", "\n", "***\n", "\n", "#### Project Los Angeles\n", "\n", "#### Tegridy Code 2025\n", "\n", "***" ] }, { "cell_type": "markdown", "metadata": { "id": "shLrgoXdl5cj" }, "source": [ "# GPU check" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X3rABEpKCO02" }, "outputs": [], "source": [ "!nvidia-smi" ] }, { "cell_type": "markdown", "metadata": { "id": "0RcVC4btl5ck" }, "source": [ "# Setup environment" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "viHgEaNACPTs" }, "outputs": [], "source": [ "!git clone --depth 1 https://github.com/asigalov61/tegridy-tools" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vK40g6V_BTNj" }, "outputs": [], "source": [ "!pip install huggingface_hub\n", "!pip install hf-transfer\n", "!pip install ipywidgets\n", "!pip install tqdm\n", "\n", "!pip install einx\n", "!pip install einops\n", "!pip install torch-summary" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DzCOZU_gBiQV" }, "outputs": [], "source": [ "# Load modules and make data dir\n", "\n", "print('Loading modules...')\n", "\n", "import os\n", "\n", "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\"\n", "\n", "import pickle\n", "import random\n", "import secrets\n", "import tqdm\n", "import math\n", "\n", "import gc\n", "\n", "!set USE_FLASH_ATTENTION=1\n", "os.environ['USE_FLASH_ATTENTION'] = '1'\n", "\n", "import torch\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "from torchsummary import summary\n", "from sklearn import metrics\n", "\n", "%cd /home/ubuntu/tegridy-tools/tegridy-tools/\n", "\n", "import TMIDIX\n", "\n", "%cd /home/ubuntu/tegridy-tools/tegridy-tools/X-Transformer\n", "\n", "from x_transformer_2_3_1 import *\n", "\n", "torch.set_float32_matmul_precision('high')\n", "torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n", "torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n", "torch.backends.cuda.enable_flash_sdp(True)\n", "torch.backends.cuda.enable_cudnn_sdp(False)\n", "\n", "!set USE_FLASH_ATTENTION=1\n", "\n", "%cd /home/ubuntu/\n", "\n", "import random\n", "\n", "from huggingface_hub import hf_hub_download\n", "\n", "print('Done')\n", "\n", "print('Torch version:', torch.__version__)" ] }, { "cell_type": "markdown", "metadata": { "id": "feXay_Ed7mG5" }, "source": [ "# Download model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SA8qQSzbWslM" }, "outputs": [], "source": [ "hf_hub_download(repo_id='asigalov61/Orpheus-Music-Transformer',\n", " filename='Orpheus_Music_Transformer_Trained_Model_96332_steps_0.82_loss_0.748_acc.pth',\n", " local_dir='/home/ubuntu/Models/',\n", " repo_type='model'\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gSvqSRLaWslM" }, "outputs": [], "source": [ "SEQ_LEN = 8192\n", "PAD_IDX = 18819\n", "\n", "model = TransformerWrapper(num_tokens = PAD_IDX+1,\n", " max_seq_len = SEQ_LEN,\n", " attn_layers = Decoder(dim = 2048,\n", " depth = 8,\n", " heads = 32,\n", " rotary_pos_emb = True,\n", " attn_flash = True\n", " )\n", " )\n", "\n", "model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)\n", "\n", "print('=' * 70)\n", "print('Loading model checkpoint...')\n", "\n", "model_path = 'Models/Orpheus_Music_Transformer_Trained_Model_96332_steps_0.82_loss_0.748_acc.pth'\n", "\n", "model.load_state_dict(torch.load(model_path))\n", "\n", "print('=' * 70)\n", "\n", "model.cuda()\n", "model.eval()\n", "\n", "print('Done!')\n", "\n", "summary(model)\n", "\n", "dtype = torch.bfloat16\n", "\n", "ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load MIDI" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "enHpaHxaWslM" }, "outputs": [], "source": [ "midi_file = 'tegridy-tools/tegridy-tools/seed2.mid'\n", "\n", "print('=' * 70)\n", "print('Loading MIDI...')\n", "\n", "raw_score = TMIDIX.midi2single_track_ms_score(midi_file)\n", "\n", "escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True, apply_sustain=True)\n", "\n", "if escore_notes:\n", "\n", " escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], sort_drums_last=True)\n", "\n", " escore_notes = TMIDIX.recalculate_score_timings([e for e in escore_notes if e[3] != 9])\n", " \n", " dscore = TMIDIX.delta_score_notes(escore_notes)\n", " \n", " dcscore = TMIDIX.chordify_score([d[1:] for d in dscore])\n", " \n", " melody_chords = [18816]\n", "\n", " chords = []\n", " \n", " #=======================================================\n", " # MAIN PROCESSING CYCLE\n", " #=======================================================\n", " \n", " for i, c in enumerate(dcscore):\n", " \n", " delta_time = c[0][0]\n", " \n", " melody_chords.append(delta_time)\n", "\n", " cho = []\n", " \n", " cho.append(delta_time)\n", " \n", " for e in c:\n", " \n", " #=======================================================\n", " \n", " # Durations\n", " dur = max(1, min(255, e[1]))\n", " \n", " # Patches\n", " pat = max(0, min(128, e[5]))\n", " \n", " # Pitches\n", " ptc = max(1, min(127, e[3]))\n", " \n", " # Velocities\n", " # Calculating octo-velocity\n", " \n", " vel = max(8, min(127, e[4]))\n", " velocity = round(vel / 15)-1\n", " \n", " #=======================================================\n", " # FINAL NOTE SEQ\n", " #=======================================================\n", " \n", " # Writing final note\n", " pat_ptc = (128 * pat) + ptc \n", " dur_vel = (8 * dur) + velocity\n", " \n", " melody_chords.extend([pat_ptc+256, dur_vel+16768]) # 18816\n", " cho.extend([pat_ptc+256, dur_vel+16768])\n", "\n", " chords.append(cho)\n", " \n", " print('Done!')\n", " print('=' * 70)\n", " print('Score has', len(melody_chords), 'tokens')\n", " print('Score has', len(chords), 'chords')\n", " print('=' * 70)\n", "\n", "else:\n", " print('Error! Check MIDI file!')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Texture chords" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w6Z3HJ313EL_" }, "outputs": [], "source": [ "model_temperature = 1.0\n", "model_sampling_top_p = 0.96\n", "\n", "#==================================================================\n", "\n", "print('=' * 70)\n", "print('Sample score tokens', melody_chords[:10])\n", "\n", "#==================================================================\n", "\n", "def gen_drums(seq):\n", "\n", " y = 16641\n", " num_gen_drums = 0\n", "\n", " while y > 16640:\n", " \n", " x = torch.LongTensor(seq).cuda()\n", " \n", " with ctx:\n", " out = model.generate(x,\n", " 1,\n", " temperature=model_temperature,\n", " filter_logits_fn=top_p,\n", " filter_kwargs={'thres': model_sampling_top_p},\n", " return_prime=False,\n", " eos_token=18818,\n", " verbose=False)\n", "\n", " y = out.tolist()[0]\n", "\n", " if y > 16640:\n", " seq.append(y)\n", " num_gen_drums += 1\n", "\n", " if num_gen_drums == 10:\n", " break\n", "\n", " return seq\n", "\n", "#==================================================================\n", "\n", "print('=' * 70)\n", "print('Generating...')\n", "print('=' * 70)\n", "\n", "final_song = [18816]\n", "\n", "for i in tqdm.tqdm(range(len(chords))):\n", "\n", " final_song.extend(chords[i])\n", "\n", " if i == 0:\n", " final_song.append((128*128)+38+256) # Drum pitch/patch\n", " final_song.append((8*8)+5+16768) # Drum dur/vel\n", " \n", " if (final_song[-2] < 16640 and i % 8 == 0):\n", " final_song.append((128*128)+38+256) # Drum pitch/patch\n", "\n", " final_song = gen_drums(final_song)\n", "\n", "#==================================================================\n", "\n", "print('=' * 70)\n", "print('Done!')\n", "print('=' * 70)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Save to MIDI" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tlBzqWpAnZna" }, "outputs": [], "source": [ "print('Sample INTs', final_song[:15])\n", "\n", "if len(final_song) != 0:\n", "\n", " song_f = []\n", "\n", " time = 0\n", " dur = 1\n", " vel = 90\n", " pitch = 60\n", " channel = 0\n", " patch = 0\n", "\n", " patches = [-1] * 16\n", "\n", " channels = [0] * 16\n", " channels[9] = 1\n", "\n", " for ss in final_song:\n", "\n", " if 0 <= ss < 256:\n", "\n", " time += ss * 16\n", "\n", " if 256 <= ss < 16768:\n", "\n", " patch = (ss-256) // 128\n", "\n", " if patch < 128:\n", "\n", " if patch not in patches:\n", " if 0 in channels:\n", " cha = channels.index(0)\n", " channels[cha] = 1\n", " else:\n", " cha = 15\n", "\n", " patches[cha] = patch\n", " channel = patches.index(patch)\n", " else:\n", " channel = patches.index(patch)\n", "\n", " if patch == 128:\n", " channel = 9\n", "\n", " pitch = (ss-256) % 128\n", "\n", "\n", " if 16768 <= ss < 18816:\n", "\n", " dur = ((ss-16768) // 8) * 16\n", " vel = (((ss-16768) % 8)+1) * 15\n", "\n", " song_f.append(['note', time, dur, channel, pitch, vel, patch])\n", "\n", " patches = [0 if x==-1 else x for x in patches]\n", "\n", "output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f)\n", "\n", "fn1 = \"Orpheus-Drums-Transformer-Composition\"\n", "\n", "detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,\n", " output_signature = 'Orpheus Drums Transformer',\n", " output_file_name = fn1,\n", " track_name='Project Los Angeles',\n", " list_of_MIDI_patches=patches\n", " )\n", "\n", "print('Done!')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Plot tokens embeddings" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "al3TDlH7T8m7" }, "outputs": [], "source": [ "tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()\n", "\n", "cos_sim = metrics.pairwise_distances(\n", " tok_emb, metric='cosine'\n", ")\n", "plt.figure(figsize=(7, 7))\n", "plt.imshow(cos_sim, cmap=\"inferno\", interpolation=\"nearest\")\n", "im_ratio = cos_sim.shape[0] / cos_sim.shape[1]\n", "plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)\n", "plt.xlabel(\"Position\")\n", "plt.ylabel(\"Position\")\n", "plt.tight_layout()\n", "plt.plot()\n", "plt.savefig(\"/home/ubuntu/Orpheus-Drums-Transformer-Tokens-Embeddings-Plot.png\", bbox_inches=\"tight\")" ] }, { "cell_type": "markdown", "metadata": { "id": "z87TlDTVl5cp" }, "source": [ "# Congrats! You did it! :)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuClass": "premium", "gpuType": "T4", "private_outputs": true, "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }