Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
SkyTNT committed Sep 6, 2023
1 parent 0e7ad89 commit 3801d96
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ def file_ext(fname):


class MidiDataset(Dataset):
def __init__(self, midi_list, tokenizer: MIDITokenizer, max_len=2048, aug=True, check_alignment=True):
def __init__(self, midi_list, tokenizer: MIDITokenizer, max_len=2048, min_file_size=3000, max_file_size=384000,
aug=True, check_alignment=True):

self.tokenizer = tokenizer
self.midi_list = midi_list
self.max_len = max_len
self.min_file_size = min_file_size
self.max_file_size = max_file_size
self.aug = aug
self.check_alignment = check_alignment

Expand All @@ -42,9 +45,9 @@ def load_midi(self, index):
try:
with open(path, 'rb') as f:
datas = f.read()
if len(datas) > 384000:
if len(datas) > self.max_file_size: # large midi file will spend too much time to load
raise ValueError("file too large")
elif len(datas) < 3000:
elif len(datas) < self.min_file_size:
raise ValueError("file too small")
mid = MIDI.midi2score(datas)
if max([0] + [len(track) for track in mid[1:]]) == 0:
Expand Down

0 comments on commit 3801d96

Please sign in to comment.