diff --git a/train.py b/train.py index 4718e94..75209b2 100644 --- a/train.py +++ b/train.py @@ -306,7 +306,7 @@ def get_midi_list(path): "--data-val-split", type=int, default=128, - help="split length for validation", + help="the number of midi files divided into the validation set", ) parser.add_argument( "--max-len", @@ -378,7 +378,7 @@ def get_midi_list(path): "--log-step", type=int, default=1, help="log training loss every n steps" ) parser.add_argument( - "--val-step", type=int, default=1600, help="valid and save every n steps" + "--val-step", type=int, default=1600, help="valid and save every n steps, set 0 to valid and save every epoch" ) opt = parser.parse_args() @@ -467,7 +467,7 @@ def get_midi_list(path): num_nodes=opt.nodes, max_steps=opt.max_step, benchmark=not opt.disable_benchmark, - val_check_interval=opt.val_step, + val_check_interval=opt.val_step or None, log_every_n_steps=1, strategy="auto", callbacks=callbacks,