From 6de783bc24c7ed2714b560652536f6c1ebed5e71 Mon Sep 17 00:00:00 2001 From: LSCLDev Date: Tue, 13 Dec 2022 10:45:36 +0000 Subject: [PATCH] [NEVIS] Minor changes to ckpt loading logic. PiperOrigin-RevId: 494966802 Change-Id: I806a4c0a67c6f5f5de0619dea328cbac18ec9511 --- .../configs/finetuning_ind_pretrained.py | 7 ++++--- ...point_loader.py => pretrained_model_loader.py} | 15 +++++++++------ experiments_jax/training/trainer.py | 2 +- .../configs/finetuning_ind_pretrained.py | 7 ++++--- ...point_loader.py => pretrained_model_loader.py} | 15 +++++++-------- 5 files changed, 25 insertions(+), 21 deletions(-) rename experiments_jax/environment/{checkpoint_loader.py => pretrained_model_loader.py} (80%) rename experiments_torch/environment/{checkpoint_loader.py => pretrained_model_loader.py} (79%) diff --git a/experiments_jax/configs/finetuning_ind_pretrained.py b/experiments_jax/configs/finetuning_ind_pretrained.py index e8e4154..e214b29 100644 --- a/experiments_jax/configs/finetuning_ind_pretrained.py +++ b/experiments_jax/configs/finetuning_ind_pretrained.py @@ -17,7 +17,7 @@ from dm_nevis.benchmarker.datasets import test_stream from dm_nevis.benchmarker.environment import logger_utils -from experiments_jax.environment import checkpoint_loader +from experiments_jax.environment import pretrained_model_loader from experiments_jax.learners.finetuning import finetuning_learner from experiments_jax.training import augmentations from experiments_jax.training import modules @@ -33,7 +33,7 @@ DEFAULT_CHECKPOINT_DIR = os.environ.get('NEVIS_CHECKPOINT_DIR', '/tmp/nevis_checkpoint_dir') DEFAULT_PRETRAIN_CHECKPOINT_PATH = os.path.join(DEFAULT_CHECKPOINT_DIR, - 'pretraining.ckpt') + 'pretraining.pkl') FREEZE_PRETRAINED_BACKBONE = False @@ -95,7 +95,8 @@ def get_config() -> ml_collections.ConfigDict: # Optionally load and/or freeze pretrained parameters. 'load_params_fn': None, 'load_params_fn_with_kwargs': { - 'fun': checkpoint_loader.load_ckpt_params, + 'fun': + pretrained_model_loader.load_model_params_from_ckpt, 'kwargs': { 'freeze_pretrained_backbone': FREEZE_PRETRAINED_BACKBONE, diff --git a/experiments_jax/environment/checkpoint_loader.py b/experiments_jax/environment/pretrained_model_loader.py similarity index 80% rename from experiments_jax/environment/checkpoint_loader.py rename to experiments_jax/environment/pretrained_model_loader.py index 285653f..2851a25 100644 --- a/experiments_jax/environment/checkpoint_loader.py +++ b/experiments_jax/environment/pretrained_model_loader.py @@ -4,11 +4,11 @@ from absl import logging import chex -from experiments_jax.environment import pickle_checkpointer +from experiments_jax.training import trainer import haiku as hk -def load_ckpt_params( +def load_model_params_from_ckpt( params: hk.Params, state: hk.State, freeze_pretrained_backbone: bool = False, @@ -26,12 +26,15 @@ def load_ckpt_params( updated params split into trainable and frozen, updated states. """ - checkpointer = pickle_checkpointer.PickleCheckpointer(checkpoint_path) - restored_params = checkpointer.restore() - - if restored_params is None: + trainer_state = trainer.restore_train_state(checkpoint_path) + if trainer_state is None or trainer_state.trainable_params is None or trainer_state.frozen_params is None: return params, {}, state + restored_params = { + **trainer_state.trainable_params, + **trainer_state.frozen_params + } + def filter_fn(module_name, *unused_args): del unused_args return module_name.startswith('backbone') diff --git a/experiments_jax/training/trainer.py b/experiments_jax/training/trainer.py index ba2d400..24dce3b 100644 --- a/experiments_jax/training/trainer.py +++ b/experiments_jax/training/trainer.py @@ -76,7 +76,7 @@ def init_train_state( if load_params_fn: trainable_params, frozen_params, state = load_params_fn(params, state) else: - trainable_params, frozen_params = params, [] + trainable_params, frozen_params = params, {} opt_state = opt.init(trainable_params) diff --git a/experiments_torch/configs/finetuning_ind_pretrained.py b/experiments_torch/configs/finetuning_ind_pretrained.py index 0adc3de..fdd2927 100644 --- a/experiments_torch/configs/finetuning_ind_pretrained.py +++ b/experiments_torch/configs/finetuning_ind_pretrained.py @@ -16,7 +16,7 @@ import os from dm_nevis.benchmarker.environment import logger_utils -from experiments_torch.environment import checkpoint_loader +from experiments_torch.environment import pretrained_model_loader from experiments_torch.learners.finetuning import finetuning_learner from experiments_torch.training import augmentations from experiments_torch.training import resnet @@ -31,7 +31,7 @@ DEFAULT_CHECKPOINT_DIR = os.environ.get('NEVIS_CHECKPOINT_DIR', '/tmp/nevis_checkpoint_dir') DEFAULT_PRETRAIN_CHECKPOINT_PATH = os.path.join(DEFAULT_CHECKPOINT_DIR, - 'pretraining.ckpt') + 'pretraining.pkl') FREEZE_PRETRAINED_BACKBONE = False @@ -93,7 +93,8 @@ def get_config() -> ml_collections.ConfigDict: # Optionally load and/or freeze pretrained parameters. 'load_params_fn': None, 'load_params_fn_with_kwargs': { - 'fun': checkpoint_loader.load_ckpt_params, + 'fun': + pretrained_model_loader.load_model_params_from_ckpt, 'kwargs': { 'freeze_pretrained_backbone': FREEZE_PRETRAINED_BACKBONE, diff --git a/experiments_torch/environment/checkpoint_loader.py b/experiments_torch/environment/pretrained_model_loader.py similarity index 79% rename from experiments_torch/environment/checkpoint_loader.py rename to experiments_torch/environment/pretrained_model_loader.py index 5eff8ab..c904eac 100644 --- a/experiments_torch/environment/checkpoint_loader.py +++ b/experiments_torch/environment/pretrained_model_loader.py @@ -3,12 +3,12 @@ from typing import Tuple, Union, Dict from absl import logging -from experiments_torch.environment import pickle_checkpointer from experiments_torch.training import models +from experiments_torch.training import trainer from torch.nn import parameter -def load_ckpt_params( +def load_model_params_from_ckpt( model: models.Model, freeze_pretrained_backbone: bool = False, checkpoint_path: str = '', @@ -23,19 +23,18 @@ def load_ckpt_params( Returns: updated params split into trainable and frozen. """ - - checkpointer = pickle_checkpointer.PickleCheckpointer(checkpoint_path) - restored_model = checkpointer.restore() - - if restored_model is None: + trainer_state = trainer.restore_train_state(checkpoint_path) + if trainer_state is None or trainer_state.model is None: return model.backbone.parameters(), {} + restored_model = trainer_state.model + assert isinstance(restored_model, models.Model) logging.info('Loading pretrained model finished.') for model_param, restored_model_param in zip( model.backbone.parameters(), restored_model.backbone.parameters()): - assert model_param.data.shape == restored_model_param.data + assert model_param.data.shape == restored_model_param.data.shape model_param.data = restored_model_param.data model_param.requires_grad = not freeze_pretrained_backbone