Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
zetwhite committed Jul 29, 2024
1 parent f8dcd0e commit dde2bbf
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 7 deletions.
4 changes: 4 additions & 0 deletions runtime/onert/backend/train/ExtraTensorGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ void ExtraTensorGenerator::generate(ir::OperationIndex op_idx, const ExtraTensor
auto generated_tensor = _tensor_reg->getExtraTensor(tensor_idx);
*reqs[i].address = generated_tensor;
}

// To avoid unused error
// temporal
_tgraph.getInputs();
return;
}

Expand Down
12 changes: 6 additions & 6 deletions runtime/onert/backend/train/ops/DepthwiseConvolutionLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void DepthwiseConvolutionLayer::configureBackward(IPortableTensor *back_prop_inp
*/
}

ExtraTensorRequests requestExtraTensors()
ExtraTensorRequests DepthwiseConvolutionLayer::requestExtraTensors()
{
ExtraTensorRequests reqs;

Expand All @@ -120,7 +120,7 @@ ExtraTensorRequests requestExtraTensors()
const auto batch = incoming_shape.Dims(0);
const auto depth = incoming_shape.Dims(3);

// const auto filter_shape = getShape(_kernel);
const auto filter_shape = getShape(_kernel);
const int filter_rows = filter_shape.Dims(1);
const int filter_cols = filter_shape.Dims(2);
const int filter_spatial_size = filter_rows * filter_cols;
Expand All @@ -135,23 +135,23 @@ ExtraTensorRequests requestExtraTensors()
{
auto type_info = _kernel->get_info().typeInfo();
ir::Shape shape({batch, filter_spatial_size, padded_filter_inner_dim_size});
auto info = ir::OperandInfo::createStaticInfo(shape, type_info)
reqs.emplace_back(info, ExtraTensorLifeTime::BACKWARD, &_padded_filter);
auto info = ir::OperandInfo::createStaticInfo(shape, type_info);
reqs.emplace_back(info, ExtraTensorLifeTime::BACKWARD, &_padded_filter);
}

// _filter_buffers
{
auto type_info = _kernel->get_info().typeInfo();
ir::Shape shape({thread_count, filter_spatial_size, padded_filter_inner_dim_size});
auto info = ir::OperandIndex::createStaticInfo(shape, type_info);
auto info = ir::OperandInfo::createStaticInfo(shape, type_info);
reqs.emplace_back(info, ExtraTensorLifeTime::BACKWARD, &_filter_buffers);
}

// _filter_dim_buffers
{
auto type = _back_prop_input->get_info().typeInfo();
ir::Shape shape({thread_count, padded_filter_inner_dim_size});
auto info = ir::OperandIndex::createStaticInfo(shape, type);
auto info = ir::OperandInfo::createStaticInfo(shape, type);
reqs.emplace_back(info, ExtraTensorLifeTime::BACKWARD, &_filter_dim_buffers);
}

Expand Down
4 changes: 3 additions & 1 deletion runtime/onert/backend/train/ops/FullyConnectedLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void FullyConnectedLayer::configureBackward(
}
}

ExtraTensorRequests requestExtraTensors()
ExtraTensorRequests FullyConnectedLayer::requestExtraTensors()
{
ExtraTensorRequests reqs;

Expand All @@ -112,6 +112,8 @@ ExtraTensorRequests requestExtraTensors()
reqs.push_back(
ExtraTensorRequest::createRequestLike(_back_prop_output, &_act_back_prop_output));
}

return reqs;
}

void FullyConnectedLayer::forward(bool) { cpu::ops::FullyConnectedLayer::run(); }
Expand Down
5 changes: 5 additions & 0 deletions runtime/onert/core/include/backend/train/ExtraTensorRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ class ExtraTensorRequest
return ExtraTensorRequest(o.info, o.lifetime, o.address);
}

ExtraTensorRequest(const ExtraTensorRequest &o)
: info(o.info), lifetime(o.lifetime), address(o.address)
{
}

public:
ir::OperandInfo info;
ExtraTensorLifeTime lifetime;
Expand Down

0 comments on commit dde2bbf

Please sign in to comment.