From 9a6a6831cc893b655e3ed30f240e813c07cba0f6 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Fri, 6 Aug 2021 23:51:42 +0000 Subject: [PATCH] [CodeGen] Check odla status after odla API calls --- lib/target/generic_cpp/generic_cxx_codegen.cc | 69 +++++++++++-------- tests/compile/test_cxx_gen.cc | 2 +- tests/compile/test_cxx_gen_gpu.cc | 2 +- 3 files changed, 44 insertions(+), 29 deletions(-) diff --git a/lib/target/generic_cpp/generic_cxx_codegen.cc b/lib/target/generic_cpp/generic_cxx_codegen.cc index 53c69bf0b..ced389392 100644 --- a/lib/target/generic_cpp/generic_cxx_codegen.cc +++ b/lib/target/generic_cpp/generic_cxx_codegen.cc @@ -208,7 +208,7 @@ std::string GenericCXXCodeGen::GetFunctionDecl(const Function& func, bool with_type, bool public_function) { const static std::string inference_func_decl = - "void model_run(int num_inputs, const void* inputs[]," + "int model_run(int num_inputs, const void* inputs[]," "int num_outputs, void* outputs[], int batch_size)"; if (opts_.emit_inference_func_sig && func.IsEntryFunction() && public_function) { @@ -221,7 +221,7 @@ std::string GenericCXXCodeGen::GetFunctionDecl(const Function& func, ss << "static "; } if (with_func_name) { - ss << "void " << NormalizeVariableName(func.GetName()); + ss << "int " << NormalizeVariableName(func.GetName()); } ss << "("; if (is_sub) { @@ -612,10 +612,13 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) { oss << "extern \"C\" {\n"; } oss << " " << func_decl << ";\n"; - oss << "void " << init_func_name << "();\n"; - oss << "void " << fini_func_name << "();\n"; - oss << (is_compile_mode ? "odla_computation " : "static void ") - << helper_func_name << "();\n"; + oss << "int " << init_func_name << "();\n"; + oss << "int " << fini_func_name << "();\n"; + if (is_compile_mode) { + oss << "int " << helper_func_name << "(odla_computation comp);\n"; + } else { + oss << "static void " << helper_func_name << "()\n;"; + } if (opts_.dialect == Dialect::CXX_11) { oss << "};\n"; } @@ -626,9 +629,7 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) { if (emit_builder_func) { if (is_compile_mode) { os_ << "static odla_computation Comp;\n"; - os_ << "odla_computation " << helper_func_name << "() {\n"; - os_ << " odla_computation comp;\n"; - os_ << " odla_CreateComputation(&comp);\n"; + os_ << "int " << helper_func_name << "(odla_computation comp) {\n"; EmitComputationItems(&os_, opts_); } else { os_ << "static void " << helper_func_name @@ -641,11 +642,6 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) { os_ << " odla_SetCurrentDevice(device);"; } - if (is_compile_mode) { - os_ << " static odla_computation comp;\n"; - os_ << " if (comp == " << EmitNull() << ") {\n"; - os_ << " odla_CreateComputation(&comp);\n"; - } EmitComputationItems(&os_, opts_); } @@ -675,7 +671,7 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) { RunOnBasicBlock(*bb); } if (is_compile_mode) { - os_ << " return comp;\n"; + os_ << " return ODLA_SUCCESS;\n"; } os_ << "}\n"; // End of computation build function. @@ -684,28 +680,41 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) { dynamic_check_os_ << GenerateTestFunc(function, func_decl, *return_inst); } + const std::string& status_check{ + "if (status != ODLA_SUCCESS) { return status;}"}; + if (emit_builder_func) { // Emit function for launching computation. if (opts_.exec_mode == ExecMode::Compile) { if (function.IsEntryFunction()) { - os_ << "void " << fini_func_name << "(){\n"; - os_ << " odla_DestroyComputation(Comp);\n"; + os_ << "int " << fini_func_name << "(){\n"; + os_ << " if (Comp !=" << EmitNull() << ") {"; + os_ << " return odla_DestroyComputation(Comp);}\n"; + os_ << " return ODLA_SUCCESS;\n"; os_ << "}\n"; - os_ << "void " << init_func_name << "(){\n"; + os_ << "int " << init_func_name << "(){\n"; } else { os_ << GetFunctionDecl(function, *return_inst, true, true, true) << " {\n"; } - os_ << " if (Comp == " << EmitNull() << ") { Comp = " << helper_func_name - << "(); }\n"; + os_ << " odla_status status = ODLA_SUCCESS;\n"; + os_ << " if (Comp == " << EmitNull() << ") { \n"; + os_ << " status = odla_CreateComputation(&Comp);\n"; + os_ << " " << status_check << "\n"; + os_ << " status = (odla_status)" << helper_func_name << "(Comp);\n"; + os_ << " }\n"; + os_ << " return status;\n"; os_ << "}\n"; } + if (function.IsEntryFunction()) { os_ << GetFunctionDecl(function, *return_inst, true, true, true) << " {\n"; if (opts_.exec_mode == ExecMode::Compile) { - os_ << " " << init_func_name << "();\n"; + os_ << " odla_status status = ODLA_SUCCESS;\n"; + os_ << " status = (odla_status)" << init_func_name << "();\n"; + os_ << " " << status_check << "\n"; } } if (opts_.exec_mode == ExecMode::Interpret) { @@ -740,11 +749,14 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) { if (opts_.exec_mode == ExecMode::Compile) { os_ << " static odla_context Ctx;\n"; - os_ << " if (Ctx == " << EmitNull() - << ") { odla_CreateContext(&Ctx); };\n"; + os_ << " if (Ctx == " << EmitNull() << ") {"; + os_ << " status = odla_CreateContext(&Ctx);\n"; + os_ << " " << status_check << "\n"; + os_ << " }\n"; if (opts_.emit_dynamic_batch) { - os_ << "odla_SetContextItem(Ctx, ODLA_RUN_BATCH_SIZE, " + os_ << " status = odla_SetContextItem(Ctx, ODLA_RUN_BATCH_SIZE, " "(odla_item_value) &batch_size);\n"; + os_ << " " << status_check << "\n"; } } int index = 0; @@ -755,11 +767,13 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) { ? (is_sub ? "inputs.values[" : "inputs[") + std::to_string(index++) + "]" : cv.name; - os_ << (is_sub ? " odla_BindValueToArgumentById(" + os_ << " status = " + << (is_sub ? " odla_BindValueToArgumentById(" : " odla_BindToArgumentById(") << Join("(const odla_value_id)\"" + arg->GetName() + "\"", arg_name, "Ctx") << ");\n"; + os_ << " " << status_check << "\n"; } index = 0; // Pre-launch binding. @@ -769,12 +783,13 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) { ? (is_sub ? "outputs.values[" : "outputs[") + std::to_string(index++) + "]" : "out_" + cv.name; - os_ << " odla_Bind" << (is_sub ? "Value" : "") << "ToOutputById(" + os_ << " status = odla_Bind" << (is_sub ? "Value" : "") << "ToOutputById(" << Join("(const odla_value_id)\"" + cv.name + "\"", arg_name, "Ctx") << ");\n"; + os_ << " " << status_check << "\n"; } if (opts_.exec_mode == ExecMode::Compile) { - os_ << " odla_ExecuteComputation(Comp, Ctx, " + os_ << " return odla_ExecuteComputation(Comp, Ctx, " "ODLA_COMPUTE_INFERENCE, " << EmitNull() << ");\n"; } diff --git a/tests/compile/test_cxx_gen.cc b/tests/compile/test_cxx_gen.cc index 75d1b0444..73fb08a65 100644 --- a/tests/compile/test_cxx_gen.cc +++ b/tests/compile/test_cxx_gen.cc @@ -37,7 +37,7 @@ // GEN: static odla_computation Comp; -// GEN: void func(const float input[3], float out_add1[3]) { +// GEN: int func(const float input[3], float out_add1[3]) { // GEN: func_init(); // GEN: odla_BindToArgumentById((const odla_value_id)"input", input, Ctx); // GEN: odla_BindToOutputById((const odla_value_id)"add1", out_add1, Ctx); diff --git a/tests/compile/test_cxx_gen_gpu.cc b/tests/compile/test_cxx_gen_gpu.cc index 428a0faef..47bb66291 100644 --- a/tests/compile/test_cxx_gen_gpu.cc +++ b/tests/compile/test_cxx_gen_gpu.cc @@ -34,7 +34,7 @@ // GEN: static odla_computation Comp; -// GEN: void func(const float input[3], float out_add1[3]) { +// GEN: int func(const float input[3], float out_add1[3]) { // GEN: func_init(); // GEN: odla_BindToArgumentById((const odla_value_id)"input", input, Ctx); // GEN: odla_BindToOutputById((const odla_value_id)"add1", out_add1, Ctx);