From ab10546d54dc7a60c16591ff8b50003d4e734e75 Mon Sep 17 00:00:00 2001 From: Mustafa Chasmai Date: Tue, 1 Jun 2021 12:00:52 +0530 Subject: [PATCH] configs | incorporated in main code --- config.py | 109 ++++++++++++++------- main_lightning.py | 64 +++++++----- test_models_lightning.py | 41 +++++--- utils/model_init.py | 207 +++++++++++++++++++-------------------- 4 files changed, 237 insertions(+), 184 deletions(-) diff --git a/config.py b/config.py index acf6034..e8a8b2a 100644 --- a/config.py +++ b/config.py @@ -3,7 +3,6 @@ """ import os -from pickle import NONE from yacs.config import CfgNode as CN @@ -13,53 +12,56 @@ _C = CN() -_C.TO_VALIDATE = False # choices = [True, False] +_C.TO_VALIDATE = True # choices = [True, False] + # ----------------------------------------------------------------------------- -# Dataset +# Paths # ----------------------------------------------------------------------------- - -# dataset paths -_C.DATASET.PATH_DATA_ROOT = "data/" # directory where the feature pickles are stored. Depends on users -_C.DATASET.PATH_LABELS_ROOT = "data/" # directory where the annotations are stored. Depends on users -_C.DATASET.PATH_EXP_ROOT="model/action-model/" # directory where the checkpoints are to be stored. Depends on users +_C.PATHS = CN() +_C.PATHS.PATH_DATA_ROOT = "data/" # directory where the feature pickles are stored. Depends on users +_C.PATHS.PATH_LABELS_ROOT = "annotations/" # directory where the annotations are stored. Depends on users +_C.PATHS.PATH_EXP_ROOT="model/action-model/" # directory where the checkpoints are to be stored. Depends on users -_C.DATASET.DATASET_SOURCE="source_train" # depends on users -_C.DATASET.DATASET_TARGET="target_train" # depends on users +_C.PATHS.DATASET_SOURCE="source_train" # depends on users +_C.PATHS.DATASET_TARGET="target_train" # depends on users if _C.TO_VALIDATE: - _C.DATASET.VAL_DATASET_SOURCE="source_val" # depends on users - _C.DATASET.VAL_DATASET_TARGET="target_val" # depends on users + _C.PATHS.VAL_DATASET_SOURCE="source_val" # depends on users + _C.PATHS.VAL_DATASET_TARGET="target_val" # depends on users else: - _C.DATASET.VAL_DATASET_SOURCE= None - _C.DATASET.VAL_DATASET_TARGET= None -_C.DATASET.NUM_SOURCE= 16115 # number of training data (source) -_C.DATASET.NUM_TARGET= 26115 # number of training data (target) + _C.PATHS.VAL_DATASET_SOURCE= None + _C.PATHS.VAL_DATASET_TARGET= None +_C.PATHS.NUM_SOURCE= 16115 # number of training data (source) +_C.PATHS.NUM_TARGET= 26115 # number of training data (target) -_C.DATASET.PATH_DATA_SOURCE=os.path.join(_C.DATASET.PATH_DATA_ROOT, _C.DATASET.DATASET_SOURCE) -_C.DATASET.PATH_DATA_TARGET=os.path.join(_C.DATASET.PATH_DATA_ROOT, _C.DATASET.DATASET_TARGET) +_C.PATHS.PATH_DATA_SOURCE=os.path.join(_C.PATHS.PATH_DATA_ROOT, _C.PATHS.DATASET_SOURCE) +_C.PATHS.PATH_DATA_TARGET=os.path.join(_C.PATHS.PATH_DATA_ROOT, _C.PATHS.DATASET_TARGET) if _C.TO_VALIDATE: - _C.DATASET.PATH_VAL_DATA_SOURCE=os.path.join(_C.DATASET.PATH_DATA_ROOT, _C.DATASET.VAL_DATASET_SOURCE) - _C.DATASET.PATH_VAL_DATA_TARGET=os.path.join(_C.DATASET.PATH_DATA_ROOT, _C.DATASET.VAL_DATASET_TARGET) + _C.PATHS.PATH_VAL_DATA_SOURCE=os.path.join(_C.PATHS.PATH_DATA_ROOT, _C.PATHS.VAL_DATASET_SOURCE) + _C.PATHS.PATH_VAL_DATA_TARGET=os.path.join(_C.PATHS.PATH_DATA_ROOT, _C.PATHS.VAL_DATASET_TARGET) else: - _C.DATASET.PATH_VAL_DATA_SOURCE= None - _C.DATASET.PATH_VAL_DATA_SOURCE= None + _C.PATHS.PATH_VAL_DATA_SOURCE= None + _C.PATHS.PATH_VAL_DATA_SOURCE= None -_C.DATASET.TRAIN_SOURCE_LIST=os.path.join(_C.DATASET.PATH_LABELS_ROOT, 'EPIC_100_uda_source_train.pkl') # '/domain_adaptation_source_train_pre-release_v3.pkl' -_C.DATASET.TRAIN_TARGET_LIST=os.path.join(_C.DATASET.PATH_LABELS_ROOT, 'EPIC_100_uda_target_train_timestamps.pkl') # '/domain_adaptation_target_train_pre-release_v6.pkl' +_C.PATHS.TRAIN_SOURCE_LIST=os.path.join(_C.PATHS.PATH_LABELS_ROOT, 'EPIC_100_uda_source_train.pkl') # '/domain_adaptation_source_train_pre-release_v3.pkl' +_C.PATHS.TRAIN_TARGET_LIST=os.path.join(_C.PATHS.PATH_LABELS_ROOT, 'EPIC_100_uda_target_train_timestamps.pkl') # '/domain_adaptation_target_train_pre-release_v6.pkl' if _C.TO_VALIDATE: - _C.DATASET.VAL_SOURCE_LIST=os.path.join(_C.DATASET.PATH_LABELS_ROOT, "EPIC_100_uda_source_val.pkl") - _C.DATASET.VAL_TARGET_LIST=os.path.join(_C.DATASET.PATH_LABELS_ROOT, "EPIC_100_uda_target_val.pkl") + _C.PATHS.VAL_SOURCE_LIST=os.path.join(_C.PATHS.PATH_LABELS_ROOT, "EPIC_100_uda_source_val.pkl") + _C.PATHS.VAL_TARGET_LIST=os.path.join(_C.PATHS.PATH_LABELS_ROOT, "EPIC_100_uda_target_val.pkl") else: - _C.DATASET.VAL_SOURCE_LIST= None - _C.DATASET.VAL_TARGET_LIST= None -_C.DATASET.VAL_LIST=os.path.join(_C.DATASET.PATH_LABELS_ROOT, "EPIC_100_uda_target_test_timestamps.pkl") -_C.DATASET.PATH_EXP=os.path.join(_C.DATASET.PATH_EXP_ROOT, "Testexp") + _C.PATHS.VAL_SOURCE_LIST= None + _C.PATHS.VAL_TARGET_LIST= None +_C.PATHS.VAL_LIST=os.path.join(_C.PATHS.PATH_LABELS_ROOT, "EPIC_100_uda_target_test_timestamps.pkl") +_C.PATHS.PATH_EXP=os.path.join(_C.PATHS.PATH_EXP_ROOT, "Testexp") -# dataset parameters + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- _C.DATASET = CN() _C.DATASET.DATASET = "epic" # dataset choices = [hmdb_ucf, hmdb_ucf_small, ucf_olympic] -_C.DATASET.NUM_CLASSES = 97300 +_C.DATASET.NUM_CLASSES = "97,300" _C.DATASET.MODALITY = "ALL" # choices = [RGB ] _C.DATASET.FRAME_TYPE = "feature" # choices = [frame] _C.DATASET.NUM_SEGMENTS = 5 # sample frame # of each video for training @@ -80,6 +82,10 @@ _C.MODEL.WEIGHTED_CLASS_LOSS_DA = "N" # choices = [Y, N] _C.MODEL.WEIGHTED_CLASS_LOSS = "N" # choices = [Y, N] +_C.MODEL.DROPOUT_I = 0.8 +_C.MODEL.DROPOUT_V = 0.8 +_C.MODEL.NO_PARTIALBN = True + # DA configs if _C.MODEL.USE_TARGET == "none": @@ -96,6 +102,10 @@ _C.MODEL.USE_ATTN = "TransAttn" # choices = [None, TransAttn, general] _C.MODEL.USE_ATTN_FRAME = None # choices = [None, TransAttn, general] _C.MODEL.USE_BN = None # choices = [None, AdaBN, AutoDIAL] +_C.MODEL.N_ATTN = 1 +_C.MODEL.PLACE_DIS = ["Y", "Y", "N"] +_C.MODEL.PLACE_ADV = ["Y", "Y", "Y"] + # ---------------------------------------------------------------------------- # # Hyperparameters @@ -103,7 +113,6 @@ _C.HYPERPARAMETERS = CN() _C.HYPERPARAMETERS.ALPHA = 0 _C.HYPERPARAMETERS.BETA = [0.75, 0.75, 0.5] -_C.HYPERPARAMETERS.N_ATTN = 1 _C.HYPERPARAMETERS.GAMMA = 0.003 # U->H: 0.003 | H->U: 0.3 _C.HYPERPARAMETERS.MU = 0 @@ -117,21 +126,51 @@ _C.TRAINER.ARCH = "TBN" # choices = [resnet50] _C.TRAINER.USE_TARGET = "uSv" # choices = [uSv, Sv, none] _C.TRAINER.SHARE_PARAMS = "Y" # choices = [Y, N] +_C.TRAINER.PRETRAIN_SOURCE = False +_C.TRAINER.VERBOSE = True # Learning configs +_C.TRAINER.LOSS_TYPE = 'nll' _C.TRAINER.LR = 0.003 _C.TRAINER.LR_DECAY = 10 _C.TRAINER.LR_ADAPTIVE = None # choices = [None, loss, dann] _C.TRAINER.LR_STEPS = [10, 20] +_C.TRAINER.MOMENTUM = 0.9 +_C.TRAINER.WEIGHT_DECAY = 0.0001 _C.TRAINER.BATCH_SIZE = [128, 128*(_C.DATASET.NUM_TARGET/_C.DATASET.NUM_SOURCE), 128] _C.TRAINER.OPTIMIZER_NAME = "SGD" # choices = [SGD, Adam] -_C.TRAINER.GD = 20 +_C.TRAINER.CLIP_GRADIENT = 20 _C.TRAINER.PRETRAINED = None +_C.TRAINER.RESUME = "" +_C.TRAINER.RESUME_HP = "" _C.TRAINER.MIN_EPOCHS = 25 _C.TRAINER.MAX_EPOCHS = 30 +_C.TRAINER.ACCELERATOR = "ddp" + + + +_C.PATHS.EXP_PATH = os.path.join(_C.DATASET.PATH_EXP + '_' + _C.TRAINER.OPTIMIZER_NAME + '-share_params_' + _C.MODEL.SHARE_PARAMS + '-lr_' + str(_C.TRAINER.LR) + '-bS_' + str(_C.TRAINER.BATCH_SIZE[0]), _C.DATASET.DATASET + '-'+ str(_C.DATASET.NUM_SEGMENTS) + '-seg-disDA_' + _C.MODEL.DIS_DA + '-alpha_' + str(_C.HYPERPARAMETERS.ALPHA) + '-advDA_' + _C.MODEL.ADV_DA + '-beta_' + str(_C.HYPERPARAMETERS.BETA[0])+ '_'+ str(_C.HYPERPARAMETERS.BETA[1])+'_'+ str(_C.HYPERPARAMETERS.BETA[2])+"_gamma_" + str(_C.HYPERPARAMETERS.GAMMA) + "_mu_" + str(_C.HYPERPARAMETERS.MU)) + + +# ---------------------------------------------------------------------------- # +# Tester +# ---------------------------------------------------------------------------- # +_C.TESTER = CN() + +_C.TESTER.TEST_TARGET_DATA = os.path.join(_C.PATHS.PATH_DATA_ROOT, "target_test") + +_C.TESTER.WEIGHTS = os.path.join(_C.EXP_PATH , "checkpoint.pth.tar") +_C.TESTER.NOUN_WEIGHTS = None +_C.TESTER.RESULT_JSON = "test.json" +_C.TESTER.TEST_SEGMENTS = 5 # sample frame # of each video for testing +_C.TESTER.SAVE_SCORES = os.path.join(_C.EXP_PATH , "scores") +_C.TESTER.SAVE_CONFUSION = os.path.join(_C.EXP_PATH , "confusion_matrix") + +_C.TESTER.VERBOSE = True + # ---------------------------------------------------------------------------- # # Miscellaneous configs # ---------------------------------------------------------------------------- # @@ -146,10 +185,10 @@ _C.TRAINER.PF = 50 _C.TRAINER.SF = 50 _C.TRAINER.COPY_LIST = ["N", "N"] +_C.TRAINER.SAVE_MODEL = True -_C.DATASET.EXP_PATH = os.path.join(_C.DATASET.PATH_EXP + '_' + _C.TRAINER.OPTIMIZER_NAME + '-share_params_' + _C.MODEL.SHARE_PARAMS + '-lr_' + str(_C.TRAINER.LR) + '-bS_' + str(_C.TRAINER.BATCH_SIZE[0]), _C.DATASET.DATASET + '-'+ str(_C.DATASET.NUM_SEGMENTS) + '-seg-disDA_' + _C.MODEL.DIS_DA + '-alpha_' + str(_C.HYPERPARAMETERS.ALPHA) + '-advDA_' + _C.MODEL.ADV_DA + '-beta_' + str(_C.HYPERPARAMETERS.BETA[0])+ '_'+ str(_C.HYPERPARAMETERS.BETA[1])+'_'+ str(_C.HYPERPARAMETERS.BETA[2])+"_gamma_" + str(_C.HYPERPARAMETERS.GAMMA) + "_mu_" + str(_C.HYPERPARAMETERS.MU)) diff --git a/main_lightning.py b/main_lightning.py index 429c012..7f5f20c 100644 --- a/main_lightning.py +++ b/main_lightning.py @@ -1,21 +1,19 @@ +import os import numpy as np import time +import argparse import torch import torch.nn.parallel import torch.optim -from colorama import init -from colorama import Fore, Back, Style - - from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from tensorboardX import SummaryWriter from utils.loss import * -from utils.opts import parser from utils.model_init import initialise_trainer +from config import get_cfg_defaults from utils.data_loaders import get_train_data_loaders, get_val_data_loaders from utils.logging import * @@ -25,47 +23,53 @@ torch.manual_seed(1) torch.cuda.manual_seed_all(1) -init(autoreset=True) - -best_prec1 = 0 gpu_count = torch.cuda.device_count() log_info( "Number of GPUS available: " + str(gpu_count)) +def arg_parse(): + """Parsing arguments""" + parser = argparse.ArgumentParser(description="TA3N Domain Adaptation") + parser.add_argument("--cfg", required=True, help="path to config file", type=str) + parser.add_argument("--gpus", default="0", help="gpu id(s) to use", type=str) + parser.add_argument("--resume", default="", type=str) + args = parser.parse_args() + return args + def main(): - args = parser.parse_args() + args = arg_parse() + cfg = get_cfg_defaults() + cfg.merge_from_file(args.cfg) + cfg.freeze() - path_exp = args.exp_path + args.modality + '/' + # log_info(str(cfg)) + + path_exp = os.path.join(cfg.PATHS.EXP_PATH, cfg.DATASET.MODALITY) #========== model init ========# log_info('Initialising model......') - model = initialise_trainer(args) + model = initialise_trainer(cfg) #========== log files init ========# - open_log_files(args) + open_log_files(cfg) #========== Data loading ========# log_info('Loading data......') - - if args.use_opencv: - log_debug("use opencv functions") - - source_loader, target_loader = get_train_data_loaders(args) + source_loader, target_loader = get_train_data_loaders(cfg) - to_validate = args.val_source_data != "none" and args.val_target_data != "none" - if(to_validate): + if(cfg.TO_VALIDATE): log_info('Loading validation data......') - source_loader_val, target_loader_val = get_val_data_loaders(args) + source_loader_val, target_loader_val = get_val_data_loaders(cfg) #========== Callbacks and checkpoints ========# - if args.train_metric == "all": + if cfg.TRAINER.TRAIN_METRIC == "all": monitor = "Prec@1 Action" - elif args.train_metric == "noun": + elif cfg.TRAINER.TRAIN_METRIC == "noun": monitor = "Prec@1 Noun" - elif args.train_metric == "verb": + elif cfg.TRAINER.TRAIN_METRIC == "verb": monitor = "Prec@1 Verb" else: log_error("invalid metric to train") @@ -81,12 +85,18 @@ def main(): #========== Actual Training ========# - trainer = Trainer(min_epochs=20, max_epochs=30, callbacks=[checkpoint_callback], gpus = gpu_count, accelerator='ddp') + trainer = Trainer( + min_epochs=cfg.TRAINER.MIN_EPOCHS, + max_epochs=cfg.TRAINER.MAX_EPOCHS, + callbacks=[checkpoint_callback], + gpus = args.gpus, + accelerator=cfg.TRAINER.ACCELERATOR + ) log_info('Starting training......') start_train = time.time() - if(to_validate): + if(cfg.TO_VALIDATE): trainer.fit(model, (source_loader, target_loader), (source_loader_val, target_loader_val)) else: trainer.fit(model, (source_loader, target_loader)) @@ -95,14 +105,14 @@ def main(): #========== Logging ========# - write_log_files('total time: {:.3f} '.format(end_train - start_train), best_prec1) + write_log_files('total time: {:.3f} '.format(end_train - start_train), model.best_prec1) model.writer_train.close() model.writer_val.close() log_info('Training complete') log_info('Total training time:' + str(end_train - start_train)) - if(to_validate): + if(cfg.TO_VALIDATE): log_info('Validation scores:\n | Prec@1 Verb: ' + str(model.prec1_verb_val) + "\n | Prec@1 Noun: " + str(model.prec1_noun_val)+ "\n | Prec@1 Action: " + str(model.prec1_val) + "\n | Prec@5 Verb: " + str(model.prec5_verb_val) + "\n | Prec@5 Noun: " + str(model.prec5_noun_val) + "\n | Prec@5 Action: " + str(model.prec5_val) + "\n | Loss total: " + str(model.losses_val)) diff --git a/test_models_lightning.py b/test_models_lightning.py index c88831f..357df19 100644 --- a/test_models_lightning.py +++ b/test_models_lightning.py @@ -1,4 +1,4 @@ -import numpy as np +import argparse import json from json import encoder @@ -9,37 +9,50 @@ from pytorch_lightning import Trainer -from colorama import init -from colorama import Fore, Back, Style - -from utils.opts_test import parser from utils.model_init import initialise_tester from utils.data_loaders import get_test_data_loaders - +from config import get_cfg_defaults from utils.logging import * encoder.FLOAT_REPR = lambda o: format(o, '.3f') -init(autoreset=True) -def main(): +def arg_parse(): + """Parsing arguments""" + parser = argparse.ArgumentParser(description="TA3N Domain Adaptation Testing") + parser.add_argument("--cfg", required=True, help="path to config file", type=str) + parser.add_argument("--gpus", default="0", help="gpu id(s) to use", type=str) + parser.add_argument("--ckpt", default=None, help="pre-trained parameters for the model (ckpt files)", type=str) args = parser.parse_args() + return args + +def main(): + args = arg_parse() + cfg = get_cfg_defaults() + cfg.merge_from_file(args.cfg) + cfg.freeze() #========== model init ========# log_info('Preparing the model......') - verb_net, noun_net = initialise_tester(args) + verb_net, noun_net = initialise_tester(cfg) #========== Data loading ========# log_info('Loading data......') - data_loader = get_test_data_loaders(args) - log_info('Data loaded from: ' + args.test_target_data+".pkl") + data_loader = get_test_data_loaders(cfg) + log_info('Data loaded from: ' + cfg.TESTER.TEST_TARGET_DATA+".pkl") #========== Actual Testing ========# log_info('starting validation......') - trainer = Trainer(gpus = torch.cuda.device_count()) - - trainer.test(model = verb_net, test_dataloaders=data_loader, ckpt_path=args.weights, verbose = True) + trainer = Trainer(gpus = args.gpus) + + if args.ckpt is None: + ckpt_path = cfg.TESTER.WEIGHTS + else: + ckpt_path = args.ckpt + trainer.test(model = verb_net, test_dataloaders=data_loader, ckpt_path=ckpt_path, verbose = cfg.TESTER.VERBOSE) + if noun_net is not None: + trainer.test(model = noun_net, test_dataloaders=data_loader, ckpt_path=ckpt_path, verbose = cfg.TESTER.VERBOSE) log_info('validation complete') diff --git a/utils/model_init.py b/utils/model_init.py index 9657649..11662e3 100644 --- a/utils/model_init.py +++ b/utils/model_init.py @@ -9,53 +9,50 @@ from tensorboardX import SummaryWriter -def set_hyperparameters(model, args): - model.optimizerName = args.optimizer - model.loss_type = args.loss_type - model.lr = args.lr - model.momentum = args.momentum - model.weight_decay = args.weight_decay - model.epochs = args.epochs - model.batch_size = args.batch_size - model.eval_freq = args.eval_freq - - model.lr_adaptive = args.lr_adaptive - model.lr_decay = args.lr_decay - model.lr_steps = args.lr_steps - - model.alpha = args.alpha - model.beta = args.beta - model.gamma = args.gamma - model.mu = args.mu - - model.train_metric = args.train_metric - model.dann_warmup = args.dann_warmup +def set_hyperparameters(model, cfg): + model.optimizerName = cfg.TRAINER.OPTIMIZER_NAME + model.loss_type = cfg.TRAINER.LOSS_TYPE + model.lr = cfg.TRAINER.LR + model.momentum = cfg.TRAINER.MOMENTUMs + model.weight_decay = cfg.TRAINER.WEIGHT_DECAY + model.epochs = cfg.TRAINER.MAX_EPOCHS + model.batch_size = cfg.TRAINER.BATCH_SIZE + + model.lr_adaptive = cfg.TRAINER.LR_ADAPTIVE + model.lr_decay = cfg.TRAINER.LR_DECAY + model.lr_steps = cfg.TRAINER.LR_STEPS + + model.alpha = cfg.HYPERPARAMETERS.ALPHA + model.beta = cfg.HYPERPARAMETERS.BETA + model.gamma = cfg.HYPERPARAMETERS.GAMMA + model.mu = cfg.HYPERPARAMETERS.MU + + model.train_metric = cfg.train_metric + model.dann_warmup = cfg.dann_warmup model.tensorboard = True - model.path_exp = model.modality + '/' + model.path_exp = cfg.PATHS.EXP_PATH if not os.path.isdir(model.path_exp): os.makedirs(model.path_exp) model.writer_train = SummaryWriter(model.path_exp + '/tensorboard_train') # for tensorboardX model.writer_val = SummaryWriter(model.path_exp + '/tensorboard_val') # for tensorboardX - model.pretrain_source = args.pretrain_source - model.clip_gradient = args.clip_gradient - - model.dis_DA = args.dis_DA - model.use_target = args.use_target - model.add_fc = args.add_fc - model.place_dis = args.place_dis - model.place_adv = args.place_adv - model.pred_normalize = args.pred_normalize - model.add_loss_DA = args.add_loss_DA - model.print_freq = args.print_freq - model.show_freq = args.show_freq - model.ens_DA = args.ens_DA - - model.arch = args.arch - model.save_model = args.save_model + model.pretrain_source = cfg.TRAINER.PRETRAIN_SOURCE + model.clip_gradient = cfg.TRAINER.CLIP_GRADIENT + + model.dis_DA = cfg.MODEL.DIS_DA + model.use_target = cfg.MODEL.USE_TARGET + model.add_fc = cfg.MODEL.ADD_FC + model.place_dis = cfg.MODEL.PLACE_DIS + model.place_adv = cfg.MODEL.PLACE_ADV + model.pred_normalize = cfg.MODEL.PRED_NORMALIZE + model.add_loss_DA = cfg.MODEL.ADD_LOSS_DA + model.ens_DA = cfg.MODEL.ENS_DA + + model.arch = cfg.MODEL.ARCH + model.save_model = cfg.TRAINER.SAVE_MODEL model.labels_available = True - model.adv_DA = args.adv_DA + model.adv_DA = cfg.MODEL.ADV_DA if model.loss_type == 'nll': model.criterion = torch.nn.CrossEntropyLoss() @@ -63,36 +60,36 @@ def set_hyperparameters(model, args): else: raise ValueError("Unknown loss type") -def initialise_trainer(args): +def initialise_trainer(cfg): - log_debug('Baseline:' + args.baseline_type) - log_debug('Frame aggregation method:' + args.frame_aggregation) + log_debug('Baseline:' + cfg.DATASET.BASELINE_TYPE) + log_debug('Frame aggregation method:' + cfg.DATASET.FRAME_AGGREGATION) - log_debug('target data usage:' + args.use_target) - if args.use_target == 'none': + log_debug('target data usage:' + cfg.MODEL.USE_TARGET) + if cfg.MODEL.USE_TARGET is None: log_debug('no Domain Adaptation') else: - if args.dis_DA != 'none': - log_debug('Apply the discrepancy-based Domain Adaptation approach:'+ args.dis_DA) - if len(args.place_dis) != args.add_fc + 2: + if cfg.MODEL.DIS_DA is not None: + log_debug('Apply the discrepancy-based Domain Adaptation approach:'+ cfg.MODEL.DIS_DA) + if len(cfg.MODEL.PLACE_DIS) != cfg.MODEL.ADD_FC + 2: log_error('len(place_dis) should be equal to add_fc + 2') raise ValueError('len(place_dis) should be equal to add_fc + 2') - if args.adv_DA != 'none': - log_debug('Apply the adversarial-based Domain Adaptation approach:'+ args.adv_DA) + if cfg.MODEL.ADV_DA is not None: + log_debug('Apply the adversarial-based Domain Adaptation approach:'+ cfg.MODEL.ADV_DA) - if args.use_bn != 'none': - log_debug('Apply the adaptive normalization approach:'+ args.use_bn) + if cfg.MODEL.USE_BN is not None: + log_debug('Apply the adaptive normalization approach:'+ cfg.MODEL.USE_BN) # determine the categories #want to allow multi-label classes. #Original way to compute number of classes - ####class_names = [line.strip().split(' ', 1)[1] for line in open(args.class_file)] + ####class_names = [line.strip().split(' ', 1)[1] for line in open(cfg.class_file)] ####num_class = len(class_names) #New approach - num_class_str = args.num_class.split(",") + num_class_str = cfg.DATASET.NUM_CLASSES.split(",") #single class if len(num_class_str) < 1: raise Exception("Must specify a number of classes to train") @@ -102,27 +99,27 @@ def initialise_trainer(args): num_class.append(int(num)) #=== check the folder existence ===# - path_exp = args.exp_path + args.modality + '/' + path_exp = cfg.PATHS.EXP_PATH if not os.path.isdir(path_exp): os.makedirs(path_exp) #=== initialize the model ===# log_info('preparing the model......') - model = VideoModel(num_class, args.baseline_type, args.frame_aggregation, args.modality, - train_segments=args.num_segments, val_segments=args.val_segments, - base_model=args.arch, path_pretrained=args.pretrained, - add_fc=args.add_fc, fc_dim = args.fc_dim, - dropout_i=args.dropout_i, dropout_v=args.dropout_v, partial_bn=not args.no_partialbn, - use_bn=args.use_bn if args.use_target != 'none' else 'none', ens_DA=args.ens_DA if args.use_target != 'none' else 'none', - n_rnn=args.n_rnn, rnn_cell=args.rnn_cell, n_directions=args.n_directions, n_ts=args.n_ts, - use_attn=args.use_attn, n_attn=args.n_attn, use_attn_frame=args.use_attn_frame, - verbose=args.verbose, share_params=args.share_params) - - if args.optimizer == 'SGD': + model = VideoModel(num_class, cfg.DATASET.BASELINE_TYPE, cfg.DATASET.FRAME_AGGREGATION, cfg.DATASET.MODALITY, + train_segments=cfg.DATASET.NUM_SEGMENTS, val_segments=cfg.DATASET.NUM_SEGMENTS, + base_model=cfg.MODEL.ARCH, path_pretrained=cfg.TRAINER.PRETRAINED, + add_fc=cfg.MODEL.ADD_FC, fc_dim = cfg.MODEL.FC_DIM, + dropout_i=cfg.MODEL.DROPOUT_I, dropout_v=cfg.MODEL.DROPOUT_V, partial_bn=not cfg.MODEL.NO_PARTIALBN, + use_bn=cfg.MODEL.USE_BN if cfg.MODEL.USE_TARGET is not None else None, ens_DA=cfg.MODEL.ENS_DA if cfg.MODEL.USE_TARGET is not None else None, + n_rnn=cfg.MODEL.N_RNN, rnn_cell=cfg.MODEL.RNN_CELL, n_directions=cfg.MODEL.N_DIRECTIONS, n_ts=cfg.MODEL.N_TS, + use_attn=cfg.MODEL.USE_ATTN, n_attn=cfg.MODEL.N_ATTN, use_attn_frame=cfg.MODEL.USE_ATTN_FRAME, + verbose=cfg.TRAINER.VERBOSE, share_params=cfg.MODEL.SHARE_PARAMS) + + if cfg.TRAINER.OPTIMIZER_NAME == 'SGD': log_debug('using SGD') model.optimizerName = 'SGD' - elif args.optimizer == 'Adam': + elif cfg.TRAINER.OPTIMIZER_NAME == 'Adam': log_debug( 'using Adam') model.optimizerName = 'Adam' else: @@ -130,20 +127,19 @@ def initialise_trainer(args): exit() #=== check point ===# - start_epoch = 1 log_debug('checking the checkpoint......') - if args.resume: - if os.path.isfile(args.resume): - checkpoint = torch.load(args.resume) + if cfg.TRAINER.RESUME != "": + if os.path.isfile(cfg.TRAINER.RESUME): + checkpoint = torch.load(cfg.TRAINER.RESUME) start_epoch = checkpoint['epoch'] + 1 best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) - log_debug("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) - if args.resume_hp: + log_debug("=> loaded checkpoint '{}' (epoch {})".format(cfg.TRAINER.RESUME, checkpoint['epoch'])) + if cfg.TRAINER.RESUME_HP: log_debug("=> loaded checkpoint hyper-parameters") model.optimizer.load_state_dict(checkpoint['optimizer']) else: - log_error("=> no checkpoint found at '{}'".format(args.resume)) + log_error("=> no checkpoint found at '{}'".format(cfg.TRAINER.RESUME)) cudnn.benchmark = True @@ -151,33 +147,28 @@ def initialise_trainer(args): # --- Optimizer ---# # define loss function (criterion) and optimizer - if args.loss_type == 'nll': + if cfg.TRAINER.LOSS_TYPE == 'nll': model.loss_type = 'nll' else: raise ValueError("Unknown loss type") - # --- Parameters ---# - model.beta = args.beta - model.gamma = args.gamma - model.mu = args.mu - - set_hyperparameters(model, args) + set_hyperparameters(model, cfg) return model -def set_hyperparameters_test(model, args): - model.batch_size = [args.bS] - model.alpha = 1 - model.beta = [1, 1, 1] - model.gamma = 1 - model.mu = 0 +def set_hyperparameters_test(model, cfg): + model.batch_size = cfg.TRAINER.BATCH_SIZE + model.alpha = cfg.HYPERPARAMETERS.ALPHA + model.beta = cfg.HYPERPARAMETERS.BETA + model.gamma = cfg.HYPERPARAMETERS.GAMMA + model.mu = cfg.HYPERPARAMETERS.MU model.criterion = torch.nn.CrossEntropyLoss() model.criterion_domain = torch.nn.CrossEntropyLoss() -def initialise_tester(args): +def initialise_tester(cfg): # New approach - num_class_str = args.num_class.split(",") + num_class_str = cfg.DATASET.NUM_CLASSES.split(",") # single class if len(num_class_str) < 1: raise Exception("Must specify a number of classes to train") @@ -187,37 +178,37 @@ def initialise_tester(args): num_class.append(int(num)) - verb_net = VideoModel(num_class, args.baseline_type, args.frame_aggregation, args.modality, - train_segments=args.test_segments if args.baseline_type == 'video' else 1, val_segments=args.test_segments if args.baseline_type == 'video' else 1, - base_model=args.arch, add_fc=args.add_fc, fc_dim=args.fc_dim, share_params=args.share_params, - dropout_i=args.dropout_i, dropout_v=args.dropout_v, use_bn=args.use_bn, partial_bn=False, - n_rnn=args.n_rnn, rnn_cell=args.rnn_cell, n_directions=args.n_directions, n_ts=args.n_ts, - use_attn=args.use_attn, n_attn=args.n_attn, use_attn_frame=args.use_attn_frame, - verbose=args.verbose, before_softmax=False) + verb_net = VideoModel(num_class, cfg.DATASET.BASELINE_TYPE, cfg.DATASET.FRAME_AGGREGATION, cfg.DATASET.MODALITY, + train_segments=cfg.TESTER.TEST_SEGMENTS if cfg.DATASET.BASELINE_TYPE == 'video' else 1, val_segments=cfg.TESTER.TEST_SEGMENTS if cfg.DATASET.BASELINE_TYPE == 'video' else 1, + base_model=cfg.MODEL.ARCH, add_fc=cfg.MODEL.ADD_FC, fc_dim=cfg.MODEL.FC_DIM, share_params=cfg.MODEL.SHARE_PARAMS, + dropout_i=cfg.MODEL.DROPOUT_I, dropout_v=cfg.MODEL.DROPOUT_V, use_bn=cfg.MODEL.USE_BN, partial_bn=False, + n_rnn=cfg.MODEL.N_RNN, rnn_cell=cfg.MODEL.RNN_CELL, n_directions=cfg.MODEL.N_DIRECTIONS, n_ts=cfg.MODEL.N_TS, + use_attn=cfg.MODEL.USE_ATTN, n_attn=cfg.MODEL.N_ATTN, use_attn_frame=cfg.MODEL.USE_ATTN_FRAME, + verbose=cfg.TESTER.VERBOSE, before_softmax=False) - verb_checkpoint = torch.load(args.weights) + verb_checkpoint = torch.load(cfg.weights) verb_base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(verb_checkpoint['state_dict'].items())} verb_net.load_state_dict(verb_base_dict) # verb_net = torch.nn.DataParallel(verb_net) - set_hyperparameters_test(verb_net, args) + set_hyperparameters_test(verb_net, cfg) verb_net.eval() - if args.noun_weights is not None: - noun_net = VideoModel(num_class, args.baseline_type, args.frame_aggregation, args.modality, - train_segments=args.test_segments if args.baseline_type == 'video' else 1, - val_segments=args.test_segments if args.baseline_type == 'video' else 1, - base_model=args.arch, add_fc=args.add_fc, fc_dim=args.fc_dim, share_params=args.share_params, - dropout_i=args.dropout_i, dropout_v=args.dropout_v, use_bn=args.use_bn, partial_bn=False, - n_rnn=args.n_rnn, rnn_cell=args.rnn_cell, n_directions=args.n_directions, n_ts=args.n_ts, - use_attn=args.use_attn, n_attn=args.n_attn, use_attn_frame=args.use_attn_frame, - verbose=args.verbose, before_softmax=False) - noun_checkpoint = torch.load(args.noun_weights) + if cfg.TESTER.NOUN_WEIGHTS is not None: + noun_net = VideoModel(num_class, cfg.DATASET.BASELINE_TYPE, cfg.DATASET.FRAME_AGGREGATION, cfg.DATASET.MODALITY, + train_segments=cfg.TESTER.TEST_SEGMENTS if cfg.DATASET.BASELINE_TYPE == 'video' else 1, + val_segments=cfg.TESTER.TEST_SEGMENTS if cfg.DATASET.BASELINE_TYPE == 'video' else 1, + base_model=cfg.MODEL.ARCH, add_fc=cfg.MODEL.ADD_FC, fc_dim=cfg.MODEL.FC_DIM, share_params=cfg.MODEL.SHARE_PARAMS, + dropout_i=cfg.MODEL.DROPOUT_I, dropout_v=cfg.MODEL.DROPOUT_V, use_bn=cfg.MODEL.USE_BN, partial_bn=False, + n_rnn=cfg.MODEL.N_RNN, rnn_cell=cfg.MODEL.RNN_CELL, n_directions=cfg.MODEL.N_DIRECTIONS, n_ts=cfg.MODEL.N_TS, + use_attn=cfg.MODEL.USE_ATTN, n_attn=cfg.MODEL.N_ATTN, use_attn_frame=cfg.MODEL.USE_ATTN_FRAME, + verbose=cfg.TESTER.VERBOSE, before_softmax=False) + noun_checkpoint = torch.load(cfg.TESTER.NOUN_WEIGHTS) noun_base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(noun_checkpoint['state_dict'].items())} noun_net.load_state_dict(noun_base_dict) # noun_net = torch.nn.DataParallel(noun_net.cuda()) - set_hyperparameters_test(noun_net, args) + set_hyperparameters_test(noun_net, cfg) noun_net.eval() else: noun_net = None