Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
zetwhite committed Aug 7, 2024
1 parent 51e917c commit 8f78928
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
10 changes: 5 additions & 5 deletions runtime/onert/backend/train/ExtraTensorGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void ExtraTensorGenerator::register_tensors(ir::OperationIndex op_idx, ExtraTens
{
// register tensor
ExtraTensorIndex tensor_idx(op_idx, i);
_tensor_builder->registerExtraTensorInfo(tensor_idx, reqs[i].info);
_tensor_builder->registerExtraTensorInfo(tensor_idx, reqs[i].info());

std::stringstream op_info;
op_info << op_idx << "_" << operations.at(op_idx).name();
Expand All @@ -60,7 +60,7 @@ void ExtraTensorGenerator::register_tensors(ir::OperationIndex op_idx, ExtraTens

// return registered tensor
auto generated_tensor = _tensor_reg->getExtraTensor(tensor_idx);
*reqs[i].address = generated_tensor;
reqs[i].update_address(generated_tensor);
}
return;
}
Expand All @@ -74,7 +74,7 @@ void ExtraTensorGenerator::plan()
auto &reqs = _idx_to_requests[op_index];
for (auto i = 0u; i < reqs.size(); ++i)
{
auto &lt = reqs[i].lifetime;
const auto &lt = reqs[i].lifetime();
if (lt == ExtraTensorLifeTime::FORWARD_TO_BACKWARD)
_tensor_builder->notifyFirstUse(ExtraTensorIndex(op_index, i));
}
Expand All @@ -88,14 +88,14 @@ void ExtraTensorGenerator::plan()

for (auto i = 0u; i < reqs.size(); ++i)
{
auto &lt = reqs[i].lifetime;
const auto &lt = reqs[i].lifetime();
if (lt == ExtraTensorLifeTime::BACKWARD)
_tensor_builder->notifyFirstUse(ExtraTensorIndex(op_index, i));
}

for (auto i = 0u; i < reqs.size(); ++i)
{
auto &lt = reqs[i].lifetime;
const auto &lt = reqs[i].lifetime();
if (lt == ExtraTensorLifeTime::FORWARD_TO_BACKWARD || lt == ExtraTensorLifeTime::BACKWARD)
_tensor_builder->notifyLastUse(ExtraTensorIndex(op_index, i));
}
Expand Down
28 changes: 22 additions & 6 deletions runtime/onert/core/include/backend/train/ExtraTensorRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ class ExtraTensorRequest

public:
ExtraTensorRequest(ir::OperandInfo info, ExtraTensorLifeTime lt,
backend::train::ExtraTensor **addr)
: info(info), lifetime(lt), address(addr)
ExtraTensor **addr)
: _info(info), _lifetime(lt), _address(addr)
{
}

static ExtraTensorRequest createLike(const IPortableTensor *origin,
backend::train::ExtraTensor** addr)
ExtraTensor** addr)
{
assert(origin != nullptr);
assert(addr != nullptr);
Expand All @@ -52,9 +52,25 @@ class ExtraTensorRequest
}

public:
const ir::OperandInfo info;
const ExtraTensorLifeTime lifetime;
backend::train::ExtraTensor ** const address;
const ir::OperandInfo& info() const
{
return _info;
}

ExtraTensorLifeTime lifetime() const
{
return _lifetime;
}

void update_address(ExtraTensor* tensor)
{
*_address = tensor;
}

private:
ir::OperandInfo _info;
ExtraTensorLifeTime _lifetime;
backend::train::ExtraTensor ** const _address;
};

using ExtraTensorRequests = std::vector<ExtraTensorRequest>;
Expand Down

0 comments on commit 8f78928

Please sign in to comment.