Skip to content

Commit

Permalink
Merge pull request #116 from sony/feature/20181205-utils-qnn-train
Browse files Browse the repository at this point in the history
Serialization of SolverState
  • Loading branch information
TakuyaNarihira authored Jan 11, 2019
2 parents d1e9474 + 6a2118f commit 13b5e9f
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 24 deletions.
8 changes: 5 additions & 3 deletions src/nbla/cuda/solver/generic/adadelta.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,18 @@ __global__ void kernel_adadelta_update(const int num, T *data, const T *grad,
template <typename T>
void AdadeltaCuda<T>::update_impl(const string &key, VariablePtr param) {
Size_t size = param->size();
auto &state = this->state_.at(key);
VariablePtr e1 = state.e_sqr_grad;
VariablePtr e2 = state.e_sqr_delta;
auto &state = this->states_.at(key);
VariablePtr e1 = state.pstate["e_sqr_grad"];
VariablePtr e2 = state.pstate["e_sqr_delta"];
T *e_sqr_grad = e1->cast_data_and_get_pointer<T>(this->ctx_);
T *e_sqr_delta = e2->cast_data_and_get_pointer<T>(this->ctx_);
const T *grad = param->get_grad_pointer<T>(this->ctx_);
T *data = param->cast_data_and_get_pointer<T>(this->ctx_);
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_adadelta_update, size, data, grad,
e_sqr_grad, e_sqr_delta, this->lr_,
this->decay_, this->eps_);
auto &t = state.t;
t = std::min(t + 1, std::numeric_limits<uint32_t>::max() - 1);
}

NBLA_DEF_WEIGHT_DECAY(AdadeltaCuda, weight_decay_cuda);
Expand Down
5 changes: 4 additions & 1 deletion src/nbla/cuda/solver/generic/adagrad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ __global__ void kernel_adagrad_update(const int num, T *data, const T *grad,
template <typename T>
void AdagradCuda<T>::update_impl(const string &key, VariablePtr param) {
Size_t size = param->size();
VariablePtr g_ = this->state_.at(key);
auto &state = this->states_.at(key);
VariablePtr g_ = state.pstate["v"];
T *g = g_->cast_data_and_get_pointer<T>(this->ctx_);
auto &t = state.t;
const T *grad = param->get_grad_pointer<T>(this->ctx_);
T *data = param->cast_data_and_get_pointer<T>(this->ctx_);
t = std::min(t + 1, std::numeric_limits<uint32_t>::max() - 1);
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_adagrad_update, size, data, grad, g,
this->lr_, this->eps_);
}
Expand Down
11 changes: 6 additions & 5 deletions src/nbla/cuda/solver/generic/adam.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@ template <typename T>
void AdamCuda<T>::update_impl(const string &key, VariablePtr param) {
cuda_set_device(std::stoi(this->ctx_.device_id));
Size_t size = param->size();
auto &state = this->state_.at(key);
int &t = state.t;
auto &state = this->states_.at(key);
uint32_t &t = state.t;
const T *g = param->get_grad_pointer<T>(this->ctx_);
shared_ptr<Variable> mean_ = state.mean; // To prevent compile error.
shared_ptr<Variable> var_ = state.var; // To prevent compile error.
shared_ptr<Variable> mean_ =
state.pstate["mean"]; // To prevent compile error.
shared_ptr<Variable> var_ = state.pstate["var"]; // To prevent compile error.
T *m = mean_->cast_data_and_get_pointer<T>(this->ctx_);
T *v = var_->cast_data_and_get_pointer<T>(this->ctx_);
T *theta = param->cast_data_and_get_pointer<T>(this->ctx_);
t = std::min(t + 1, std::numeric_limits<int>::max());
t = std::min(t + 1, std::numeric_limits<uint32_t>::max() - 1);
const T bias_correction = std::sqrt(1 - std::pow(this->beta2_, t)) /
(1 - std::pow(this->beta1_, t));
const T alpha_t = this->alpha_ * bias_correction;
Expand Down
10 changes: 5 additions & 5 deletions src/nbla/cuda/solver/generic/adamax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ __global__ void kernel_adamax_update(const int num, T *theta, T *m, T *u,
template <typename T>
void AdamaxCuda<T>::update_impl(const string &key, VariablePtr param) {
Size_t size = param->size();
auto &state = this->state_.at(key);
int &t = state.t;
VariablePtr s1 = state.mean;
VariablePtr s2 = state.u;
auto &state = this->states_.at(key);
uint32_t &t = state.t;
VariablePtr s1 = state.pstate["m"];
VariablePtr s2 = state.pstate["u"];
const T *g = param->get_grad_pointer<T>(this->ctx_);
T *m = s1->cast_data_and_get_pointer<T>(this->ctx_);
T *u = s2->cast_data_and_get_pointer<T>(this->ctx_);
T *theta = param->cast_data_and_get_pointer<T>(this->ctx_);
t = std::min(t + 1, std::numeric_limits<int>::max());
t = std::min(t + 1, std::numeric_limits<uint32_t>::max() - 1);
const T bias_correction = 1 / (1 - std::pow(this->beta1_, t));
const T alpha_t = this->alpha_ * bias_correction;
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_adamax_update, size, theta, m, u, g,
Expand Down
13 changes: 7 additions & 6 deletions src/nbla/cuda/solver/generic/amsgrad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,18 @@ template <typename T>
void AMSGRADCuda<T>::update_impl(const string &key, VariablePtr param) {
cuda_set_device(std::stoi(this->ctx_.device_id));
Size_t size = param->size();
auto &state = this->state_.at(key);
int &t = state.t;
auto &state = this->states_.at(key);
auto &t = state.t;
const T *g = param->get_grad_pointer<T>(this->ctx_);
shared_ptr<Variable> mean_ = state.mean; // To prevent compile error.
shared_ptr<Variable> var_ = state.var; // To prevent compile error.
shared_ptr<Variable> var_hat_ = state.var_hat; // To prevent compile error.
shared_ptr<Variable> mean_ = state.pstate["m"]; // To prevent compile error.
shared_ptr<Variable> var_ = state.pstate["v"]; // To prevent compile error.
shared_ptr<Variable> var_hat_ =
state.pstate["v_hat"]; // To prevent compile error.
T *m = mean_->cast_data_and_get_pointer<T>(this->ctx_);
T *v = var_->cast_data_and_get_pointer<T>(this->ctx_);
T *v_hat = var_hat_->cast_data_and_get_pointer<T>(this->ctx_);
T *theta = param->cast_data_and_get_pointer<T>(this->ctx_);
t = std::min(t + 1, std::numeric_limits<int>::max());
t = std::min(t + 1, std::numeric_limits<uint32_t>::max() - 1);
const T bias_correction = std::sqrt(1 - std::pow(this->beta2_, t)) /
(1 - std::pow(this->beta1_, t));
const T alpha_t =
Expand Down
5 changes: 4 additions & 1 deletion src/nbla/cuda/solver/generic/momentum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ template <typename T>
void MomentumCuda<T>::update_impl(const string &key, VariablePtr param) {
cuda_set_device(std::stoi(this->ctx_.device_id));
Size_t size = param->size();
VariablePtr r_ = this->state_.at(key);
auto &state = this->states_.at(key);
VariablePtr r_ = state.pstate["m"];
const T *grad = param->get_grad_pointer<T>(this->ctx_);
T *v = r_->cast_data_and_get_pointer<T>(this->ctx_);
T *data = param->cast_data_and_get_pointer<T>(this->ctx_);
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_momentum_update, size, data, grad, v,
this->lr_, this->momentum_);
auto &t = state.t;
t = std::min(t + 1, std::numeric_limits<uint32_t>::max() - 1);
}

NBLA_DEF_WEIGHT_DECAY(MomentumCuda, weight_decay_cuda);
Expand Down
5 changes: 4 additions & 1 deletion src/nbla/cuda/solver/generic/nesterov.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ __global__ void kernel_nesterov_update(const int num, T *data, const T *grad,
template <typename T>
void NesterovCuda<T>::update_impl(const string &key, VariablePtr param) {
Size_t size = param->size();
VariablePtr v_ = this->state_.at(key);
auto &state = this->states_.at(key);
VariablePtr v_ = state.pstate["m"];
T *v = v_->cast_data_and_get_pointer<T>(this->ctx_);
const T *grad = param->get_grad_pointer<T>(this->ctx_);
T *data = param->cast_data_and_get_pointer<T>(this->ctx_);
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_nesterov_update, size, data, grad, v,
this->lr_, this->momentum_);
auto &t = state.t;
t = std::min(t + 1, std::numeric_limits<uint32_t>::max() - 1);
}

NBLA_DEF_WEIGHT_DECAY(NesterovCuda, weight_decay_cuda);
Expand Down
7 changes: 5 additions & 2 deletions src/nbla/cuda/solver/generic/rmsprop.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ __global__ void kernel_rmsprop_update(const int num, T *data, const T *grad,
template <typename T>
void RMSpropCuda<T>::update_impl(const string &key, VariablePtr param) {
Size_t size = param->size();
VariablePtr state = this->state_.at(key);
T *e_sqr_grad = state->cast_data_and_get_pointer<T>(this->ctx_);
auto &state = this->states_.at(key);
VariablePtr v = state.pstate["v"];
T *e_sqr_grad = v->cast_data_and_get_pointer<T>(this->ctx_);
const T *grad = param->get_grad_pointer<T>(this->ctx_);
T *data = param->cast_data_and_get_pointer<T>(this->ctx_);
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_rmsprop_update, size, data, grad,
e_sqr_grad, this->lr_, this->decay_,
this->eps_);
auto &t = state.t;
t = std::min(t + 1, std::numeric_limits<uint32_t>::max() - 1);
}

NBLA_DEF_WEIGHT_DECAY(RMSpropCuda, weight_decay_cuda);
Expand Down
3 changes: 3 additions & 0 deletions src/nbla/cuda/solver/generic/sgd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ void SgdCuda<T>::update_impl(const string &key, VariablePtr param) {
const T *grad = param->get_grad_pointer<T>(this->ctx_);
T *data = param->cast_data_and_get_pointer<T>(this->ctx_);
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_update, size, data, grad, this->lr_);
auto &state = this->states_.at(key);
auto &t = state.t;
t = std::min(t + 1, std::numeric_limits<uint32_t>::max() - 1);
}

NBLA_DEF_WEIGHT_DECAY(SgdCuda, weight_decay_cuda);
Expand Down

0 comments on commit 13b5e9f

Please sign in to comment.