Skip to content

Commit

Permalink
[Mixed Precision] Fix mixed precsion to use Tensor V2
Browse files Browse the repository at this point in the history
This PR includes fixes to use TensorV2

Resolves:

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <[email protected]>
  • Loading branch information
jijoongmoon committed Nov 11, 2024
1 parent a4a3750 commit eb7cf07
Show file tree
Hide file tree
Showing 31 changed files with 139 additions and 148 deletions.
15 changes: 8 additions & 7 deletions api/ccapi/include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,14 @@ class Model {
* @details This function accepts vector of properties in the format -
* { std::string property_name, void * property_val, ...}
*/
virtual int train(const std::vector<std::string> &values = {},
std::function<bool(void *)> stop_cb =
[](void *stop_user_data) { return false; },
void *stop_user_data = nullptr,
std::function<void(void *)> epoch_complete_cb =
[](void *epoch_user_data) { return false; },
void *epoch_user_data = nullptr) = 0;
virtual int train(
const std::vector<std::string> &values = {},
std::function<bool(void *)> stop_cb =
[](void *stop_user_data) { return false; },
void *stop_user_data = nullptr,
std::function<void(void *)> epoch_complete_cb =
[](void *epoch_user_data) { return false; },
void *epoch_user_data = nullptr) = 0;

/**
* @brief Run Model train with callback function by user
Expand Down
18 changes: 11 additions & 7 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,18 @@ warning_c_flags = [

arch = host_machine.cpu_family()

target = target_machine.cpu_family()

if get_option('enable-avx')
extra_defines += '-DUSE_AVX=1'
if get_option('platform') == 'tizen'
add_project_arguments(['-mavx2'], language: ['c','cpp'])
else
add_project_arguments(['-march=native'], language: ['c','cpp'])
endif
message('-march=native added for AVX hardware acceleration.')
if get_option('platform') != 'android'
if target == 'x86_64' or target == 'x86'
extra_defines += '-DUSE_AVX=1'
add_project_arguments(['-march=native'], language: ['c','cpp'])
add_project_arguments(['-mavx2'], language: ['c','cpp'])
message('-march=native added for AVX hardware acceleration.')
endif
message('This arch does not support avx2')
endif
endif

if get_option('enable-fp16')
Expand Down
4 changes: 3 additions & 1 deletion nntrainer/app_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ AppContext::registerPluggableFromDirectory(const std::string &base_path) {
struct dirent *entry;

std::vector<int> keys;

while ((entry = readdir(dir)) != NULL) {
if (endswith(entry->d_name, solib_suffix)) {
if (endswith(entry->d_name, layerlib_suffix)) {
Expand All @@ -581,7 +582,8 @@ AppContext::registerPluggableFromDirectory(const std::string &base_path) {
}
}

closedir(dir);
if (dir != NULL)
closedir(dir);

return keys;
}
Expand Down
2 changes: 1 addition & 1 deletion nntrainer/graph/graph_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void GraphCore::topologicalSort() {
if (Sorted.size() != node_list.size())
throw std::runtime_error("Internal error in topologicalSort");
unsigned int idx = 0;
for (auto n : Sorted) {
for (auto &n : Sorted) {
sorted_node_map[n->getName()] = idx;
idx++;
}
Expand Down
13 changes: 5 additions & 8 deletions nntrainer/graph/network_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ void NetworkGraph::applyGradients(
/**
* @note the weights whose gradient are to be clipped by global norm will
* be clipped at once at the end of iteration and applied then.
* For those weights where mixed precision is uesed, their gradient
* updates might be delayed until they confirm whether their loss scales
* are appropeiate.
*/
continue;
}
Expand Down Expand Up @@ -438,7 +441,7 @@ bool NetworkGraph::backwarding(
*/
float scale = (*iter_)->getRunContext().getLossScale();

NNTR_THROW_IF(scale == 1.0f, std::invalid_argument)
NNTR_THROW_IF(scale - 1.0f < 10e-6, std::invalid_argument)
<< "Loss Scale Factor is 1.0f";

float s = scale > 1.5f ? scale * 0.5f : 1.0f;
Expand Down Expand Up @@ -487,18 +490,12 @@ bool NetworkGraph::backwarding(
}
}
/** apply the gradient with the above global norm */
std::cout << "======================================= update gradient "
<< std::endl;
for (auto w : lazy_weights) {
std::cout << w->getName() << " : ";
lazy_apply_grad_op(*w, iteration);
}
nan_count++;

std::cout << "====================================== update gradient finished"
<< std::endl;
/** @todo : handle as property : growth_interval : default --> 2000 */

if (nan_count > 2000) {
float scale = (*iter_)->getRunContext().getLossScale();
/** @todo growth_factor : default --> 2.0 */
Expand Down Expand Up @@ -1647,7 +1644,7 @@ void NetworkGraph::requestOptimizerVariable(
w->setOptimizerVariables(tensor_manager->requestWeightOptimizerVariables(
dims, w->getName(), ":opt", TensorLifespan::MAX_LIFESPAN,
w->isGradientClipByGlobalNorm(), w->isMixedPrecision(),
Tensor::Initializer::ZEROS));
Initializer::ZEROS));
}
}
}
Expand Down
14 changes: 12 additions & 2 deletions nntrainer/graph/network_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ class NetworkGraph {
/**
* @brief Constructor of NeuralNetwork Graph Class
* @param[in] enable_swap enable memory swap for tensor
* @param[in] mode execution mode (default ExecutionMode::TRAIN)
* @param[in] swap_path memory swap file path when the swap is enabled
* @param[in] tensor_format define tensor format. One of NCHW and NHWC
* (default NCHW)
* @param[in] tensor_type It says weight type and activation type (default
* FP32-FP32)
*/
NetworkGraph(bool enable_swap, ExecutionMode mode = ExecutionMode::TRAIN,
const std::string &swap_path = "", unsigned int lookahead = 0,
Expand Down Expand Up @@ -207,8 +212,12 @@ class NetworkGraph {
/**
* @brief backwarding the network graph
* @param[in] iteration current iteration number
* @param[in] forwarding_op operation for the forwarding
* @param[in] backwarding_op operation for the backwarding
* @param[in] apply_grad_clip_op operation for applying the clip gradients
* @param[in] lazy_apply_grad_op operation for applying the lazy gradients
* @retval ret it is false then the gradient has NaN valude in mixed precision
* training. If it is, then we need to control the loss scale factor and
* compute again the derivatives.
*/
bool backwarding(
int iteration,
Expand Down Expand Up @@ -496,7 +505,8 @@ class NetworkGraph {
std::unordered_map<std::string, int>
profile_keys; /**< profile keys based on the layer type */
std::vector<Weight *>
lazy_weights; /**< weights with global norm based clipping enabled */
lazy_weights; /**< weights with delayed grad update, e.g., gradient
clipping, loss scaling */
bool is_clip_grad;

unsigned int nan_count;
Expand Down
18 changes: 9 additions & 9 deletions nntrainer/layers/bn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
1.0f, bias_decay, "beta", true);

wt_idx[BNParams::mu_b] =
context.requestTensor(dim, "moviing_mean_backup", Tensor::Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);
context.requestTensor(dim, "moviing_mean_backup", Initializer::NONE, false,
TensorLifespan::ITERATION_LIFESPAN);

wt_idx[BNParams::var_b] = context.requestTensor(
dim, "moviing_variance_backup", Tensor::Initializer::NONE, false,
TensorLifespan::ITERATION_LIFESPAN);
wt_idx[BNParams::var_b] =
context.requestTensor(dim, "moviing_variance_backup", Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);

/**
* caches the deviation -> input - avg(input)
Expand All @@ -137,8 +137,8 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
}

wt_idx[BNParams::deviation] =
context.requestTensor(in_dim_, "deviation", Tensor::Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);
context.requestTensor(in_dim_, "deviation", Initializer::NONE, false,
TensorLifespan::ITERATION_LIFESPAN);
/** caches the inverse standard deviation */
wt_idx[BNParams::invstd] =
context.requestTensor(dim, "invstd", Initializer::NONE, false,
Expand All @@ -150,8 +150,8 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
* as the output of this layer need not be stored all the time.
*/
wt_idx[BNParams::t_full] =
context.requestTensor(in_dim_, "tensor_full", Tensor::Initializer::NONE,
false, TensorLifespan::CALC_DERIV_LIFESPAN);
context.requestTensor(in_dim_, "tensor_full", Initializer::NONE, false,
TensorLifespan::CALC_DERIV_LIFESPAN);
/**
* caches variance + epsilon as well.
*/
Expand Down
2 changes: 2 additions & 0 deletions nntrainer/layers/conv2d_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ static void im2col(const Tensor &in, const TensorDim &kdim,
unsigned int base_im_h = 0;
int patch_height_end = eff_k_height + hs;
/// map the patch to a single line looping through channel
// We need to optimize this padding & copy. May be use multi threads, or
// SIMD
for (unsigned int c = 0; c < channel; ++c) {
for (int h = hs; h < patch_height_end; h += dilation[0]) {
if (h < 0 || in_height <= h) {
Expand Down
6 changes: 5 additions & 1 deletion nntrainer/layers/layer_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class InitLayerContext {
* @param name name
* @param prefix_ prefix
* @param max_norm max norm
* @param tensor_type array including tensor format and weight, activation
* type.
* @param loss_scale loss scale value for mixed precision training
* @param mode execution mode.
*/
InitLayerContext(
const std::vector<TensorDim> &dim,
Expand Down Expand Up @@ -220,7 +224,7 @@ class InitLayerContext {
* start from 0 and will always be incremental.
*/
unsigned int requestWeight(const TensorDim &dim, const TensorDim &dim_g,
const Tensor::Initializer init,
const Initializer init,
const WeightRegularizer reg, const float reg_const,
const float decay, const std::string &name,
bool trainable = true, unsigned int out_axis = 3) {
Expand Down
24 changes: 12 additions & 12 deletions nntrainer/layers/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,16 +512,16 @@ void LSTMLayer::finalize(InitLayerContext &context) {
const TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit,
activation_tensor_type);

wt_idx[LSTMParams::hidden_state] = context.requestTensor(
hidden_state_dim, "hidden_state", Tensor::Initializer::NONE, true,
TensorLifespan::ITERATION_LIFESPAN);
wt_idx[LSTMParams::hidden_state] =
context.requestTensor(hidden_state_dim, "hidden_state", Initializer::NONE,
true, TensorLifespan::ITERATION_LIFESPAN);
// cell_state_dim : [ batch_size, 1, max_timestep, unit ]
const TensorDim cell_state_dim(batch_size, 1, max_timestep, unit,
activation_tensor_type);

wt_idx[LSTMParams::cell_state] = context.requestTensor(
cell_state_dim, "cell_state", Tensor::Initializer::NONE, true,
TensorLifespan::ITERATION_LIFESPAN);
wt_idx[LSTMParams::cell_state] =
context.requestTensor(cell_state_dim, "cell_state", Initializer::NONE, true,
TensorLifespan::ITERATION_LIFESPAN);

// ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ]
const TensorDim ifgo_dim(batch_size, 1, max_timestep, NUM_GATE * unit,
Expand Down Expand Up @@ -594,18 +594,18 @@ void LSTMLayer::finalize(InitLayerContext &context) {
// reverse_ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ]
const TensorDim reverse_ifgo_dim(batch_size, 1, max_timestep,
NUM_GATE * unit, activation_tensor_type);
wt_idx[LSTMParams::reverse_ifgo] = context.requestTensor(
reverse_ifgo_dim, "reverse_ifgo", Tensor::Initializer::NONE, true,
TensorLifespan::ITERATION_LIFESPAN);
wt_idx[LSTMParams::reverse_ifgo] =
context.requestTensor(reverse_ifgo_dim, "reverse_ifgo", Initializer::NONE,
true, TensorLifespan::ITERATION_LIFESPAN);
}

if (dropout_rate > epsilon) {
// dropout_mask_dim = [ batch, 1, time_iteration, unit ]
const TensorDim dropout_mask_dim(batch_size, 1, max_timestep, unit,
activation_tensor_type);
wt_idx[LSTMParams::dropout_mask] = context.requestTensor(
dropout_mask_dim, "dropout_mask", Tensor::Initializer::NONE, false,
TensorLifespan::ITERATION_LIFESPAN);
wt_idx[LSTMParams::dropout_mask] =
context.requestTensor(dropout_mask_dim, "dropout_mask", Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);
}

if (context.getActivationDataType() == TensorDim::DataType::FP32) {
Expand Down
8 changes: 4 additions & 4 deletions nntrainer/layers/pooling2d_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,15 @@ void Pooling2DLayer::finalize(InitLayerContext &context) {
auto helper_dim = in_dim;
helper_dim.setDataType(ml::train::TensorDim::DataType::FP32);
pool_helper_idx =
context.requestTensor(helper_dim, "helper_idx", Tensor::Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);
context.requestTensor(helper_dim, "helper_idx", Initializer::NONE, false,
TensorLifespan::ITERATION_LIFESPAN);
pool_helper_size.resize(helper_dim.batch() * helper_dim.channel());
} else {
auto helper_dim = out_dim;
helper_dim.setDataType(ml::train::TensorDim::DataType::FP32);
pool_helper_idx =
context.requestTensor(helper_dim, "helper_idx", Tensor::Initializer::NONE,
false, TensorLifespan::ITERATION_LIFESPAN);
context.requestTensor(helper_dim, "helper_idx", Initializer::NONE, false,
TensorLifespan::ITERATION_LIFESPAN);
}
}

Expand Down
1 change: 0 additions & 1 deletion nntrainer/models/neuralnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,6 @@ int NeuralNetwork::train_run(
auto epochs = getEpochs();
ml_logd("[NNTrainer] Starts training. Current epoch: %d. Total epochs: %d.",
epoch_idx + 1, getEpochs());
epoch_idx = 0;
for (epoch_idx = epoch_idx + 1; epoch_idx <= epochs; ++epoch_idx) {
if (stop_cb(stop_user_data)) {
--epoch_idx;
Expand Down
3 changes: 1 addition & 2 deletions nntrainer/tensor/blas_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -874,8 +874,7 @@ void scopy(const unsigned int N, const float *X, const int incX, float *Y,
#ifdef BLAS_NUM_THREADS
openblas_set_num_threads(BLAS_NUM_THREADS);
#endif
// cblas_scopy(N, (float*)(X), incX, (float*)(Y), incY);
// replace cblas scopy with raw temporary.
// cblas_scopy(N, X, incX, Y, incY);
for (unsigned int i = 0; i < N; ++i)
Y[i * incY] = X[i * incX];
#else
Expand Down
5 changes: 5 additions & 0 deletions nntrainer/tensor/char_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ class CharTensor : public TensorBase {
* @return std::string of tensor data type (QINT8)
*/
std::string getStringDataType() const override { return "QINT8"; }

/**
* @copydoc Tensor::isValid()
*/
bool isValid() const override { return true; }; // NYI
};

} // namespace nntrainer
Expand Down
6 changes: 3 additions & 3 deletions nntrainer/tensor/float_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void FloatTensor::setZero() {
// sscal(size(), 0, getData<float>(), 1);
/// @note we cannot use sscal, when we set zero. if the data is inf or
/// NaN, then the inf or NaN still remain.
memset(getData<float>(), 0, sizeof(float) * size());
memset((float *)getData(), 0, sizeof(float) * size());
} else {
/// @todo implement apply_i
// apply_i<float>([](float val) -> float { return 0; });
Expand Down Expand Up @@ -1210,8 +1210,8 @@ void FloatTensor::apply_broadcast(
return apply_broadcast_util(m, v_func, output, this->computeBroadcastInfo(m));
}

bool Tensor::isValid() const {
return is_valid(dim.getDataLen(), Tdatatype::FP32, getData<float>());
bool FloatTensor::isValid() const {
return is_valid(dim.getDataLen(), Tdatatype::FP32, (float *)getData());
}

} // namespace nntrainer
2 changes: 1 addition & 1 deletion nntrainer/tensor/float_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ class FloatTensor : public TensorBase {
/**
* @copydoc Tensor::isValid()
*/
bool Tensor::isValid() const;
bool isValid() const override;
};

} // namespace nntrainer
Expand Down
6 changes: 3 additions & 3 deletions nntrainer/tensor/half_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void HalfTensor::setZero() {
// sscal(size(), 0, (_FP16 *)getData(), 1);
/// @note we cannot use sscal, when we set zero. if the data is inf or
/// NaN, then the inf or NaN still remain.
memset(getData<_FP16>(), 0, sizeof(_FP16) * size());
memset((_FP16 *)getData(), 0, sizeof(_FP16) * size());
} else {
/// @todo implement apply_i
// apply_i<_FP16>([](_FP16 val) -> _FP16 { return 0; });
Expand Down Expand Up @@ -1176,8 +1176,8 @@ void HalfTensor::apply_broadcast_util(
}
}

bool Tensor::isValid() const {
return is_valid(dim.getDataLen(), Tdatatype::FP16, getData<_FP16>());
bool HalfTensor::isValid() const {
return is_valid(dim.getDataLen(), Tdatatype::FP16, (_FP16 *)getData());
}

} // namespace nntrainer
2 changes: 1 addition & 1 deletion nntrainer/tensor/half_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class HalfTensor : public TensorBase {
/**
* @copydoc Tensor::isValid()
*/
bool Tensor::isValid() const;
bool isValid() const override;
};

} // namespace nntrainer
Expand Down
Loading

0 comments on commit eb7cf07

Please sign in to comment.