Skip to content

Commit

Permalink
[tflchef] Extract cook_operations
Browse files Browse the repository at this point in the history
This will extract cook_operations method.

ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark committed Aug 13, 2024
1 parent 4ff27c7 commit 8aae7cc
Showing 1 changed file with 89 additions and 6 deletions.
95 changes: 89 additions & 6 deletions compiler/tflchef/core/src/ModelChef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ class ModelChef
private:
template <typename T> void cook_operands(const T &graph);

template <typename T>
void cook_operations(const T &graph, std::map<std::string, int32_t> &symbol_table);

template <typename T>
void cook_graph(const T &graph, std::map<std::string, int32_t> &symbol_table);

Expand All @@ -212,6 +215,8 @@ class ModelChef
// per graph that needs clear afer graph is processed
// Operand-related
std::vector<flatbuffers::Offset<::tflite::Tensor>> _tensor_vec;
// Operation-related
std::vector<flatbuffers::Offset<::tflite::Operator>> _operator_vec;

std::string _graph_name;
};
Expand Down Expand Up @@ -596,16 +601,89 @@ template <typename T> void ModelChef::cook_operands(const T &graph)
}
}

template <typename T>
void ModelChef::cook_operations(const T &graph, std::map<std::string, int32_t> &symbol_table)
{
auto lookup = [&](const std::string &name) {
if (symbol_table.find(name) != symbol_table.end())
return symbol_table.at(name);
else if (name == "")
return -1; // -1 in TFLite means that optional input tensor is empty.
else
{
std::string msg = "tflchef : input not found in " + _graph_name + " graph";
throw std::runtime_error(msg.c_str());
}
};

// Create Operator
for (const auto &operation : graph.operation())
{
assert(operation.has_type());

std::string op_type = operation.type();
if (not operation.custom_code().empty())
op_type = operation.custom_code();

auto op_chef = op_chef_registry().lookup(op_type).create(&operation);

// Create 'inputs'
std::vector<int32_t> input_vec = as_dataset(operation.input()).map(lookup).vectorize();
auto inputs = _flatbuffer_builder->CreateVector(input_vec);

// Create 'outputs'
std::vector<int32_t> output_vec = as_dataset(operation.output()).map(lookup).vectorize();
auto outputs = _flatbuffer_builder->CreateVector(output_vec);

// Create Option
auto options = op_chef->value(*_flatbuffer_builder);

// Create Custom option
auto circle_custom_options = op_chef->custom_value(*_flatbuffer_builder);

// Create Operator
tflite::OperatorBuilder op_builder{*_flatbuffer_builder};

// Note that opcode_index is an index into the operator_codes vector.
// operator_codes consists of buildtin_code and custom_code, which is inserted sequentially.
uint32_t opcode_index = 0;
auto op_it = _builtin_code_map.find(op_chef->code());
// builtin operator
if (op_it != _builtin_code_map.end())
{
opcode_index = std::distance(_builtin_code_map.begin(), op_it);
}
// custom operator
else
{
assert(not operation.custom_code().empty());
auto custom_code = operation.custom_code();
auto op_it = std::find(_custom_code_vec.begin(), _custom_code_vec.end(), custom_code);
assert(op_it != _custom_code_vec.end());
opcode_index = _builtin_code_map.size();
opcode_index += std::distance(_custom_code_vec.begin(), op_it);
}

op_builder.add_opcode_index(opcode_index);
op_builder.add_inputs(inputs);
op_builder.add_outputs(outputs);
op_builder.add_builtin_options_type(op_chef->type());
op_builder.add_builtin_options(options);
op_builder.add_custom_options(circle_custom_options);
op_builder.add_custom_options_format(tflite::CustomOptionsFormat_FLEXBUFFERS);
// Append Operator
_operator_vec.emplace_back(op_builder.Finish());
}
}

template <typename T>
void ModelChef::cook_graph(const T &graph, std::map<std::string, int32_t> &symbol_table)
{
LOGGER(l);

assert(symbol_table.empty()); // FIX_CALLER_UNLESS
assert(_tensor_vec.empty()); // FIX_CALLER_UNLESS

// Operation-related
std::vector<flatbuffers::Offset<::tflite::Operator>> operator_vec;
assert(symbol_table.empty()); // FIX_CALLER_UNLESS
assert(_tensor_vec.empty()); // FIX_CALLER_UNLESS
assert(_operator_vec.empty()); // FIX_CALLER_UNLESS

// default name for graph
std::string graph_name = _graph_name;
Expand Down Expand Up @@ -984,6 +1062,7 @@ void ModelChef::cook_graph(const T &graph, std::map<std::string, int32_t> &symbo
symbol_table[tensor_name] = tensor_index;
}

#if 0
// Create Operator
for (const auto &operation : graph.operation())
{
Expand Down Expand Up @@ -1042,6 +1121,9 @@ void ModelChef::cook_graph(const T &graph, std::map<std::string, int32_t> &symbo
// Append Operator
operator_vec.emplace_back(op_builder.Finish());
}
#endif

cook_operations(graph, symbol_table);

// Create network input/output vector
std::vector<int32_t> input_vec = as_dataset(graph.input()).map(lookup).vectorize();
Expand All @@ -1051,7 +1133,7 @@ void ModelChef::cook_graph(const T &graph, std::map<std::string, int32_t> &symbo
auto tensors = _flatbuffer_builder->CreateVector(_tensor_vec);
auto inputs = _flatbuffer_builder->CreateVector(input_vec);
auto outputs = _flatbuffer_builder->CreateVector(output_vec);
auto operators = _flatbuffer_builder->CreateVector(operator_vec);
auto operators = _flatbuffer_builder->CreateVector(_operator_vec);
auto name = _flatbuffer_builder->CreateString(graph_name);

tflite::SubGraphBuilder subgraph_builder{*_flatbuffer_builder};
Expand Down Expand Up @@ -1141,6 +1223,7 @@ void ModelChef::cook(const ::tflchef::ModelRecipe &model_recipe)

symbol_table.clear();
_tensor_vec.clear();
_operator_vec.clear();
cook_graph<::tflchef::Graph>(graph, symbol_table);
_symbol_tables.push_back(symbol_table);
}
Expand Down

0 comments on commit 8aae7cc

Please sign in to comment.