diff --git a/src/nbla/cuda/solver/generic/adadelta.cu b/src/nbla/cuda/solver/generic/adadelta.cu index 0e8de162e0..6f14d32dd8 100644 --- a/src/nbla/cuda/solver/generic/adadelta.cu +++ b/src/nbla/cuda/solver/generic/adadelta.cu @@ -38,9 +38,9 @@ __global__ void kernel_adadelta_update(const int num, T *data, const T *grad, template void AdadeltaCuda::update_impl(const string &key, VariablePtr param) { Size_t size = param->size(); - auto &state = this->state_.at(key); - VariablePtr e1 = state.e_sqr_grad; - VariablePtr e2 = state.e_sqr_delta; + auto &state = this->states_.at(key); + VariablePtr e1 = state.pstate["e_sqr_grad"]; + VariablePtr e2 = state.pstate["e_sqr_delta"]; T *e_sqr_grad = e1->cast_data_and_get_pointer(this->ctx_); T *e_sqr_delta = e2->cast_data_and_get_pointer(this->ctx_); const T *grad = param->get_grad_pointer(this->ctx_); @@ -48,6 +48,8 @@ void AdadeltaCuda::update_impl(const string &key, VariablePtr param) { NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_adadelta_update, size, data, grad, e_sqr_grad, e_sqr_delta, this->lr_, this->decay_, this->eps_); + auto &t = state.t; + t = std::min(t + 1, std::numeric_limits::max() - 1); } NBLA_DEF_WEIGHT_DECAY(AdadeltaCuda, weight_decay_cuda); diff --git a/src/nbla/cuda/solver/generic/adagrad.cu b/src/nbla/cuda/solver/generic/adagrad.cu index 59d526b074..f72abb5bd5 100644 --- a/src/nbla/cuda/solver/generic/adagrad.cu +++ b/src/nbla/cuda/solver/generic/adagrad.cu @@ -32,10 +32,13 @@ __global__ void kernel_adagrad_update(const int num, T *data, const T *grad, template void AdagradCuda::update_impl(const string &key, VariablePtr param) { Size_t size = param->size(); - VariablePtr g_ = this->state_.at(key); + auto &state = this->states_.at(key); + VariablePtr g_ = state.pstate["v"]; T *g = g_->cast_data_and_get_pointer(this->ctx_); + auto &t = state.t; const T *grad = param->get_grad_pointer(this->ctx_); T *data = param->cast_data_and_get_pointer(this->ctx_); + t = std::min(t + 1, std::numeric_limits::max() - 1); NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_adagrad_update, size, data, grad, g, this->lr_, this->eps_); } diff --git a/src/nbla/cuda/solver/generic/adam.cu b/src/nbla/cuda/solver/generic/adam.cu index badb95c54a..6d2208539d 100644 --- a/src/nbla/cuda/solver/generic/adam.cu +++ b/src/nbla/cuda/solver/generic/adam.cu @@ -38,15 +38,16 @@ template void AdamCuda::update_impl(const string &key, VariablePtr param) { cuda_set_device(std::stoi(this->ctx_.device_id)); Size_t size = param->size(); - auto &state = this->state_.at(key); - int &t = state.t; + auto &state = this->states_.at(key); + uint32_t &t = state.t; const T *g = param->get_grad_pointer(this->ctx_); - shared_ptr mean_ = state.mean; // To prevent compile error. - shared_ptr var_ = state.var; // To prevent compile error. + shared_ptr mean_ = + state.pstate["mean"]; // To prevent compile error. + shared_ptr var_ = state.pstate["var"]; // To prevent compile error. T *m = mean_->cast_data_and_get_pointer(this->ctx_); T *v = var_->cast_data_and_get_pointer(this->ctx_); T *theta = param->cast_data_and_get_pointer(this->ctx_); - t = std::min(t + 1, std::numeric_limits::max()); + t = std::min(t + 1, std::numeric_limits::max() - 1); const T bias_correction = std::sqrt(1 - std::pow(this->beta2_, t)) / (1 - std::pow(this->beta1_, t)); const T alpha_t = this->alpha_ * bias_correction; diff --git a/src/nbla/cuda/solver/generic/adamax.cu b/src/nbla/cuda/solver/generic/adamax.cu index 9fe0e617a9..4cca7336ca 100644 --- a/src/nbla/cuda/solver/generic/adamax.cu +++ b/src/nbla/cuda/solver/generic/adamax.cu @@ -37,15 +37,15 @@ __global__ void kernel_adamax_update(const int num, T *theta, T *m, T *u, template void AdamaxCuda::update_impl(const string &key, VariablePtr param) { Size_t size = param->size(); - auto &state = this->state_.at(key); - int &t = state.t; - VariablePtr s1 = state.mean; - VariablePtr s2 = state.u; + auto &state = this->states_.at(key); + uint32_t &t = state.t; + VariablePtr s1 = state.pstate["m"]; + VariablePtr s2 = state.pstate["u"]; const T *g = param->get_grad_pointer(this->ctx_); T *m = s1->cast_data_and_get_pointer(this->ctx_); T *u = s2->cast_data_and_get_pointer(this->ctx_); T *theta = param->cast_data_and_get_pointer(this->ctx_); - t = std::min(t + 1, std::numeric_limits::max()); + t = std::min(t + 1, std::numeric_limits::max() - 1); const T bias_correction = 1 / (1 - std::pow(this->beta1_, t)); const T alpha_t = this->alpha_ * bias_correction; NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_adamax_update, size, theta, m, u, g, diff --git a/src/nbla/cuda/solver/generic/amsgrad.cu b/src/nbla/cuda/solver/generic/amsgrad.cu index 6c92b7c44e..f186e68f0c 100644 --- a/src/nbla/cuda/solver/generic/amsgrad.cu +++ b/src/nbla/cuda/solver/generic/amsgrad.cu @@ -39,17 +39,18 @@ template void AMSGRADCuda::update_impl(const string &key, VariablePtr param) { cuda_set_device(std::stoi(this->ctx_.device_id)); Size_t size = param->size(); - auto &state = this->state_.at(key); - int &t = state.t; + auto &state = this->states_.at(key); + auto &t = state.t; const T *g = param->get_grad_pointer(this->ctx_); - shared_ptr mean_ = state.mean; // To prevent compile error. - shared_ptr var_ = state.var; // To prevent compile error. - shared_ptr var_hat_ = state.var_hat; // To prevent compile error. + shared_ptr mean_ = state.pstate["m"]; // To prevent compile error. + shared_ptr var_ = state.pstate["v"]; // To prevent compile error. + shared_ptr var_hat_ = + state.pstate["v_hat"]; // To prevent compile error. T *m = mean_->cast_data_and_get_pointer(this->ctx_); T *v = var_->cast_data_and_get_pointer(this->ctx_); T *v_hat = var_hat_->cast_data_and_get_pointer(this->ctx_); T *theta = param->cast_data_and_get_pointer(this->ctx_); - t = std::min(t + 1, std::numeric_limits::max()); + t = std::min(t + 1, std::numeric_limits::max() - 1); const T bias_correction = std::sqrt(1 - std::pow(this->beta2_, t)) / (1 - std::pow(this->beta1_, t)); const T alpha_t = diff --git a/src/nbla/cuda/solver/generic/momentum.cu b/src/nbla/cuda/solver/generic/momentum.cu index 2e26fdeff3..939f73641e 100644 --- a/src/nbla/cuda/solver/generic/momentum.cu +++ b/src/nbla/cuda/solver/generic/momentum.cu @@ -34,12 +34,15 @@ template void MomentumCuda::update_impl(const string &key, VariablePtr param) { cuda_set_device(std::stoi(this->ctx_.device_id)); Size_t size = param->size(); - VariablePtr r_ = this->state_.at(key); + auto &state = this->states_.at(key); + VariablePtr r_ = state.pstate["m"]; const T *grad = param->get_grad_pointer(this->ctx_); T *v = r_->cast_data_and_get_pointer(this->ctx_); T *data = param->cast_data_and_get_pointer(this->ctx_); NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_momentum_update, size, data, grad, v, this->lr_, this->momentum_); + auto &t = state.t; + t = std::min(t + 1, std::numeric_limits::max() - 1); } NBLA_DEF_WEIGHT_DECAY(MomentumCuda, weight_decay_cuda); diff --git a/src/nbla/cuda/solver/generic/nesterov.cu b/src/nbla/cuda/solver/generic/nesterov.cu index 35fc699f2b..6c7bf53cfa 100644 --- a/src/nbla/cuda/solver/generic/nesterov.cu +++ b/src/nbla/cuda/solver/generic/nesterov.cu @@ -34,12 +34,15 @@ __global__ void kernel_nesterov_update(const int num, T *data, const T *grad, template void NesterovCuda::update_impl(const string &key, VariablePtr param) { Size_t size = param->size(); - VariablePtr v_ = this->state_.at(key); + auto &state = this->states_.at(key); + VariablePtr v_ = state.pstate["m"]; T *v = v_->cast_data_and_get_pointer(this->ctx_); const T *grad = param->get_grad_pointer(this->ctx_); T *data = param->cast_data_and_get_pointer(this->ctx_); NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_nesterov_update, size, data, grad, v, this->lr_, this->momentum_); + auto &t = state.t; + t = std::min(t + 1, std::numeric_limits::max() - 1); } NBLA_DEF_WEIGHT_DECAY(NesterovCuda, weight_decay_cuda); diff --git a/src/nbla/cuda/solver/generic/rmsprop.cu b/src/nbla/cuda/solver/generic/rmsprop.cu index aa5f97a9af..a30a88e7aa 100644 --- a/src/nbla/cuda/solver/generic/rmsprop.cu +++ b/src/nbla/cuda/solver/generic/rmsprop.cu @@ -34,13 +34,16 @@ __global__ void kernel_rmsprop_update(const int num, T *data, const T *grad, template void RMSpropCuda::update_impl(const string &key, VariablePtr param) { Size_t size = param->size(); - VariablePtr state = this->state_.at(key); - T *e_sqr_grad = state->cast_data_and_get_pointer(this->ctx_); + auto &state = this->states_.at(key); + VariablePtr v = state.pstate["v"]; + T *e_sqr_grad = v->cast_data_and_get_pointer(this->ctx_); const T *grad = param->get_grad_pointer(this->ctx_); T *data = param->cast_data_and_get_pointer(this->ctx_); NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_rmsprop_update, size, data, grad, e_sqr_grad, this->lr_, this->decay_, this->eps_); + auto &t = state.t; + t = std::min(t + 1, std::numeric_limits::max() - 1); } NBLA_DEF_WEIGHT_DECAY(RMSpropCuda, weight_decay_cuda); diff --git a/src/nbla/cuda/solver/generic/sgd.cu b/src/nbla/cuda/solver/generic/sgd.cu index b1da2a8521..b5d9db1630 100644 --- a/src/nbla/cuda/solver/generic/sgd.cu +++ b/src/nbla/cuda/solver/generic/sgd.cu @@ -33,6 +33,9 @@ void SgdCuda::update_impl(const string &key, VariablePtr param) { const T *grad = param->get_grad_pointer(this->ctx_); T *data = param->cast_data_and_get_pointer(this->ctx_); NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_update, size, data, grad, this->lr_); + auto &state = this->states_.at(key); + auto &t = state.t; + t = std::min(t + 1, std::numeric_limits::max() - 1); } NBLA_DEF_WEIGHT_DECAY(SgdCuda, weight_decay_cuda);