Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaojw1998 committed Oct 27, 2024
1 parent 33b9cc7 commit f7ea8de
Show file tree
Hide file tree
Showing 2,961 changed files with 53,309 additions and 1,172 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.pyc
*.zip
/data_file_dir
/data_file_dir
*.wav
59 changes: 54 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,61 @@
[![GitHub](https://img.shields.io/badge/GitHub-demo%20page-blue?logo=Github&style=flat-round)](https://zhaojw1998.github.io/AccoMontage-3)
[![Colab](https://img.shields.io/badge/Colab-tutorial-blue?logo=googlecolab&style=flat-round)](https://colab.research.google.com/drive/1LSY1TTkSesDUfpJplq5xi-3-DI09fWQ9?usp=sharing)

Repository for Paper: Zhao et al., [AccoMontage-3: Full-Band Accompaniment Arrangement via Sequential Style Transfer with Multi-Track Function Prior](https://arxiv.org/abs/2310.16334).
Repository for Paper: Zhao et al., Structured Multi-Track Accompaniment Arrangement via Style Prior Modelling, in proceedins of NeurIPS 2024.

Demp page: https://zhaojw1998.github.io/AccoMontage-3
We present a two-stage ststem for *whole-song*, *multi-track* accompaniment arrangement. In the first stage, a piano accompaniment is generated given a lead sheet. In the second stage, a multi-track accompaniment is orchestrated with customizable track numbers and choices of instruments. Our main novelty (essentials of this repo) lies in the second stage, where we implement long-term *style prior modelling* based on disentangled music content and style factors. Please refer to our paper for the detailed work.

AccoMontage-3 can be quickly tested with this [Tutorial on Colab](https://colab.research.google.com/drive/1LSY1TTkSesDUfpJplq5xi-3-DI09fWQ9?usp=sharing).
Demp page: https://zhaojw1998.github.io/structured-arrangement/

Model checkpoints can be downloaded [via this link](https://drive.google.com/drive/folders/17yB-Oae_4eGKJmqRS-LB8PwE2rqwZrUu?usp=sharing).
Our system can be quickly tested [on Colab](https://colab.research.google.com/drive/1LSY1TTkSesDUfpJplq5xi-3-DI09fWQ9?usp=sharing).

More information to be updated soon.

### Code and File Directory
This repository is organized as follows:
```
root
├──data_processing/ scripts for data processing
├──demo/ MIDI pieces for demonstration
├──orchestrator/ the orchestrator module at Stage 2
├──piano_arranger/ the piano arranger module at Stage 1
├──test/ scripts and results for objective evaluation
├──arrangement_utils.py functionals for model inference
├──inference_arrangement.ipynb two-stage model inference (arrangement from lead sheet)
├──inference_orchestration.ipynb Stage-2 module inference (orchestration from piano)
├──train_autoencoder.py training script for Stage-2 autoencoder
└──train_prior.py training script for Stage-2 prior model
```


### How to run
* You can quckly test our system on [Google Colab](https://colab.research.google.com/drive/1N3XeEfTCWNLTuBp9NWPwzW-hq7Ho7nQA?usp=sharing), where you can quickly test our model online.

* Alternatively, follow the guidance in [`./inference_arrangement.ipynb`](./inference_arrangement.ipynb) offline for more in-depth testing.

* If you wish to train our model from scratch, run [`./train_prior.py`](./train_prior.py). Please first download our processed LMD dataset and configure the data directory in the script. You may also wish to configure a few params such as `BATCH_SIZE` from the beginning of the script. When `DEBUG_MODE`=1, it will load a small portion of data and quickly run through for debugging purpose.


### Data and Checkpoints

* Model checkpoints can be downloaded [via this link](https://drive.google.com/drive/folders/17yB-Oae_4eGKJmqRS-LB8PwE2rqwZrUu?usp=sharing).

* Processed dataset (LMD) for training the prior model can be downloaded [via this link](https://drive.google.com/drive/folders/17yB-Oae_4eGKJmqRS-LB8PwE2rqwZrUu?usp=sharing).

* Processed dataset (Slakh2100) for training the autoencoder is accessbible [in this repo](https://github.com/zhaojw1998/Query-and-reArrange/tree/main/data/Slakh2100).


### Contact
Jingwei Zhao (PhD student in Data Science at NUS)

jzhao@u.nus.edu

Oct. 27, 2024
152 changes: 108 additions & 44 deletions arrangement_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,66 +8,74 @@

from piano_arranger.acc_utils import split_phrases
import piano_arranger.format_converter as cvt
from piano_arranger.models import DisentangleVAE
from piano_arranger.models import DisentangleVAE, PolyDisVAE
from piano_arranger.AccoMontage import find_by_length, dp_search, re_harmonization, get_texture_filter, ref_spotlight

from orchestrator import Slakh2100_Pop909_Dataset, collate_fn, compute_pr_feat, EMBED_PROGRAM_MAPPING, Prior
from orchestrator.QA_dataset import SLAKH_CLASS_PROGRAMS
from orchestrator.autoencoder_dataset import SLAKH_CLASS_PROGRAMS
from orchestrator.utils import grid2pr, pr2grid, matrix2midi, midi2matrix

from orchestrator.prior_dataset import TOTAL_LEN_BIN, ABS_POS_BIN, REL_POS_BIN

SLAKH_CLASS_MAPPING = {v: k for k, v in EMBED_PROGRAM_MAPPING.items()}


def load_premise(DATA_FILE_ROOT, DEVICE):
"""Load AccoMontage Search Space"""
print('Loading AccoMontage piano texture search space. This may take 1 or 2 minutes ...')
data = np.load(os.path.join(DATA_FILE_ROOT, 'phrase_data.npz'), allow_pickle=True)
melody = data['melody']
acc = data['acc']
chord = data['chord']
vel = data['velocity']
cc = data['cc']
acc_pool = {}
for LEN in tqdm(range(2, 13)):
(mel, acc_, chord_, vel_, cc_, song_reference) = find_by_length(melody, acc, chord, vel, cc, LEN)
acc_pool[LEN] = (mel, acc_, chord_, vel_, cc_, song_reference)
texture_filter = get_texture_filter(acc_pool)
edge_weights=np.load(os.path.join(DATA_FILE_ROOT, 'edge_weights.npz'), allow_pickle=True)

"""Load Q&A Prompt Search Space"""
print('loading orchestration prompt search space ...')
def load_premise(DATA_FILE_ROOT, DEVICE, load_piano_arranger=True):
if load_piano_arranger:
print('Loading lead sheet to piano arrangement module (piano arranger). This may take 1 or 2 minutes ...')
data = np.load(os.path.join(DATA_FILE_ROOT, 'phrase_data.npz'), allow_pickle=True)
melody = data['melody']
acc = data['acc']
chord = data['chord']
vel = data['velocity']
cc = data['cc']
acc_pool = {}
for LEN in tqdm(range(2, 17)):
(mel, acc_, chord_, vel_, cc_, song_reference) = find_by_length(melody, acc, chord, vel, cc, LEN)
acc_pool[LEN] = (mel, acc_, chord_, vel_, cc_, song_reference)
texture_filter = get_texture_filter(acc_pool)
edge_weights=np.load(os.path.join(DATA_FILE_ROOT, 'edge_weights.npz'), allow_pickle=True)

piano_arranger = PolyDisVAE(DEVICE, chd_size=256, voi_size=256, txt_size=256, num_channel=10)
piano_arranger.load_state_dict(torch.load(os.path.join(DATA_FILE_ROOT, "params_chord_texture.pt")))
piano_arranger.to(DEVICE)

else:
piano_arranger, acc_pool, texture_filter, edge_weights = None, None, None, None

print('Loading piano to multi-track arrangement module (orchestrator). This may take 1 or 2 minutes ...')
slakh_dir = os.path.join(DATA_FILE_ROOT, 'Slakh2100_inference_set')
dataset = Slakh2100_Pop909_Dataset(slakh_dir=slakh_dir, pop909_dir=None, debug_mode=False, split='validation', mode='train')
dataset = Slakh2100_Pop909_Dataset(slakh_dir=slakh_dir, pop909_dir=None, debug_mode=False, split='inference', mode='train')

loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda b:collate_fn(b, DEVICE))
REF = []
REF_PROG = []
REF_MIX = []
for (_, prog, function, _, _, _) in loader:
for (mix, prog, function, _, _, _) in loader:
mix = mix[:, :32]
prog = prog[0, :]
mix = mix.detach().cpu().numpy()[0]
mix = grid2pr(mix, max_note_count=32)[np.newaxis, :, :]
ref_mix = torch.from_numpy(np.concatenate(compute_pr_feat(mix)[1:], axis=-1)).to(function.device)

REF.extend([batch for batch in function])
REF_PROG.extend([prog for _ in range(len(function))])
REF_MIX.append(torch.sum(function, dim=1))
REF_MIX.append(ref_mix)
REF_MIX = torch.cat(REF_MIX, dim=0)

"""Initialize orchestration model (Prior + Q&A)"""
print('Initialize model ...')
print('Initializing prior model')
prior_model_path = os.path.join(DATA_FILE_ROOT, 'params_prior.pt')
QaA_model_path = os.path.join(DATA_FILE_ROOT, 'params_qa.pt')
QaA_model_path = os.path.join(DATA_FILE_ROOT, "params_autoencoder.pt")
orchestrator = Prior.init_inference_model(prior_model_path, QaA_model_path, DEVICE=DEVICE)
orchestrator.to(DEVICE)
orchestrator.eval()
piano_arranger = DisentangleVAE.init_model(torch.device('cuda')).cuda()
piano_arranger.load_state_dict(torch.load(os.path.join(DATA_FILE_ROOT, 'params_reharmonizer.pt')))

print('Finished.')
return piano_arranger, orchestrator, (acc_pool, edge_weights, texture_filter), (REF, REF_PROG, REF_MIX)


def read_lead_sheet(DEMO_ROOT, SONG_NAME, SEGMENTATION, NOTE_SHIFT, melody_track_ID=0):
melody_roll, chord_roll = cvt.leadsheet2matrix(os.path.join(DEMO_ROOT, SONG_NAME, 'lead sheet.mid'), melody_track_ID)
def read_lead_sheet(DEMO_ROOT, SONG_NAME, SEGMENTATION, NOTE_SHIFT, melody_track_ID=0, filename='lead sheet.mid'):
melody_roll, chord_roll = cvt.leadsheet2matrix(os.path.join(DEMO_ROOT, SONG_NAME, filename), melody_track_ID)
assert(len(melody_roll == len(chord_roll)))
if NOTE_SHIFT != 0:
melody_roll = melody_roll[int(NOTE_SHIFT*4):, :]
Expand Down Expand Up @@ -109,6 +117,50 @@ def read_lead_sheet(DEMO_ROOT, SONG_NAME, SEGMENTATION, NOTE_SHIFT, melody_track
return (LEADSHEET, CHORD_TABLE, melody_queries, query_phrases)


def read_piano_reduction(DEMO_ROOT, SONG_NAME, NOTE_SHIFT, melody_track_ID=0):
ACC = 4 #quantize at 1/16 beat
path = os.path.join(DEMO_ROOT, SONG_NAME, 'arrangement_piano.mid')
midi = pyd.PrettyMIDI(path)
beats = midi.get_beats()
beats = np.append(beats, beats[-1] + (beats[-1] - beats[-2]))
quantize = interp1d(np.array(range(0, len(beats))) * ACC, beats, kind='linear')
quaver = quantize(np.array(range(0, (len(beats) - 1) * ACC)))

piano_reduction = []
for idx, track in enumerate(midi.instruments):
if (not track.is_drum) and (idx != melody_track_ID):
pr_matrix = np.zeros((len(quaver), 128))
for note in track.notes:
note_start = np.argmin(np.abs(quaver - note.start))
note_end = np.argmin(np.abs(quaver - note.end))
if note_end == note_start:
note_end = min(note_start + 1, len(quaver) - 1)
pr_matrix[note_start, note.pitch] = note_end - note_start
piano_reduction.append(pr_matrix)
piano_reduction = np.sum(np.array(piano_reduction), axis=0)

melody = np.zeros((len(quaver), 128))
for note in midi.instruments[melody_track_ID].notes:
note_start = np.argmin(np.abs(quaver - note.start))
note_end = np.argmin(np.abs(quaver - note.end))
if note_end == note_start:
note_end = min(note_start + 1, len(quaver) - 1)
melody[note_start, note.pitch] = note_end - note_start
melody = np.array(melody)


if NOTE_SHIFT != 0:
piano_reduction = piano_reduction[int(NOTE_SHIFT*4)+1:, :]
melody = melody[int(NOTE_SHIFT*4)+1:, :]

if len(piano_reduction) % 32 != 0:
pad_len = (len(piano_reduction)//32+1)*32-len(piano_reduction)
piano_reduction = np.pad(piano_reduction, ((0, pad_len), (0, 0)))
melody = np.pad(melody, ((0, pad_len), (0, 0)))

return melody, piano_reduction


def piano_arrangement(pianoRoll, chord_table, melody_queries, query_phrases, acc_pool, edge_weights, texture_filter, piano_arranger, PREFILTER, tempo=100):
print('Phrasal Unit selection begins:\n\t', f'{len(query_phrases)} phrases in the lead sheet;\n\t', f'set note density filter: {PREFILTER}.')
phrase_indice, chord_shift = dp_search( melody_queries,
Expand All @@ -123,17 +175,26 @@ def piano_arrangement(pianoRoll, chord_table, melody_queries, query_phrases, acc
midi_recon, acc = re_harmonization(pianoRoll, chord_table, query_phrases, path, shift, acc_pool, model=piano_arranger, get_est=True, tempo=tempo)
acc = np.array([grid2pr(matrix) for matrix in acc])
print('Piano accompaiment generated!')

return midi_recon, acc


def prompt_sampling(acc_piano, REF, REF_PROG, REF_MIX, DEVICE='cuda:0'):
ref_mix = torch.from_numpy(compute_pr_feat(acc_piano[0:1])[-1]).to(DEVICE)
def prompt_sampling(acc_piano, REF, REF_PROG, REF_MIX, MUST_HAVE=[], MUSTNOT_HAVE=[], DEVICE='cuda:0'):
ref_mix = torch.from_numpy(np.concatenate(compute_pr_feat(acc_piano[0:1])[1:], axis=-1)).to(DEVICE)
sim_func = torch.nn.CosineSimilarity(dim=-1)
distance = sim_func(ref_mix, REF_MIX)
distance = distance + torch.normal(mean=torch.zeros(distance.shape), std=0.2*torch.ones(distance.shape)).to(distance.device)

MUSTNOT_HAVE = [(EMBED_PROGRAM_MAPPING[item]) for item in MUSTNOT_HAVE]
MUST_HAVE = [EMBED_PROGRAM_MAPPING[item] for item in MUST_HAVE]
for i in range(len(REF_PROG)):
ref_i = [item.item() for item in REF_PROG[i]]
distance[i] += len(set(MUST_HAVE).intersection(set(ref_i)))
distance[i] -= len(set(MUSTNOT_HAVE).intersection(set(ref_i)))

sim_values, anchor_points = torch.sort(distance, descending=True)
IDX = 0
sim_value = sim_values[IDX]
#sim_value = sim_values[IDX]
anchor_point = anchor_points[IDX]
function = REF[anchor_point]
prog = REF_PROG[anchor_point]
Expand All @@ -142,7 +203,7 @@ def prompt_sampling(acc_piano, REF, REF_PROG, REF_MIX, DEVICE='cuda:0'):
print(f'Prior model initialized with {len(program_name)} tracks:\n\t{program_name}')
return prog, function

def orchestration(acc_piano, chord_track, prog, function, orchestrator, DEVICE='cuda:0', blur=.5, p=.1, t=4, tempo=100):
def orchestration(acc_piano, chord_track, prog, function, orchestrator, DEVICE='cuda:0', blur=.5, p=.1, t=4, tempo=100, num_sample=1):
print('Orchestration begins ...')
if chord_track is not None:
if len(acc_piano) > len(chord_track):
Expand All @@ -160,16 +221,19 @@ def orchestration(acc_piano, chord_track, prog, function, orchestrator, DEVICE='
total_len = torch.from_numpy(total_len).long().to(DEVICE)

if function is not None:
function = function.unsqueeze(0).unsqueeze(0)
recon_pitch, recon_dur = orchestrator.run_autoregressive_nucleus(mix.unsqueeze(0), prog.unsqueeze(0), function, total_len.unsqueeze(0), a_pos.unsqueeze(0), r_pos.unsqueeze(0), blur, p, t) #function.unsqueeze(0).unsqueeze(0)
function = function.repeat(num_sample, 1, 1, *[1]*len(function.shape))
recon_pitch, recon_dur = orchestrator.inference(mix.repeat(num_sample, *[1]*len(mix.shape)), prog.repeat(num_sample, *[1]*len(prog.shape)), function, total_len.unsqueeze(0), a_pos.unsqueeze(0), r_pos.unsqueeze(0), blur, p, t) #function.unsqueeze(0).unsqueeze(0)

grid_recon = torch.cat([recon_pitch.max(-1)[-1].unsqueeze(-1), recon_dur.max(-1)[-1]], dim=-1)
bat_ch, track, _, max_simu_note, grid_dim = grid_recon.shape
grid_recon = grid_recon.permute(1, 0, 2, 3, 4)
grid_recon = grid_recon.reshape(track, -1, max_simu_note, grid_dim)

pr_recon_ = np.array([grid2pr(matrix) for matrix in grid_recon.detach().cpu().numpy()])
pr_recon = matrix2midi(pr_recon_, [SLAKH_CLASS_MAPPING[item.item()] for item in prog.cpu().detach().numpy()], tempo)
print('Full-band accompaiment generated!')
return pr_recon
batch, n_segments, track, _, max_simu_note, grid_dim = grid_recon.shape
grid_recon = grid_recon.permute(0, 2, 1, 3, 4, 5)
grid_recon = grid_recon.reshape(batch, track, -1, max_simu_note, grid_dim)

midi_collection = []
for batch_i in range(len(grid_recon)):
pr_recon_ = np.array([grid2pr(matrix) for matrix in grid_recon[batch_i].detach().cpu().numpy()])
pr_recon = matrix2midi(pr_recon_, [SLAKH_CLASS_MAPPING[item] for item in prog.cpu().detach().numpy()], tempo)
midi_collection.append(pr_recon)
print('Full-band accompaiment generated!')
return midi_collection

Loading

0 comments on commit f7ea8de

Please sign in to comment.