Skip to content

Commit

Permalink
[Type Casting Pass] Add support for fp64->fp32 conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
Weiming Zhao committed Nov 29, 2021
1 parent 2c7191b commit 1577a50
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 29 deletions.
36 changes: 36 additions & 0 deletions lib/transforms/typecast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ bool TypeCast::RunOnFunction(Function* func) {
halo::Type new_ty{DataType::INT32, ty.GetDimSizes()};
arg->SetType(new_ty);
changed |= true;
} else if (ty.GetDataType() == DataType::FLOAT64) {
halo::Type new_ty{DataType::FLOAT32, ty.GetDimSizes()};
arg->SetType(new_ty);
changed |= true;
}
}
{
Module* m = func->GetParent();
Function::ConstantList& constants = m->Constants();
for (auto it = constants.begin(), ie = constants.end(); it != ie; ++it) {
const auto& orig_type = (*it)->GetResultType(0);
if (orig_type.GetDataType() == DataType::FLOAT64) {
std::vector<float> ret;
ret.reserve(orig_type.GetTotalNumOfElements());
for (unsigned int i = 0; i < orig_type.GetTotalNumOfElements(); ++i) {
ret.push_back(static_cast<float>((*it)->GetData<double>(i)));
}
(*it)->SetData(halo::Type{DataType::FLOAT32, orig_type.GetDimSizes()},
ret.data());
changed = true;
}
}
}
// Replace constants.
Expand All @@ -50,6 +71,17 @@ bool TypeCast::RunOnFunction(Function* func) {
halo::Type{DataType::INT32, orig_type.GetDimSizes()}, ret.data());
(*it)->ReplaceAllUsesWith(0, *c_ret);
changed = true;
} else if (orig_type.GetDataType() == DataType::FLOAT64) {
std::vector<float> ret;
ret.reserve(orig_type.GetTotalNumOfElements());
for (unsigned int i = 0; i < orig_type.GetTotalNumOfElements(); ++i) {
ret.push_back(static_cast<float>((*it)->GetData<double>(i)));
}
Constant* c_ret = cb.CreateConstant(
(*it)->GetName() + "_castdown",
halo::Type{DataType::FLOAT32, orig_type.GetDimSizes()}, ret.data());
(*it)->ReplaceAllUsesWith(0, *c_ret);
changed = true;
}
}

Expand All @@ -60,6 +92,10 @@ bool TypeCast::RunOnFunction(Function* func) {
if (orig_type.IsValid() && orig_type.GetDataType() == DataType::INT64) {
inst->GetResultsTypes()[i] =
halo::Type{DataType::INT32, orig_type.GetDimSizes()};
} else if (orig_type.IsValid() &&
orig_type.GetDataType() == DataType::FLOAT64) {
inst->GetResultsTypes()[i] =
halo::Type{DataType::FLOAT32, orig_type.GetDimSizes()};
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

// clang-format off
// Testing CXX Code Gen using ODLA API on popart
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.pb
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.pb
// RUN: %halo_compiler -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_batch_diagonal/model.onnx -o %t.cc
// RUN: %halo_compiler -disable-type-cast=false -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.pb
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.pb
// RUN: %halo_compiler -disable-type-cast=false -disable-type-cast=false -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_batch_diagonal/model.onnx -o %t.cc
// RUN: %cxx -c -fPIC -o %t.o %t.cc -I%odla_path/include
// RUN: %cxx -g %s %t.o %t.bin -I%T -I%odla_path/include -I%unittests_path -I%data_path/test_einsum_batch_diagonal/test_data_set_0 %odla_link %device_link -lodla_popart -o %t_popart.exe -Wno-deprecated-declarations
// RUN: %t_popart.exe 0.0001 0 popart %data_path/test_einsum_batch_diagonal | FileCheck %s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

// clang-format off
// Testing CXX Code Gen using ODLA API on tensorrt
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.pb
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.pb
// RUN: %halo_compiler -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_batch_diagonal/model.onnx -o %t.cc
// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/output_0.pb
// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_diagonal/test_data_set_0/input_0.pb
// RUN: %halo_compiler --disable-type-cast=false -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_batch_diagonal/model.onnx -o %t.cc
// RUN: %cxx -c -fPIC -o %t.o %t.cc -I%odla_path/include
// RUN: %cxx -g %s %t.o %t.bin -I%T -I%odla_path/include -I%unittests_path -I%data_path/test_einsum_batch_diagonal/test_data_set_0 %odla_link %device_link -lodla_tensorrt -o %t_tensorrt.exe -Wno-deprecated-declarations
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_batch_diagonal | FileCheck %s
// RUN: ODLA_TRT_USE_EXPLICIT_BATCH=1 %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_batch_diagonal | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

// clang-format off
// Testing CXX Code Gen using ODLA API on tensorrt
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_matmul/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_matmul/test_data_set_0/input_0.pb
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_matmul/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_matmul/test_data_set_0/output_0.pb
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_batch_matmul/test_data_set_0/input_1.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_matmul/test_data_set_0/input_1.pb
// RUN: %halo_compiler -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_batch_matmul/model.onnx -o %t.cc
// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_batch_matmul/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_matmul/test_data_set_0/input_0.pb
// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_batch_matmul/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_matmul/test_data_set_0/output_0.pb
// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_batch_matmul/test_data_set_0/input_1.cc -x onnx -emit-data-as-c %data_path/test_einsum_batch_matmul/test_data_set_0/input_1.pb
// RUN: %halo_compiler --disable-type-cast=false -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_batch_matmul/model.onnx -o %t.cc
// RUN: %cxx -c -fPIC -o %t.o %t.cc -I%odla_path/include
// RUN: %cxx -g %s %t.o %t.bin -I%T -I%odla_path/include -I%unittests_path -I%data_path/test_einsum_batch_matmul/test_data_set_0 %odla_link %device_link -lodla_tensorrt -o %t_tensorrt.exe -Wno-deprecated-declarations
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_batch_matmul | FileCheck %s
// RUN: ODLA_TRT_USE_EXPLICIT_BATCH=1 %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_batch_matmul | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_einsum_batch_matmul_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

// clang-format off
// Testing CXX Code Gen using ODLA API on tensorrt
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_inner_prod/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_inner_prod/test_data_set_0/input_0.pb
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_inner_prod/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_inner_prod/test_data_set_0/output_0.pb
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_inner_prod/test_data_set_0/input_1.cc -x onnx -emit-data-as-c %data_path/test_einsum_inner_prod/test_data_set_0/input_1.pb
// RUN: %halo_compiler -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_inner_prod/model.onnx -o %t.cc
// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_inner_prod/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_inner_prod/test_data_set_0/input_0.pb
// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_inner_prod/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_inner_prod/test_data_set_0/output_0.pb
// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_inner_prod/test_data_set_0/input_1.cc -x onnx -emit-data-as-c %data_path/test_einsum_inner_prod/test_data_set_0/input_1.pb
// RUN: %halo_compiler --disable-type-cast=false -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_inner_prod/model.onnx -o %t.cc
// RUN: %cxx -c -fPIC -o %t.o %t.cc -I%odla_path/include
// RUN: %cxx -g %s %t.o %t.bin -I%T -I%odla_path/include -I%unittests_path -I%data_path/test_einsum_inner_prod/test_data_set_0 %odla_link %device_link -lodla_tensorrt -o %t_tensorrt.exe -Wno-deprecated-declarations
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_inner_prod | FileCheck %s
// RUN: ODLA_TRT_USE_EXPLICIT_BATCH=1 %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_inner_prod | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_einsum_inner_prod_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

// clang-format off
// Testing CXX Code Gen using ODLA API on tensorrt
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_sum/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_sum/test_data_set_0/output_0.pb
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_sum/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_sum/test_data_set_0/input_0.pb
// RUN: %halo_compiler -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_sum/model.onnx -o %t.cc
// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_sum/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_sum/test_data_set_0/output_0.pb
// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_sum/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_sum/test_data_set_0/input_0.pb
// RUN: %halo_compiler --disable-type-cast=false -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_sum/model.onnx -o %t.cc
// RUN: %cxx -c -fPIC -o %t.o %t.cc -I%odla_path/include
// RUN: %cxx -g %s %t.o %t.bin -I%T -I%odla_path/include -I%unittests_path -I%data_path/test_einsum_sum/test_data_set_0 %odla_link %device_link -lodla_tensorrt -o %t_tensorrt.exe -Wno-deprecated-declarations
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_sum | FileCheck %s
// RUN: ODLA_TRT_USE_EXPLICIT_BATCH=1 %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_sum | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_einsum_sum_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

// clang-format off
// Testing CXX Code Gen using ODLA API on tensorrt
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_transpose/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_transpose/test_data_set_0/output_0.pb
// RUN: %halo_compiler -target cxx -o %data_path/test_einsum_transpose/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_transpose/test_data_set_0/input_0.pb
// RUN: %halo_compiler -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_transpose/model.onnx -o %t.cc
// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_transpose/test_data_set_0/output_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_transpose/test_data_set_0/output_0.pb
// RUN: %halo_compiler --disable-type-cast=false -target cxx -o %data_path/test_einsum_transpose/test_data_set_0/input_0.cc -x onnx -emit-data-as-c %data_path/test_einsum_transpose/test_data_set_0/input_0.pb
// RUN: %halo_compiler --disable-type-cast=false -target cxx -batch-size 1 %halo_compile_flags %data_path/test_einsum_transpose/model.onnx -o %t.cc
// RUN: %cxx -c -fPIC -o %t.o %t.cc -I%odla_path/include
// RUN: %cxx -g %s %t.o %t.bin -I%T -I%odla_path/include -I%unittests_path -I%data_path/test_einsum_transpose/test_data_set_0 %odla_link %device_link -lodla_tensorrt -o %t_tensorrt.exe -Wno-deprecated-declarations
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_transpose | FileCheck %s
// RUN: ODLA_TRT_USE_EXPLICIT_BATCH=1 %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_einsum_transpose | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_einsum_transpose_tensorrt.cc.tmp.main.cc.in"

0 comments on commit 1577a50

Please sign in to comment.