-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Getting started with a custom dataset #33
Comments
Hi, from hear21passt.base import load_model
model = load_model(mode="logits").cuda()
logits = model(wave_signal) Unfortunately, all the pre-trained models are trained on mono audio. |
Hi, On the contrary, I would like to use your training framework and fine-tune my dataset. From which file can I modify the code for running training on my data? |
I think the simplest way to get started is to do something similar to the ESC-50. Dataset fileYou would need a dataset file: https://github.com/kkoutini/PaSST/blob/main/esc50/dataset.py Line 137 in d27d832
This file should also define methods to get the training and test sets: Line 255 in d27d832
Line 271 in d27d832
The config is injected automatically using sacred. take a look at Line 309 in d27d832
to check if your dataset is parsing and loading the audio files correctly. Experiment fileThen you need to create the Experiment file, similar to https://github.com/kkoutini/PaSST/blob/main/ex_esc50.py Here, you need to change the dataset ingredient to match your new dataset (string format like a python import) : Line 68 in d27d832
you can then change the default project name in wanb config: Line 71 in d27d832
The reset is to update the pytorch-lightning module if needed: Line 106 in d27d832
|
Thank you for your reply. I managed to fine-tune the model on my costume dataset, and I would like to run an evaluation on my test set (not inference). I added 'get_test_set()' in the "dataset.py" and also added a test loader in the experiment file ("/basedataset.get_test_set"). I already have a model checkpoint, and I wonder how I can load it so that it can fit into the class "M(Ba3lModule)" and the function "evaluate_only" in the experiment file. Is there any basic code to run the checkpoint fine-tuned model in the experiment file? Thanks! |
Hi that's great! @ex.command
def evaluate_only(load_ckpt_path="path_to_ckpt", _run, _config, _log, _rnd, _seed):
# force overriding the config, not logged = not recommended
trainer = get_trainer(logger=get_logger())
val_loader = get_validate_loader()
modul = M(ex)
## loading pre-trained weights
checkpoint = torch.load(load_ckpt_path) # maybe with map_location=torch.device('cpu'))
modul.load_state_dict(checkpoint['state_dict'] )
##
modul.val_dataloader = None
trainer.val_dataloaders = None
print(f"\n\nValidation len={len(val_loader)}\n")
res = trainer.validate(modul, val_dataloaders=val_loader)
print("\n\n Validtaion:")
print(res) maybe you also want to load only the transformer weights: ## loading pre-trained weights
checkpoint = torch.load(load_ckpt_path) # maybe with map_location=torch.device('cpu'))
net_statedict = {k[4:]: v for k, v in checkpoint['state_dict'].items() if k.startswith("net.")}
modul.net.load_state_dict(net_statedict )
## |
Hi, Thank you for responding. It works! Regarding my question about the multi-channel audio case; can I train PaSST on multi-channel from scratch? (not from pre-trained) Does the model/code support training on these kinds of data? |
yes, you can change the |
Hello, @OhadCohen97 It is great news that you managed to fine-tune your dataset using the framework, can you please explain to me how you achieved it, step by step? Am trying to do the same thing and am lost where to start or what to do. I Would appreciate if you explain step by step from start to finish. |
Hi,
Thank you for your great work!
I want to use PaSST for my custom dataset, different classification task.
Are there any minimal instructions/code for running the model for a different dataset? From which file should I start?
Does PaSST support multi-channel audio wav?
Best
The text was updated successfully, but these errors were encountered: