Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

difference of fine-tuning the pretrained models #30

Open
nianniandoushabao opened this issue Feb 19, 2023 · 2 comments
Open

difference of fine-tuning the pretrained models #30

nianniandoushabao opened this issue Feb 19, 2023 · 2 comments

Comments

@nianniandoushabao
Copy link

I'm sorry to bother you. I want to ask the difference between the two ways to get pre-training models. I don't know if I understand correctly
The first is in the ''Getting a pre-trained model for fine tuning'' part. The code is

from hear21passt.base import get_basic_model,get_model_passt
import torch
# get the PaSST model wrapper, includes Melspectrogram and the default pre-trained transformer
model = get_basic_model(mode="logits")
print(model.mel) # Extracts mel spectrogram from raw waveforms.

# optional replace the transformer with one that has the required number of classes i.e. 50
model.net = get_model_passt(arch="passt_s_swa_p16_128_ap476",  n_classes=50)
print(model.net) # the transformer network.


# now model contains mel + the transformer pre-trained model ready to be fine tuned.
# It's still expecting input of the shape [batch, seconds*32000] sampling rate is 32k

model.train()
model = model.cuda()

The second is in the ''Pre-trained models'' part.

from models.passt import get_model
model  = get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527, in_channels=1,
                   fstride=10, tstride=10,input_fdim=128, input_tdim=998,
                   u_patchout=0, s_patchout_t=40, s_patchout_f=4)

I have two questions. Is it the first way to obtain the pre-trained model and only fine-tune the layers in transformer blocks related to num_classes? The other layers' weights will not changed?
And Is it the second way to obtain the pre-trained model will load weights of all layers and train them again ? Or are the two ways the same?

@kkoutini
Copy link
Owner

Hi!
get_model_passt is an alias for the get_model function in passt_hear21 see here.
So in general there should not be a difference between the two functions, they will both return the transformer model. The passt_hear21 repo contains only the models code (including the preprocessing) without the dependecies of saced and PyTorch Lightning.

if num_classes is different then that of the pre-trained models, then the last MLP layer is initilizaed to random (according to the new num_classes`) and the model is ready to be fine-tuned on the new task, that is done in timm here
so both ways are the same. Off-course you can use all the other arguments here.

hear21passt.base.get_basic_model will return a wrapper module that contains both the transformer and the preprocessing mel-spectrograms see here

@nianniandoushabao
Copy link
Author

Oh, thank you, thank you very much.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants