From 1577a50d1b5065b0e44f87171bf20ee0924d992e Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Mon, 29 Nov 2021 18:17:39 +0000 Subject: [PATCH] [Type Casting Pass] Add support for fp64->fp32 conversion --- lib/transforms/typecast.cc | 36 +++++++++++++++++++ .../test_einsum_batch_diagonal_popart.cc | 6 ++-- .../test_einsum_batch_diagonal_tensorrt.cc | 8 ++--- .../test_einsum_batch_matmul_tensorrt.cc | 12 +++---- .../test_einsum_inner_prod_tensorrt.cc | 12 +++---- .../test_tensorrt/test_einsum_sum_tensorrt.cc | 10 +++--- .../test_einsum_transpose_tensorrt.cc | 10 +++--- 7 files changed, 65 insertions(+), 29 deletions(-) diff --git a/lib/transforms/typecast.cc b/lib/transforms/typecast.cc index e837ef681..7153b77c4 100644 --- a/lib/transforms/typecast.cc +++ b/lib/transforms/typecast.cc @@ -32,6 +32,27 @@ bool TypeCast::RunOnFunction(Function* func) { halo::Type new_ty{DataType::INT32, ty.GetDimSizes()}; arg->SetType(new_ty); changed |= true; + } else if (ty.GetDataType() == DataType::FLOAT64) { + halo::Type new_ty{DataType::FLOAT32, ty.GetDimSizes()}; + arg->SetType(new_ty); + changed |= true; + } + } + { + Module* m = func->GetParent(); + Function::ConstantList& constants = m->Constants(); + for (auto it = constants.begin(), ie = constants.end(); it != ie; ++it) { + const auto& orig_type = (*it)->GetResultType(0); + if (orig_type.GetDataType() == DataType::FLOAT64) { + std::vector ret; + ret.reserve(orig_type.GetTotalNumOfElements()); + for (unsigned int i = 0; i < orig_type.GetTotalNumOfElements(); ++i) { + ret.push_back(static_cast((*it)->GetData(i))); + } + (*it)->SetData(halo::Type{DataType::FLOAT32, orig_type.GetDimSizes()}, + ret.data()); + changed = true; + } } } // Replace constants. @@ -50,6 +71,17 @@ bool TypeCast::RunOnFunction(Function* func) { halo::Type{DataType::INT32, orig_type.GetDimSizes()}, ret.data()); (*it)->ReplaceAllUsesWith(0, *c_ret); changed = true; + } else if (orig_type.GetDataType() == DataType::FLOAT64) { + std::vector ret; + ret.reserve(orig_type.GetTotalNumOfElements()); + for (unsigned int i = 0; i < orig_type.GetTotalNumOfElements(); ++i) { + ret.push_back(static_cast((*it)->GetData(i))); + } + Constant* c_ret = cb.CreateConstant( + (*it)->GetName() + "_castdown", + halo::Type{DataType::FLOAT32, orig_type.GetDimSizes()}, ret.data()); + (*it)->ReplaceAllUsesWith(0, *c_ret); + changed = true; } } @@ -60,6 +92,10 @@ bool TypeCast::RunOnFunction(Function* func) { if (orig_type.IsValid() && orig_type.GetDataType() == DataType::INT64) { inst->GetResultsTypes()[i] = halo::Type{DataType::INT32, orig_type.GetDimSizes()}; + } else if (orig_type.IsValid() && + orig_type.GetDataType() == DataType::FLOAT64) { + inst->GetResultsTypes()[i] = + halo::Type{DataType::FLOAT32, orig_type.GetDimSizes()}; } } } diff --git a/tests/unittests/lit_cases/test_popart/test_einsum_batch_diagonal_popart.cc b/tests/unittests/lit_cases/test_popart/test_einsum_batch_diagonal_popart.cc index 742013532..f5a7a798e 100644 --- a/tests/unittests/lit_cases/test_popart/test_einsum_batch_diagonal_popart.cc +++ b/tests/unittests/lit_cases/test_popart/test_einsum_batch_diagonal_popart.cc @@ -17,9 +17,9 @@ // clang-format off // Testing CXX Code Gen using ODLA API on popart -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.pb -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.pb -// RUN: %halo_compiler -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_batch_diagonal/model.onnx -o %t.cc +// RUN: %halo_compiler -disable-type-cast=false -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.pb +// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.pb +// RUN: %halo_compiler -disable-type-cast=false -disable-type-cast=false -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_batch_diagonal/model.onnx -o %t.cc // RUN: %cxx -c -fPIC -o %t.o %t.cc -I%odla_path/include // RUN: %cxx -g %s %t.o %t.bin -I%T -I%odla_path/include -I%unittests_path -I%data_path/test_einsum_batch_diagonal/test_data_set_0 %odla_link %device_link -lodla_popart -o %t_popart.exe -Wno-deprecated-declarations // RUN: %t_popart.exe 0.0001 0 popart %data_path/test_einsum_batch_diagonal | FileCheck %s diff --git a/tests/unittests/lit_cases/test_tensorrt/test_einsum_batch_diagonal_tensorrt.cc b/tests/unittests/lit_cases/test_tensorrt/test_einsum_batch_diagonal_tensorrt.cc index a02b650b8..8c297d148 100644 --- a/tests/unittests/lit_cases/test_tensorrt/test_einsum_batch_diagonal_tensorrt.cc +++ b/tests/unittests/lit_cases/test_tensorrt/test_einsum_batch_diagonal_tensorrt.cc @@ -17,12 +17,12 @@ // clang-format off // Testing CXX Code Gen using ODLA API on tensorrt -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.pb -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.pb -// RUN: %halo_compiler -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_batch_diagonal/model.onnx -o %t.cc +// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.pb +// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.pb +// RUN: %halo_compiler --disable-type-cast=false -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_batch_diagonal/model.onnx -o %t.cc // RUN: %cxx -c -fPIC -o %t.o %t.cc -I%odla_path/include // RUN: %cxx -g %s %t.o %t.bin -I%T -I%odla_path/include -I%unittests_path -I%data_path/test_einsum_batch_diagonal/test_data_set_0 %odla_link %device_link -lodla_tensorrt -o %t_tensorrt.exe -Wno-deprecated-declarations -// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_batch_diagonal | FileCheck %s +// RUN: ODLA_TRT_USE_EXPLICIT_BATCH=1 %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_batch_diagonal | FileCheck %s // CHECK: Result Pass // clang-format on // XFAIL: * diff --git a/tests/unittests/lit_cases/test_tensorrt/test_einsum_batch_matmul_tensorrt.cc b/tests/unittests/lit_cases/test_tensorrt/test_einsum_batch_matmul_tensorrt.cc index 7e395172b..9c671ab59 100644 --- a/tests/unittests/lit_cases/test_tensorrt/test_einsum_batch_matmul_tensorrt.cc +++ b/tests/unittests/lit_cases/test_tensorrt/test_einsum_batch_matmul_tensorrt.cc @@ -17,14 +17,14 @@ // clang-format off // Testing CXX Code Gen using ODLA API on tensorrt -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_matmul/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_matmul/test_data_set_0/input_0.pb -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_matmul/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_matmul/test_data_set_0/output_0.pb -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_matmul/test_data_set_0/input_1.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_matmul/test_data_set_0/input_1.pb -// RUN: %halo_compiler -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_batch_matmul/model.onnx -o %t.cc +// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_batch_matmul/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_matmul/test_data_set_0/input_0.pb +// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_batch_matmul/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_matmul/test_data_set_0/output_0.pb +// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_batch_matmul/test_data_set_0/input_1.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_matmul/test_data_set_0/input_1.pb +// RUN: %halo_compiler --disable-type-cast=false -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_batch_matmul/model.onnx -o %t.cc // RUN: %cxx -c -fPIC -o %t.o %t.cc -I%odla_path/include // RUN: %cxx -g %s %t.o %t.bin -I%T -I%odla_path/include -I%unittests_path -I%data_path/test_einsum_batch_matmul/test_data_set_0 %odla_link %device_link -lodla_tensorrt -o %t_tensorrt.exe -Wno-deprecated-declarations -// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_batch_matmul | FileCheck %s +// RUN: ODLA_TRT_USE_EXPLICIT_BATCH=1 %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_batch_matmul | FileCheck %s // CHECK: Result Pass // clang-format on -// XFAIL: * + #include "test_einsum_batch_matmul_tensorrt.cc.tmp.main.cc.in" diff --git a/tests/unittests/lit_cases/test_tensorrt/test_einsum_inner_prod_tensorrt.cc b/tests/unittests/lit_cases/test_tensorrt/test_einsum_inner_prod_tensorrt.cc index 45f29783f..c35f089b9 100644 --- a/tests/unittests/lit_cases/test_tensorrt/test_einsum_inner_prod_tensorrt.cc +++ b/tests/unittests/lit_cases/test_tensorrt/test_einsum_inner_prod_tensorrt.cc @@ -17,14 +17,14 @@ // clang-format off // Testing CXX Code Gen using ODLA API on tensorrt -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_inner_prod/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_inner_prod/test_data_set_0/input_0.pb -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_inner_prod/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_inner_prod/test_data_set_0/output_0.pb -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_inner_prod/test_data_set_0/input_1.cc -x onnx -emit-data-as-c %data_path/test_einsum_inner_prod/test_data_set_0/input_1.pb -// RUN: %halo_compiler -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_inner_prod/model.onnx -o %t.cc +// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_inner_prod/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_inner_prod/test_data_set_0/input_0.pb +// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_inner_prod/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_inner_prod/test_data_set_0/output_0.pb +// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_inner_prod/test_data_set_0/input_1.cc -x onnx -emit-data-as-c %data_path/test_einsum_inner_prod/test_data_set_0/input_1.pb +// RUN: %halo_compiler --disable-type-cast=false -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_inner_prod/model.onnx -o %t.cc // RUN: %cxx -c -fPIC -o %t.o %t.cc -I%odla_path/include // RUN: %cxx -g %s %t.o %t.bin -I%T -I%odla_path/include -I%unittests_path -I%data_path/test_einsum_inner_prod/test_data_set_0 %odla_link %device_link -lodla_tensorrt -o %t_tensorrt.exe -Wno-deprecated-declarations -// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_inner_prod | FileCheck %s +// RUN: ODLA_TRT_USE_EXPLICIT_BATCH=1 %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_inner_prod | FileCheck %s // CHECK: Result Pass // clang-format on -// XFAIL: * + #include "test_einsum_inner_prod_tensorrt.cc.tmp.main.cc.in" diff --git a/tests/unittests/lit_cases/test_tensorrt/test_einsum_sum_tensorrt.cc b/tests/unittests/lit_cases/test_tensorrt/test_einsum_sum_tensorrt.cc index 7a889b86f..a6723e47e 100644 --- a/tests/unittests/lit_cases/test_tensorrt/test_einsum_sum_tensorrt.cc +++ b/tests/unittests/lit_cases/test_tensorrt/test_einsum_sum_tensorrt.cc @@ -17,13 +17,13 @@ // clang-format off // Testing CXX Code Gen using ODLA API on tensorrt -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_sum/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_sum/test_data_set_0/output_0.pb -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_sum/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_sum/test_data_set_0/input_0.pb -// RUN: %halo_compiler -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_sum/model.onnx -o %t.cc +// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_sum/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_sum/test_data_set_0/output_0.pb +// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_sum/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_sum/test_data_set_0/input_0.pb +// RUN: %halo_compiler --disable-type-cast=false -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_sum/model.onnx -o %t.cc // RUN: %cxx -c -fPIC -o %t.o %t.cc -I%odla_path/include // RUN: %cxx -g %s %t.o %t.bin -I%T -I%odla_path/include -I%unittests_path -I%data_path/test_einsum_sum/test_data_set_0 %odla_link %device_link -lodla_tensorrt -o %t_tensorrt.exe -Wno-deprecated-declarations -// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_sum | FileCheck %s +// RUN: ODLA_TRT_USE_EXPLICIT_BATCH=1 %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_sum | FileCheck %s // CHECK: Result Pass // clang-format on -// XFAIL: * + #include "test_einsum_sum_tensorrt.cc.tmp.main.cc.in" diff --git a/tests/unittests/lit_cases/test_tensorrt/test_einsum_transpose_tensorrt.cc b/tests/unittests/lit_cases/test_tensorrt/test_einsum_transpose_tensorrt.cc index 19ffb9bd3..bc0db5135 100644 --- a/tests/unittests/lit_cases/test_tensorrt/test_einsum_transpose_tensorrt.cc +++ b/tests/unittests/lit_cases/test_tensorrt/test_einsum_transpose_tensorrt.cc @@ -17,13 +17,13 @@ // clang-format off // Testing CXX Code Gen using ODLA API on tensorrt -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_transpose/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_transpose/test_data_set_0/output_0.pb -// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_transpose/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_transpose/test_data_set_0/input_0.pb -// RUN: %halo_compiler -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_transpose/model.onnx -o %t.cc +// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_transpose/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_transpose/test_data_set_0/output_0.pb +// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_transpose/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_transpose/test_data_set_0/input_0.pb +// RUN: %halo_compiler --disable-type-cast=false -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_transpose/model.onnx -o %t.cc // RUN: %cxx -c -fPIC -o %t.o %t.cc -I%odla_path/include // RUN: %cxx -g %s %t.o %t.bin -I%T -I%odla_path/include -I%unittests_path -I%data_path/test_einsum_transpose/test_data_set_0 %odla_link %device_link -lodla_tensorrt -o %t_tensorrt.exe -Wno-deprecated-declarations -// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_transpose | FileCheck %s +// RUN: ODLA_TRT_USE_EXPLICIT_BATCH=1 %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_transpose | FileCheck %s // CHECK: Result Pass // clang-format on -// XFAIL: * + #include "test_einsum_transpose_tensorrt.cc.tmp.main.cc.in"