# Orpheus Drums Transformer (ver. 1.0)

***

Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools

***

WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/

***

#### Project Los Angeles

#### Tegridy Code 2025

***

# GPU check

In [None]:
!nvidia-smi

# Setup environment

In [None]:
!git clone --depth 1 https://github.com/asigalov61/tegridy-tools

In [None]:
!pip install huggingface_hub
!pip install hf-transfer
!pip install ipywidgets
!pip install tqdm

!pip install einx
!pip install einops
!pip install torch-summary

In [None]:
# Load modules and make data dir

print('Loading modules...')

import os

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import pickle
import random
import secrets
import tqdm
import math

import gc

!set USE_FLASH_ATTENTION=1
os.environ['USE_FLASH_ATTENTION'] = '1'

import torch

import matplotlib.pyplot as plt

from torchsummary import summary
from sklearn import metrics

%cd /home/ubuntu/tegridy-tools/tegridy-tools/

import TMIDIX

%cd /home/ubuntu/tegridy-tools/tegridy-tools/X-Transformer

from x_transformer_2_3_1 import *

torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_cudnn_sdp(False)

!set USE_FLASH_ATTENTION=1

%cd /home/ubuntu/

import random

from huggingface_hub import hf_hub_download

print('Done')

print('Torch version:', torch.__version__)

# Download model

In [None]:
hf_hub_download(repo_id='asigalov61/Orpheus-Music-Transformer',
                filename='Orpheus_Music_Transformer_Trained_Model_96332_steps_0.82_loss_0.748_acc.pth',
                local_dir='/home/ubuntu/Models/',
                repo_type='model'
                )

# Load model

In [None]:
SEQ_LEN = 8192
PAD_IDX = 18819

model = TransformerWrapper(num_tokens = PAD_IDX+1,
                           max_seq_len = SEQ_LEN,
                           attn_layers = Decoder(dim = 2048,
                                                 depth = 8,
                                                 heads = 32,
                                                 rotary_pos_emb = True,
                                                 attn_flash = True
                                                 )
                           )

model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)

print('=' * 70)
print('Loading model checkpoint...')

model_path = 'Models/Orpheus_Music_Transformer_Trained_Model_96332_steps_0.82_loss_0.748_acc.pth'

model.load_state_dict(torch.load(model_path))

print('=' * 70)

model.cuda()
model.eval()

print('Done!')

summary(model)

dtype = torch.bfloat16

ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)

# Load MIDI

In [None]:
midi_file = 'tegridy-tools/tegridy-tools/seed2.mid'

print('=' * 70)
print('Loading MIDI...')

raw_score = TMIDIX.midi2single_track_ms_score(midi_file)

escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True, apply_sustain=True)

if escore_notes:

    escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], sort_drums_last=True)

    escore_notes = TMIDIX.recalculate_score_timings([e for e in escore_notes if e[3] != 9])
    
    dscore = TMIDIX.delta_score_notes(escore_notes)
    
    dcscore = TMIDIX.chordify_score([d[1:] for d in dscore])
    
    melody_chords = [18816]

    chords = []
    
    #=======================================================
    # MAIN PROCESSING CYCLE
    #=======================================================
    
    for i, c in enumerate(dcscore):
    
        delta_time = c[0][0]
    
        melody_chords.append(delta_time)

        cho = []
        
        cho.append(delta_time)
    
        for e in c:
        
            #=======================================================
            
            # Durations
            dur = max(1, min(255, e[1]))
    
            # Patches
            pat = max(0, min(128, e[5]))
            
            # Pitches
            ptc = max(1, min(127, e[3]))
            
            # Velocities
            # Calculating octo-velocity
            
            vel = max(8, min(127, e[4]))
            velocity = round(vel / 15)-1
            
            #=======================================================
            # FINAL NOTE SEQ
            #=======================================================
            
            # Writing final note
            pat_ptc = (128 * pat) + ptc 
            dur_vel = (8 * dur) + velocity
    
            melody_chords.extend([pat_ptc+256, dur_vel+16768]) # 18816
            cho.extend([pat_ptc+256, dur_vel+16768])

        chords.append(cho)
    
    print('Done!')
    print('=' * 70)
    print('Score has', len(melody_chords), 'tokens')
    print('Score has', len(chords), 'chords')
    print('=' * 70)

else:
    print('Error! Check MIDI file!')

# Texture chords

In [None]:
model_temperature = 1.0
model_sampling_top_p = 0.96

#==================================================================

print('=' * 70)
print('Sample score tokens', melody_chords[:10])

#==================================================================

def gen_drums(seq):

    y = 16641
    num_gen_drums = 0

    while y > 16640:
    
        x = torch.LongTensor(seq).cuda()
        
        with ctx:
            out = model.generate(x,
                                 1,
                                 temperature=model_temperature,
                                 filter_logits_fn=top_p,
                                 filter_kwargs={'thres': model_sampling_top_p},
                                 return_prime=False,
                                 eos_token=18818,
                                 verbose=False)

        y = out.tolist()[0]

        if y > 16640:
            seq.append(y)
            num_gen_drums += 1

        if num_gen_drums == 10:
            break

    return seq

#==================================================================

print('=' * 70)
print('Generating...')
print('=' * 70)

final_song = [18816]

for i in tqdm.tqdm(range(len(chords))):

    final_song.extend(chords[i])

    if i == 0:
        final_song.append((128*128)+38+256) # Drum pitch/patch
        final_song.append((8*8)+5+16768) # Drum dur/vel
        
    if (final_song[-2] < 16640 and i % 8 == 0):
        final_song.append((128*128)+38+256) # Drum pitch/patch

    final_song = gen_drums(final_song)

#==================================================================

print('=' * 70)
print('Done!')
print('=' * 70)

# Save to MIDI

In [None]:
print('Sample INTs', final_song[:15])

if len(final_song) != 0:

    song_f = []

    time = 0
    dur = 1
    vel = 90
    pitch = 60
    channel = 0
    patch = 0

    patches = [-1] * 16

    channels = [0] * 16
    channels[9] = 1

    for ss in final_song:

        if 0 <= ss < 256:

            time += ss * 16

        if 256 <= ss < 16768:

            patch = (ss-256) // 128

            if patch < 128:

                if patch not in patches:
                  if 0 in channels:
                      cha = channels.index(0)
                      channels[cha] = 1
                  else:
                      cha = 15

                  patches[cha] = patch
                  channel = patches.index(patch)
                else:
                  channel = patches.index(patch)

            if patch == 128:
                channel = 9

            pitch = (ss-256) % 128


        if 16768 <= ss < 18816:

            dur = ((ss-16768) // 8) * 16
            vel = (((ss-16768) % 8)+1) * 15

            song_f.append(['note', time, dur, channel, pitch, vel, patch])

    patches = [0 if x==-1 else x for x in patches]

output_score, patches, overflow_patches = TMIDIX.patch_enhanced_score_notes(song_f)

fn1 = "Orpheus-Drums-Transformer-Composition"

detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score,
                                                          output_signature = 'Orpheus Drums Transformer',
                                                          output_file_name = fn1,
                                                          track_name='Project Los Angeles',
                                                          list_of_MIDI_patches=patches
                                                          )

print('Done!')

# Plot tokens embeddings

In [None]:
tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()

cos_sim = metrics.pairwise_distances(
  tok_emb, metric='cosine'
)
plt.figure(figsize=(7, 7))
plt.imshow(cos_sim, cmap="inferno", interpolation="nearest")
im_ratio = cos_sim.shape[0] / cos_sim.shape[1]
plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)
plt.xlabel("Position")
plt.ylabel("Position")
plt.tight_layout()
plt.plot()
plt.savefig("/home/ubuntu/Orpheus-Drums-Transformer-Tokens-Embeddings-Plot.png", bbox_inches="tight")

# Congrats! You did it! :)