Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting started with a custom dataset #33

Open
OhadCohen97 opened this issue Jun 2, 2023 · 8 comments
Open

Getting started with a custom dataset #33

OhadCohen97 opened this issue Jun 2, 2023 · 8 comments

Comments

@OhadCohen97
Copy link

OhadCohen97 commented Jun 2, 2023

Hi,

Thank you for your great work!

I want to use PaSST for my custom dataset, different classification task.

Are there any minimal instructions/code for running the model for a different dataset? From which file should I start?

Does PaSST support multi-channel audio wav?

Best

@kkoutini
Copy link
Owner

kkoutini commented Jun 5, 2023

Hi,
Thank you!
If you want to use your own training framework, check out this repo for some examples: https://github.com/kkoutini/passt_hear21
For example, to get a trainable model that accepts wave inputs:

from hear21passt.base import load_model

model = load_model(mode="logits").cuda()
logits = model(wave_signal)

Unfortunately, all the pre-trained models are trained on mono audio.

@OhadCohen97
Copy link
Author

OhadCohen97 commented Jun 5, 2023

Hi,
Thank you for your reply!

On the contrary, I would like to use your training framework and fine-tune my dataset.

From which file can I modify the code for running training on my data?

@kkoutini
Copy link
Owner

kkoutini commented Jun 5, 2023

I think the simplest way to get started is to do something similar to the ESC-50.

Dataset file

You would need a dataset file: https://github.com/kkoutini/PaSST/blob/main/esc50/dataset.py
In this file, the dataset config is defined, and the main pytorch dataset class:

class AudioSetDataset(TorchDataset):

This file should also define methods to get the training and test sets:
def get_training_set(normalize, roll, wavmix=False):

def get_test_set(normalize):

The config is injected automatically using sacred.

take a look at

if __name__ == "__main__":

to check if your dataset is parsing and loading the audio files correctly.

Experiment file

Then you need to create the Experiment file, similar to https://github.com/kkoutini/PaSST/blob/main/ex_esc50.py
The experiment file has the defaults configs and loads the dataset and the NN model and contains the training logic as pytorch-lightning module .

Here, you need to change the dataset ingredient to match your new dataset (string format like a python import) :

basedataset = DynamicIngredient("esc50.dataset.dataset")

you can then change the default project name in wanb config:
wandb = dict(project="passt_esc50", log_model=True)

The reset is to update the pytorch-lightning module if needed:

class M(Ba3lModule):

@OhadCohen97
Copy link
Author

OhadCohen97 commented Jun 14, 2023

Thank you for your reply.

I managed to fine-tune the model on my costume dataset, and I would like to run an evaluation on my test set (not inference).

I added 'get_test_set()' in the "dataset.py" and also added a test loader in the experiment file ("/basedataset.get_test_set").

I already have a model checkpoint, and I wonder how I can load it so that it can fit into the class "M(Ba3lModule)" and the function "evaluate_only" in the experiment file.

Is there any basic code to run the checkpoint fine-tuned model in the experiment file?

Thanks!

@kkoutini
Copy link
Owner

Hi that's great!
I think the easiest way to load the model is to edit the eval_only function to something like this:

@ex.command
def evaluate_only(load_ckpt_path="path_to_ckpt", _run, _config, _log, _rnd, _seed):
    # force overriding the config, not logged = not recommended
    trainer = get_trainer(logger=get_logger())
    val_loader = get_validate_loader()

    modul = M(ex)
    ## loading pre-trained weights
    checkpoint = torch.load(load_ckpt_path) # maybe with  map_location=torch.device('cpu'))
    modul.load_state_dict(checkpoint['state_dict'] )
  
    ##
    modul.val_dataloader = None
    trainer.val_dataloaders = None
    print(f"\n\nValidation len={len(val_loader)}\n")
    res = trainer.validate(modul, val_dataloaders=val_loader)
    print("\n\n Validtaion:")
    print(res)

maybe you also want to load only the transformer weights:

    ## loading pre-trained weights
    checkpoint = torch.load(load_ckpt_path) # maybe with  map_location=torch.device('cpu'))
    net_statedict = {k[4:]: v for k, v in checkpoint['state_dict'].items() if k.startswith("net.")}
    modul.net.load_state_dict(net_statedict )
 
    ##

@OhadCohen97
Copy link
Author

Hi,

Thank you for responding. It works!

Regarding my question about the multi-channel audio case; can I train PaSST on multi-channel from scratch? (not from pre-trained) Does the model/code support training on these kinds of data?

@kkoutini
Copy link
Owner

@CodeBot-del
Copy link

Thank you for your reply.

I managed to fine-tune the model on my costume dataset, and I would like to run an evaluation on my test set (not inference).

I added 'get_test_set()' in the "dataset.py" and also added a test loader in the experiment file ("/basedataset.get_test_set").

I already have a model checkpoint, and I wonder how I can load it so that it can fit into the class "M(Ba3lModule)" and the function "evaluate_only" in the experiment file.

Is there any basic code to run the checkpoint fine-tuned model in the experiment file?

Thanks!

Hello, @OhadCohen97 It is great news that you managed to fine-tune your dataset using the framework, can you please explain to me how you achieved it, step by step? Am trying to do the same thing and am lost where to start or what to do. I Would appreciate if you explain step by step from start to finish.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants