diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp index 193a0ab7f..19a94738b 100644 --- a/src/caffe/solvers/sgd_solver.cpp +++ b/src/caffe/solvers/sgd_solver.cpp @@ -300,9 +300,13 @@ void SGDSolver::ApplyUpdate(int param_id) { LOG_PARAM_BLOB(this->net_->learnable_params()[param_id], diff, param_id, "ApplyUpdate: delwt after Normalize:"); - Regularize(param_id); - - LOG_PARAM_BLOB(this->net_->learnable_params()[param_id], diff, param_id, "ApplyUpdate: delwt after Regularize:"); + //In original intel-caffe code, only SGD(Not NESTEROV, ADAGRAD, RMSPROP, ADADELTA, ADAM) adapted LARS. So, we change only the flow of SGD. + //We execute Regularize process after GetLocalRate(LARS) when solver_type is "SGD". + if (this->param_.type().compare("SGD") != 0) + { + Regularize(param_id); + LOG_PARAM_BLOB(this->net_->learnable_params()[param_id], diff, param_id, "ApplyUpdate: delwt after Regularize:"); + } ComputeUpdateValue(param_id, rate); @@ -413,7 +417,6 @@ void SGDSolver::SGDFusion(int param_id, Dtype rate) { //ComputeUpdateValue initialization Dtype momentum = this->param_.momentum(); - Dtype local_rate = rate * GetLocalRate(param_id); //#pragma endregion //#pragma region 2. Common condition judgement @@ -451,7 +454,10 @@ void SGDSolver::SGDFusion(int param_id, Dtype rate) { net_params[param_id]->mutable_cpu_diff()); } } + //#pragma endregion +//execute GetLocalRate(LARS) after Normalize stage +Dtype local_rate = rate * GetLocalRate(param_id); //For most common topologies from BVLC, all skipped the Normalize stage, and use L2 regularization //If prv_diff_condition_flag == true, then prv_data_condition_flag == true (1) @@ -672,7 +678,8 @@ void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) { const vector*>& net_params = this->net_->learnable_params(); Dtype momentum = this->param_.momentum(); Dtype local_rate = rate * GetLocalRate(param_id); - + Regularize(param_id); + LOG_PARAM_BLOB(this->net_->learnable_params()[param_id], diff, param_id, "ApplyUpdate: delwt after Regularize:"); if (this->param_.warmup_iter() > 0 && this->iter_ < this->param_.warmup_iter()) { // Momentum correction during warmup stage