Skip to content

Commit

Permalink
[CodeGen] Check odla status after odla API calls
Browse files Browse the repository at this point in the history
  • Loading branch information
Weiming Zhao authored and weimingzha0 committed Aug 10, 2021
1 parent 78d22f6 commit 9a6a683
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 29 deletions.
69 changes: 42 additions & 27 deletions lib/target/generic_cpp/generic_cxx_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ std::string GenericCXXCodeGen::GetFunctionDecl(const Function& func,
bool with_type,
bool public_function) {
const static std::string inference_func_decl =
"void model_run(int num_inputs, const void* inputs[],"
"int model_run(int num_inputs, const void* inputs[],"
"int num_outputs, void* outputs[], int batch_size)";
if (opts_.emit_inference_func_sig && func.IsEntryFunction() &&
public_function) {
Expand All @@ -221,7 +221,7 @@ std::string GenericCXXCodeGen::GetFunctionDecl(const Function& func,
ss << "static ";
}
if (with_func_name) {
ss << "void " << NormalizeVariableName(func.GetName());
ss << "int " << NormalizeVariableName(func.GetName());
}
ss << "(";
if (is_sub) {
Expand Down Expand Up @@ -612,10 +612,13 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
oss << "extern \"C\" {\n";
}
oss << " " << func_decl << ";\n";
oss << "void " << init_func_name << "();\n";
oss << "void " << fini_func_name << "();\n";
oss << (is_compile_mode ? "odla_computation " : "static void ")
<< helper_func_name << "();\n";
oss << "int " << init_func_name << "();\n";
oss << "int " << fini_func_name << "();\n";
if (is_compile_mode) {
oss << "int " << helper_func_name << "(odla_computation comp);\n";
} else {
oss << "static void " << helper_func_name << "()\n;";
}
if (opts_.dialect == Dialect::CXX_11) {
oss << "};\n";
}
Expand All @@ -626,9 +629,7 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
if (emit_builder_func) {
if (is_compile_mode) {
os_ << "static odla_computation Comp;\n";
os_ << "odla_computation " << helper_func_name << "() {\n";
os_ << " odla_computation comp;\n";
os_ << " odla_CreateComputation(&comp);\n";
os_ << "int " << helper_func_name << "(odla_computation comp) {\n";
EmitComputationItems(&os_, opts_);
} else {
os_ << "static void " << helper_func_name
Expand All @@ -641,11 +642,6 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
os_ << " odla_SetCurrentDevice(device);";
}

if (is_compile_mode) {
os_ << " static odla_computation comp;\n";
os_ << " if (comp == " << EmitNull() << ") {\n";
os_ << " odla_CreateComputation(&comp);\n";
}
EmitComputationItems(&os_, opts_);
}

Expand Down Expand Up @@ -675,7 +671,7 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
RunOnBasicBlock(*bb);
}
if (is_compile_mode) {
os_ << " return comp;\n";
os_ << " return ODLA_SUCCESS;\n";
}

os_ << "}\n"; // End of computation build function.
Expand All @@ -684,28 +680,41 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
dynamic_check_os_ << GenerateTestFunc(function, func_decl, *return_inst);
}

const std::string& status_check{
"if (status != ODLA_SUCCESS) { return status;}"};

if (emit_builder_func) {
// Emit function for launching computation.
if (opts_.exec_mode == ExecMode::Compile) {
if (function.IsEntryFunction()) {
os_ << "void " << fini_func_name << "(){\n";
os_ << " odla_DestroyComputation(Comp);\n";
os_ << "int " << fini_func_name << "(){\n";
os_ << " if (Comp !=" << EmitNull() << ") {";
os_ << " return odla_DestroyComputation(Comp);}\n";
os_ << " return ODLA_SUCCESS;\n";
os_ << "}\n";

os_ << "void " << init_func_name << "(){\n";
os_ << "int " << init_func_name << "(){\n";
} else {
os_ << GetFunctionDecl(function, *return_inst, true, true, true)
<< " {\n";
}
os_ << " if (Comp == " << EmitNull() << ") { Comp = " << helper_func_name
<< "(); }\n";
os_ << " odla_status status = ODLA_SUCCESS;\n";
os_ << " if (Comp == " << EmitNull() << ") { \n";
os_ << " status = odla_CreateComputation(&Comp);\n";
os_ << " " << status_check << "\n";
os_ << " status = (odla_status)" << helper_func_name << "(Comp);\n";
os_ << " }\n";
os_ << " return status;\n";
os_ << "}\n";
}

if (function.IsEntryFunction()) {
os_ << GetFunctionDecl(function, *return_inst, true, true, true)
<< " {\n";
if (opts_.exec_mode == ExecMode::Compile) {
os_ << " " << init_func_name << "();\n";
os_ << " odla_status status = ODLA_SUCCESS;\n";
os_ << " status = (odla_status)" << init_func_name << "();\n";
os_ << " " << status_check << "\n";
}
}
if (opts_.exec_mode == ExecMode::Interpret) {
Expand Down Expand Up @@ -740,11 +749,14 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {

if (opts_.exec_mode == ExecMode::Compile) {
os_ << " static odla_context Ctx;\n";
os_ << " if (Ctx == " << EmitNull()
<< ") { odla_CreateContext(&Ctx); };\n";
os_ << " if (Ctx == " << EmitNull() << ") {";
os_ << " status = odla_CreateContext(&Ctx);\n";
os_ << " " << status_check << "\n";
os_ << " }\n";
if (opts_.emit_dynamic_batch) {
os_ << "odla_SetContextItem(Ctx, ODLA_RUN_BATCH_SIZE, "
os_ << " status = odla_SetContextItem(Ctx, ODLA_RUN_BATCH_SIZE, "
"(odla_item_value) &batch_size);\n";
os_ << " " << status_check << "\n";
}
}
int index = 0;
Expand All @@ -755,11 +767,13 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
? (is_sub ? "inputs.values[" : "inputs[") +
std::to_string(index++) + "]"
: cv.name;
os_ << (is_sub ? " odla_BindValueToArgumentById("
os_ << " status = "
<< (is_sub ? " odla_BindValueToArgumentById("
: " odla_BindToArgumentById(")
<< Join("(const odla_value_id)\"" + arg->GetName() + "\"", arg_name,
"Ctx")
<< ");\n";
os_ << " " << status_check << "\n";
}
index = 0;
// Pre-launch binding.
Expand All @@ -769,12 +783,13 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
? (is_sub ? "outputs.values[" : "outputs[") +
std::to_string(index++) + "]"
: "out_" + cv.name;
os_ << " odla_Bind" << (is_sub ? "Value" : "") << "ToOutputById("
os_ << " status = odla_Bind" << (is_sub ? "Value" : "") << "ToOutputById("
<< Join("(const odla_value_id)\"" + cv.name + "\"", arg_name, "Ctx")
<< ");\n";
os_ << " " << status_check << "\n";
}
if (opts_.exec_mode == ExecMode::Compile) {
os_ << " odla_ExecuteComputation(Comp, Ctx, "
os_ << " return odla_ExecuteComputation(Comp, Ctx, "
"ODLA_COMPUTE_INFERENCE, "
<< EmitNull() << ");\n";
}
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_cxx_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

// GEN: static odla_computation Comp;

// GEN: void func(const float input[3], float out_add1[3]) {
// GEN: int func(const float input[3], float out_add1[3]) {
// GEN: func_init();
// GEN: odla_BindToArgumentById((const odla_value_id)"input", input, Ctx);
// GEN: odla_BindToOutputById((const odla_value_id)"add1", out_add1, Ctx);
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_cxx_gen_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

// GEN: static odla_computation Comp;

// GEN: void func(const float input[3], float out_add1[3]) {
// GEN: int func(const float input[3], float out_add1[3]) {
// GEN: func_init();
// GEN: odla_BindToArgumentById((const odla_value_id)"input", input, Ctx);
// GEN: odla_BindToOutputById((const odla_value_id)"add1", out_add1, Ctx);
Expand Down

0 comments on commit 9a6a683

Please sign in to comment.