From c95b119506493f0b194f1cac3084e73942c8ccb8 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Thu, 9 Dec 2021 10:41:58 -0800 Subject: [PATCH 01/12] wip --- .github/actions/build/build_in_docker.sh | 2 +- ODLA/platforms/tensorrt/odla_tensorrt.cc | 3 + include/halo/lib/pass/pass_manager.h | 1 + include/halo/utils/passes_helper.h | 3 + lib/pass/pass_manager.cc | 3 + lib/target/generic_cpp/generic_cxx_codegen.cc | 6 + lib/target/generic_cpp/return.cc | 3 + lib/transforms/CMakeLists.txt | 1 + lib/transforms/dce.cc | 20 ++- lib/transforms/input_legalizer.cc | 1 + lib/transforms/tfextension_legalizer.cc | 30 +++- lib/transforms/type_legalizer.cc | 24 +++ utils/docker/Dockerfile | 168 +++++++++++------- utils/docker/build_image.sh | 4 +- utils/docker/start_docker_cpu.sh | 2 +- 15 files changed, 201 insertions(+), 70 deletions(-) diff --git a/.github/actions/build/build_in_docker.sh b/.github/actions/build/build_in_docker.sh index 0bac14de0..b65e95dac 100755 --- a/.github/actions/build/build_in_docker.sh +++ b/.github/actions/build/build_in_docker.sh @@ -1,7 +1,7 @@ #!/bin/bash -xe REPO="registry-intl.us-west-1.aliyuncs.com/computation/halo" -VER="0.7.6" +VER="0.7.7" FLAVOR="devel" MOUNT_DIR="$PWD" diff --git a/ODLA/platforms/tensorrt/odla_tensorrt.cc b/ODLA/platforms/tensorrt/odla_tensorrt.cc index 26ffbd9eb..71d3c6edd 100644 --- a/ODLA/platforms/tensorrt/odla_tensorrt.cc +++ b/ODLA/platforms/tensorrt/odla_tensorrt.cc @@ -456,6 +456,9 @@ static odla_value CreateValue(T* t, const odla_value_type& type, auto v = std::make_unique<_odla_value>(t, type, name); auto ret = v.get(); g_comp->vals.push_back(std::move(v)); + if (!g_comp->branchs.empty()) { + g_comp->branchs.top().branch->addInput(*ret); + } return ret; } diff --git a/include/halo/lib/pass/pass_manager.h b/include/halo/lib/pass/pass_manager.h index 3da937099..ac9335053 100644 --- a/include/halo/lib/pass/pass_manager.h +++ b/include/halo/lib/pass/pass_manager.h @@ -67,6 +67,7 @@ class HL_API_EXPORT PassManager final { Pass* AddCodeFormatterPass(std::ostringstream& code, std::ostringstream& header, const CXXCodeGenOpts& opts); + Pass* AddConvertTFCFGPass(); Pass* AddDCEPass(); Pass* AddDevicePlacementPass(); Pass* AddFusionPass(const FusionOptions& opts); diff --git a/include/halo/utils/passes_helper.h b/include/halo/utils/passes_helper.h index 9fd8acee1..70a5b8e3d 100644 --- a/include/halo/utils/passes_helper.h +++ b/include/halo/utils/passes_helper.h @@ -162,6 +162,9 @@ static void PopulateOptPasses(PassManager* pm, const std::string& target, if (opts.enable_type_cast) { pm->AddTypeCastPass(); } + if (format == ModelFormat::TENSORFLOW) { + pm->AddConvertTFCFGPass(); + } if (opts.constant_decombine) { pm->AddConstantDecombinePass(); } diff --git a/lib/pass/pass_manager.cc b/lib/pass/pass_manager.cc index 291af937d..dab4c013d 100644 --- a/lib/pass/pass_manager.cc +++ b/lib/pass/pass_manager.cc @@ -31,6 +31,7 @@ #include "halo/lib/transforms/analyzer.h" #include "halo/lib/transforms/caffeextension_legalizer.h" #include "halo/lib/transforms/constant_decombine.h" +#include "halo/lib/transforms/convert_tf_cfg.h" #include "halo/lib/transforms/dce.h" #include "halo/lib/transforms/device_placement.h" #include "halo/lib/transforms/fusion.h" @@ -265,6 +266,8 @@ Pass* PassManager::AddCodeFormatterPass(std::ostringstream& buf_code, return AddPass(buf_code, buf_header, opts); } +Pass* PassManager::AddConvertTFCFGPass() { return AddPass(); } + Pass* PassManager::AddDCEPass() { return AddPass(); } Pass* PassManager::AddDevicePlacementPass() { diff --git a/lib/target/generic_cpp/generic_cxx_codegen.cc b/lib/target/generic_cpp/generic_cxx_codegen.cc index faef90111..ada9de59f 100644 --- a/lib/target/generic_cpp/generic_cxx_codegen.cc +++ b/lib/target/generic_cpp/generic_cxx_codegen.cc @@ -270,6 +270,9 @@ std::string GenericCXXCodeGen::GetFunctionDecl(const Function& func, auto nr_outputs = ret_inst.GetNumOfOperands(); model_info.num_outputs = nr_outputs; for (const auto& out : ret_inst.GetOperands()) { + if (out.IsNull()) { + continue; + } const auto& type = out.GetType(); if (ir_mapping_.find(out) == ir_mapping_.end()) { CXXValue cv(out.GetDef()->GetName(), TensorTypeToCXXType(type, false)); @@ -840,6 +843,9 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) { index = 0; // Pre-launch binding. for (auto& op : return_inst->GetOperands()) { + if (op.IsNull()) { + continue; + } auto& cv = ir_mapping_[op]; std::string arg_name = (opts_.emit_inference_func_sig || is_sub) ? (is_sub ? "outputs.values[" : "outputs[") + diff --git a/lib/target/generic_cpp/return.cc b/lib/target/generic_cpp/return.cc index 51c36722f..584ca7a56 100644 --- a/lib/target/generic_cpp/return.cc +++ b/lib/target/generic_cpp/return.cc @@ -25,6 +25,9 @@ namespace halo { void GenericCXXCodeGen::RunOnInstruction(ReturnInst* inst) { bool is_compile_mode = opts_.exec_mode == ExecMode::Compile; for (auto& op : inst->GetOperands()) { + if (op.IsNull()) { + continue; + } const CXXValue& val = ir_mapping_[op]; if (is_compile_mode) { bool is_entry_with_calls = diff --git a/lib/transforms/CMakeLists.txt b/lib/transforms/CMakeLists.txt index 14c82c08e..7c003c1c0 100644 --- a/lib/transforms/CMakeLists.txt +++ b/lib/transforms/CMakeLists.txt @@ -22,6 +22,7 @@ set(SRCS analyzer.cc caffeextension_legalizer.cc constant_decombine.cc + convert_tf_cfg.cc dce.cc device_placement.cc fusion.cc diff --git a/lib/transforms/dce.cc b/lib/transforms/dce.cc index ae6be811c..b7792048d 100644 --- a/lib/transforms/dce.cc +++ b/lib/transforms/dce.cc @@ -24,17 +24,28 @@ namespace halo { -static void RemoveLoopBody(LoopInst* loop_inst) { - auto body = loop_inst->GetBody(); - auto return_inst = body->GetReturnInst(); +static void RemoveBody(BasicBlock* bb) { + auto return_inst = bb->GetReturnInst(); if (return_inst != nullptr) { // Drop all the operands of the return instruction so the rest of the body // loop will be DCE'ed automatically. // Note that the return inst cannot be erased because the current legalizer // will try to append one if no return inst exists for a block. return_inst->DropAllOperands(); + if (bb->Instructions().size() == 1) { + bb->Instructions().clear(); + return; + } } } +static void RemoveLoopBody(LoopInst* loop_inst) { + RemoveBody(loop_inst->GetBody()); +} + +static void RemoveIfBody(IfInst* if_inst) { + RemoveBody(if_inst->GetThenBranch()); + RemoveBody(if_inst->GetElseBranch()); +} // For instructions with `undef` operands, they are unreachable except for // `tf_merge` and optional operands. @@ -85,6 +96,9 @@ bool DCE::RunOnBasicBlock(BasicBlock* bb) { if (inst->GetOpCode() == OpCode::LOOP) { RemoveLoopBody(DynCast(inst)); } + if (inst->GetOpCode() == OpCode::IF) { + RemoveIfBody(DynCast(inst)); + } it = bb->Instructions().erase(it); } else { it = std::next(it); diff --git a/lib/transforms/input_legalizer.cc b/lib/transforms/input_legalizer.cc index 34fa2b655..c93126095 100644 --- a/lib/transforms/input_legalizer.cc +++ b/lib/transforms/input_legalizer.cc @@ -119,6 +119,7 @@ bool InputLegalizer::RunOnFunction(Function* func) { : it->second.GetDimSizes(); arg->GetResultsTypes()[0] = halo::Type(dt, dims); specified_shapes.erase(it); + changed = true; } auto dims = ty.GetDimSizes(); diff --git a/lib/transforms/tfextension_legalizer.cc b/lib/transforms/tfextension_legalizer.cc index 9f37fe634..912fd8d27 100644 --- a/lib/transforms/tfextension_legalizer.cc +++ b/lib/transforms/tfextension_legalizer.cc @@ -489,7 +489,9 @@ static std::vector ConvertStridedSlice(const TFExtensionInst* ext, static std::vector ConvertSwitch(const TFExtensionInst* ext, IRBuilder* builder) { const auto& data = ext->GetOperand(0); - if (const Constant* pred = DynCast(ext->GetOperand(1)); + const auto& cond = ext->GetOperand(1); +#if 0 + if (const Constant* pred = DynCast(cond); pred != nullptr) { HLCHECK(pred->GetResultType().GetTotalNumOfElements() == 1); bool cond = pred->GetDataAsInt64(0) != 0; @@ -497,7 +499,33 @@ static std::vector ConvertSwitch(const TFExtensionInst* ext, std::vector ret_false{data, Def::GetUndefined()}; return cond ? ret_true : ret_false; } +#endif +// TODO(unknown): move to separate pass? +#if 1 + builder->SetInsertAfter(ext); + BasicBlockBuilder bb_builder(ext->GetParent()->GetParent()); + const auto& name = ext->GetName(); + auto if_inst = builder->CreateIf(ext->GetName(), cond); + if_inst->AddOneOperand(data); + + BasicBlock* bb_t = bb_builder.CreateBasicBlock(name + "_true"); + if_inst->SetThenBranch(bb_t); + IRBuilder builder_t(bb_t); + auto arg_builder_t = std::make_unique(bb_t); + auto arg_t = arg_builder_t->CreateArgument(name + "_t", data.GetType()); + builder_t.CreateReturn(name + "ret_t", *arg_t); + + BasicBlock* bb_f = bb_builder.CreateBasicBlock(name + "_false"); + IRBuilder builder_f(bb_f); + if_inst->SetElseBranch(bb_f); + auto arg_builder_f = std::make_unique(bb_f); + auto arg_f = arg_builder_f->CreateArgument(name + "_f", data.GetType()); + builder_f.CreateReturn(name + "ret_f", *arg_f); + if_inst->SetNumOfResults(2); + return {Def(if_inst, 0), Def(if_inst, 1)}; +#else return {}; +#endif } static std::vector ConvertMerge(const TFExtensionInst* ext, diff --git a/lib/transforms/type_legalizer.cc b/lib/transforms/type_legalizer.cc index 848b3c5b3..560f21e71 100644 --- a/lib/transforms/type_legalizer.cc +++ b/lib/transforms/type_legalizer.cc @@ -1531,6 +1531,25 @@ static void RunOnInstruction(BitcastInst* inst) { inst->GetResultsTypes()[0] = result_type; } +static void RunOnInstruction(TFExtensionInst* inst) { + if (inst->GetExtOpCode() == TFExtOpCode::MERGE) { + for (auto& op : inst->GetOperands()) { + if (op.GetType().IsValid()) { + inst->GetResultsTypes()[0] = op.GetType(); + return; + } + } + return; + } + if (inst->GetExtOpCode() == TFExtOpCode::SWITCH) { + const auto& ty = inst->GetOperand(0).GetType(); + if (ty.IsValid()) { + inst->GetResultsTypes() = {ty, ty}; + } + return; + } +} + static void RunOnInstruction(UniqueInst* inst) { const auto& type0 = inst->GetOperand(0).GetType(); if (!type0.IsValid()) { @@ -1569,6 +1588,11 @@ bool TypeLegalizer::RunOnBasicBlock(BasicBlock* bb) { #define GET_INST_DOWNCAST_SWITCH #include "halo/lib/ir/instructions_info.def" #undef GET_INST_DOWNCAST_SWITCH + case OpCode::EXTENSION: { + TFExtensionInst* ext = DynCast(inst); + RunOnInstruction(ext); + break; + } default: { if (!relaxed_) { // HLCHECK(0 && "Unreachable"); diff --git a/utils/docker/Dockerfile b/utils/docker/Dockerfile index 2727d73f7..fac780edc 100644 --- a/utils/docker/Dockerfile +++ b/utils/docker/Dockerfile @@ -1,6 +1,103 @@ +# syntax=docker/dockerfile:experimental # Build this image: docker build -t halo:[version] . ARG BASE_IMAGE + +# cmake +FROM ubuntu:18.04 AS cmake +ARG CMAKE_VERSION=3.14.5 +RUN apt-get update && apt-get install -y curl gcc g++ make zlib1g zlib1g-dev libcurl4-openssl-dev git && \ + gcc --version && \ + mkdir /tmp/cmake && \ + cd /tmp/cmake && \ + curl -L https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz -o cmake.tar.gz && \ + tar zxf cmake.tar.gz && \ + cd cmake-${CMAKE_VERSION} && \ + ./bootstrap --system-curl --parallel=48 && \ + make -j all && \ + make install DESTDIR=/tmp/cmake/install && \ + make install && \ + tar -C /tmp/cmake/install -cf /tmp/cmake.tar usr && \ + rm -fr /tmp/cmake + +# binutils +FROM cmake AS binutils +ARG BINUTILS_VERSION=2.35 +RUN mkdir /tmp/binutils && \ + cd /tmp/binutils && \ + curl -s http://ftp.gnu.org/gnu/binutils/binutils-${BINUTILS_VERSION}.tar.gz | tar xvz && \ + cd binutils-${BINUTILS_VERSION} && \ + ./configure && \ + make -j all && \ + make install DESTDIR=/tmp/binutils/install && \ + tar -C /tmp/binutils/install -cf /tmp/binutils.tar usr && \ + rm -rf /tmp/binutils + +# valgrind +FROM cmake AS valgrind +ARG VALGRIND_VERSION=3.18.1 +RUN mkdir /tmp/valgrind && \ + cd /tmp/valgrind && \ + curl -o valgrind.tar.bz2 ftp://sourceware.org/pub/valgrind/valgrind-${VALGRIND_VERSION}.tar.bz2 && \ + tar jxf valgrind.tar.bz2 && \ + cd valgrind-${VALGRIND_VERSION} && \ + ./configure && \ + make -j all && \ + make install DESTDIR=/tmp/valgrind/install && \ + tar -C /tmp/valgrind/install -cf /tmp/valgrid.tar usr && \ + rm -rf /tmp/valgrind + +# Build Protobuf (static) +FROM cmake AS pb +RUN git -C /tmp clone --depth=1 https://github.com/protocolbuffers/protobuf.git -b v3.9.1 && \ + cd /tmp/protobuf/cmake && \ + cmake -G "Unix Makefiles" -Dprotobuf_BUILD_TESTS=OFF \ + -Dprotobuf_BUILD_SHARED_LIBS=OFF \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON . && \ + make -j && make install DESTDIR=/tmp/protobuf/install && \ + tar -C /tmp/protobuf/install -cf /tmp/protobuf.tar usr && \ + rm -fr /tmp/protobuf + +# Build Flatbuffer +FROM cmake as fb +RUN git -C /tmp clone --depth=1 https://github.com/google/flatbuffers.git -b v1.12.0 && \ + cd /tmp/flatbuffers && \ + cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release -DFLATBUFFERS_BUILD_SHAREDLIB=OFF && \ + make -j && make install DESTDIR=/tmp/flatbuffers/install && \ + tar -C /tmp/flatbuffers/install -cf /tmp/flatbuffers.tar usr && \ + rm -fr /tmp/flatbuffers + +# Build glog +FROM cmake AS glog +RUN git -C /tmp clone --depth=1 https://github.com/google/glog.git -b v0.4.0 && \ + cd /tmp/glog && \ + cmake -H. -Bbuild -G "Unix Makefiles" && \ + cd build && \ + make -j && make install DESTDIR=/tmp/glog/install && \ + tar -C /tmp/glog/install -cf /tmp/glog.tar usr && \ + rm -fr /tmp/glog + +# Build DNNL +FROM cmake as dnnl +RUN git -C /tmp clone --depth=1 https://github.com/oneapi-src/oneDNN.git --branch v1.7 && \ + cd /tmp/oneDNN && \ + cmake -G "Unix Makefiles" -DDNNL_BUILD_EXAMPLES=OFF -DDNNL_BUILD_TESTS=OFF -DDNNL_ENABLE_PRIMITIVE_CACHE=ON -DCMAKE_INSTALL_PREFIX=/opt/dnnl && \ + make -j && make install DESTDIR=/tmp/oneDNN/install && \ + tar -C /tmp/oneDNN/install -cf /tmp/dnnl.tar opt && \ + rm -fr /tmp/oneDNN + +# Build XNNPack +FROM cmake as xnnpack +RUN git -C /tmp clone https://github.com/google/XNNPACK.git && \ + cd /tmp/XNNPACK && git checkout -b tmp 90db69f681ea9abd1ced813c17c00007f14ce58b && \ + mkdir /tmp/xnn_build && cd /tmp/xnn_build && \ + cmake -G "Unix Makefiles" ../XNNPACK -DXNNPACK_LIBRARY_TYPE=static -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DXNNPACK_BUILD_TESTS=OFF -DXNNPACK_BUILD_BENCHMARKS=OFF -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/opt/XNNPACK && \ + make -j && make install DESTDIR=/tmp/XNNPACK/install && \ + tar -C /tmp/XNNPACK/install -cf /tmp/xnnpack.tar opt && \ + rm -fr /tmp/XNNPACK /mp/xnn_build + FROM ${BASE_IMAGE} # Redeclare the argument @@ -125,58 +222,19 @@ RUN wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - && \ RUN pip3 install wheel numpy six jupyter enum34 mock h5py pillow # Update binutils -ARG BINUTILS_VERSION=2.35 -RUN mkdir /tmp/binutils && \ - cd /tmp/binutils && \ - wget http://ftp.gnu.org/gnu/binutils/binutils-${BINUTILS_VERSION}.tar.gz && \ - tar zxf binutils-${BINUTILS_VERSION}.tar.gz && \ - cd binutils-${BINUTILS_VERSION} && \ - ./configure && \ - make -j all && \ - make install && \ - rm -rf /tmp/binutils +RUN --mount=from=binutils,target=/pkg,source=/tmp tar -C / -xf /pkg/binutils.tar # Install cmake -ARG CMAKE_VERSION=3.14.5 -RUN mkdir /tmp/cmake && \ - cd /tmp/cmake && \ - wget https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}.tar.gz && \ - tar zxf cmake-${CMAKE_VERSION}.tar.gz && \ - cd cmake-${CMAKE_VERSION} && \ - ./bootstrap --system-curl --parallel=48 && \ - make -j all && \ - make install && \ - rm -rf /tmp/cmake +RUN --mount=from=cmake,target=/pkg,source=/tmp tar -C / -xf /pkg/cmake.tar # Install valgrind -ARG VALGRIND_VERSION=3.13.0 -RUN mkdir /tmp/valgrind && \ - cd /tmp/valgrind && \ - wget ftp://sourceware.org/pub/valgrind/valgrind-${VALGRIND_VERSION}.tar.bz2 && \ - tar jxf valgrind-${VALGRIND_VERSION}.tar.bz2 && \ - cd valgrind-${VALGRIND_VERSION} && \ - ./configure && \ - make -j all && \ - make install && \ - rm -rf /tmp/valgrind +RUN --mount=from=valgrind,target=/pkg,source=/tmp tar -C / -xf /pkg/valgrid.tar # INSTALL Protobuf (static) -RUN cd /tmp && \ - git clone --depth=1 https://github.com/protocolbuffers/protobuf.git -b v3.9.1 && \ - cd protobuf/cmake && \ - cmake -G Ninja . -Dprotobuf_BUILD_TESTS=OFF \ - -Dprotobuf_BUILD_SHARED_LIBS=OFF \ - -DCMAKE_POSITION_INDEPENDENT_CODE=ON && \ - ninja install && \ - rm -fr /tmp/protobuf +RUN --mount=from=pb,target=/pkg,source=/tmp tar -C / -xf /pkg/protobuf.tar # INSTALL glog -RUN cd /tmp && \ - git clone --depth=1 https://github.com/google/glog.git -b v0.4.0 && \ - cd glog && \ - cmake -H. -Bbuild -G "Unix Makefiles" && cmake --build build && \ - cmake --build build --target install && ldconfig && \ - rm -fr /tmp/glog +RUN --mount=from=glog,target=/pkg,source=/tmp tar -C / -xf /pkg/glog.tar # Install GCC-10 RUN add-apt-repository ppa:ubuntu-toolchain-r/test && \ @@ -185,10 +243,7 @@ RUN add-apt-repository ppa:ubuntu-toolchain-r/test && \ apt-get clean && apt-get purge && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* # Build & Install DNNL (MKLDNN) -RUN cd /tmp && git clone --depth=1 https://github.com/oneapi-src/oneDNN.git --branch v1.7 && \ - cd /tmp/oneDNN && \ - cmake -G Ninja -DDNNL_BUILD_EXAMPLES=OFF -DDNNL_BUILD_TESTS=OFF -DDNNL_ENABLE_PRIMITIVE_CACHE=ON -DCMAKE_INSTALL_PREFIX=/opt/dnnl && \ - ninja install +RUN --mount=from=dnnl,target=/pkg,source=/tmp tar -C / -xf /pkg/dnnl.tar # Install Parallel RUN apt-get update && \ @@ -196,24 +251,13 @@ RUN apt-get update && \ apt-get clean && apt-get purge && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* # Install Eigen -RUN cd /tmp && wget https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.bz2 && \ - tar jxvf eigen-3.4.0.tar.bz2 && mv eigen-3.4.0 /opt +RUN curl -s https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.bz2 | tar -C /opt -xvj # Install XNNPack -RUN cd /tmp && git clone https://github.com/google/XNNPACK.git && \ - cd /tmp/XNNPACK && git checkout -b tmp 90db69f681ea9abd1ced813c17c00007f14ce58b && \ - mkdir /tmp/xnn_build_static && cd /tmp/xnn_build_static && \ - cmake -G Ninja ../XNNPACK -DXNNPACK_LIBRARY_TYPE=static -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - -DXNNPACK_BUILD_TESTS=OFF -DXNNPACK_BUILD_BENCHMARKS=OFF -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_INSTALL_PREFIX=/opt/XNNPACK && \ - ninja install +RUN --mount=from=xnnpack,target=/pkg,source=/tmp tar -C / -xf /pkg/xnnpack.tar # Install Flatbuffer -RUN cd /tmp && \ - git clone --depth=1 https://github.com/google/flatbuffers.git -b v1.12.0 && \ - cd flatbuffers && \ - cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release -DFLATBUFFERS_BUILD_SHAREDLIB=ON && make -j && make install && \ - rm -fr /tmp/flatbuffers +RUN --mount=from=fb,target=/pkg,source=/tmp tar -C / -xf /pkg/flatbuffers.tar # INSATLL ONEAPI RUN if [[ ! "${BASE_IMAGE}" =~ "nvidia" ]]; then wget https://registrationcenter-download.intel.com/akdlm/irc_nas/17769/l_BaseKit_p_2021.2.0.2883_offline.sh && \ diff --git a/utils/docker/build_image.sh b/utils/docker/build_image.sh index fe0bc99be..512ee2683 100755 --- a/utils/docker/build_image.sh +++ b/utils/docker/build_image.sh @@ -1,6 +1,6 @@ #!/bin/bash -xe -VER="0.7.6" +VER="0.7.7" FLAVOR="devel" NAMESPACE="registry-intl.us-west-1.aliyuncs.com/computation" @@ -10,6 +10,6 @@ base_image_cpu="ubuntu:18.04" #docker build --build-arg BASE_IMAGE=${base_image_cpu} \ # -t $NAMESPACE/halo:$VER-$FLAVOR-x86_64-ubuntu18.04 . -docker build --build-arg BASE_IMAGE=${base_image_gpu} \ +DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE=${base_image_gpu} \ -t $NAMESPACE/halo:$VER-$FLAVOR-cuda11.4.2-cudnn8-ubuntu18.04 . diff --git a/utils/docker/start_docker_cpu.sh b/utils/docker/start_docker_cpu.sh index 5267a1e04..7d6a53715 100755 --- a/utils/docker/start_docker_cpu.sh +++ b/utils/docker/start_docker_cpu.sh @@ -1,6 +1,6 @@ #!/bin/bash -xe -VER="0.7.6" +VER="0.7.7" FLAVOR="devel" NAMESPACE="registry-intl.us-west-1.aliyuncs.com/computation" From b3d5cc57648aba3ad9eccd507291151fe8cee894 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Sat, 11 Dec 2021 07:15:13 +0000 Subject: [PATCH 02/12] Add shape op to support dynamic shapes --- ODLA/include/ODLA/ops/odla_ops_process.h | 2 +- include/halo/lib/ir/common_instructions.td | 6 +++++ include/halo/lib/ir/tf_convert.td | 5 +++- .../target/generic_cxx/generic_cxx_codegen.h | 1 + include/halo/lib/transforms/inst_simplify.h | 1 + lib/target/generic_cpp/reshape.cc | 12 +++++++++ lib/transforms/inst_simplify.cc | 25 +++++++++++++++++++ lib/transforms/tfextension_legalizer.cc | 22 ---------------- lib/transforms/type_legalizer.cc | 10 ++++++++ 9 files changed, 60 insertions(+), 24 deletions(-) diff --git a/ODLA/include/ODLA/ops/odla_ops_process.h b/ODLA/include/ODLA/ops/odla_ops_process.h index 9306faddd..1f5b9c618 100644 --- a/ODLA/include/ODLA/ops/odla_ops_process.h +++ b/ODLA/include/ODLA/ops/odla_ops_process.h @@ -259,8 +259,8 @@ odla_Resize(odla_value input, odla_interpolation_mode interpolation, the result value is implementation determined. \param input the input value - \param value_id a unique value id (can be NULL) \param output_dims the optional output shape (can be undefined) + \param value_id a unique value id (can be NULL) \return odla_value */ diff --git a/include/halo/lib/ir/common_instructions.td b/include/halo/lib/ir/common_instructions.td index 626d93086..0e65d8de4 100644 --- a/include/halo/lib/ir/common_instructions.td +++ b/include/halo/lib/ir/common_instructions.td @@ -78,6 +78,12 @@ let cat_ = cat_common in { let outs_ = [Arg<"The result.", MatchArgType<0> >]; } + def Shape : Inst<"Compute the shape of input."> { + let ins_ = [Arg<"The input.", ArgType<[I8,I16,I32,F16,F32]> >]; + let attrs_ = [Attr<"The output date type size", EnumDataType, "data_type", "INT64">]; + let outs_ = [Arg<"The result.", ArgType<[I64, I32]> >]; + } + def Reshape : Inst<"Reshape the input X1 to create the result with the same" " number of elements and the shape specified by X2."> { let ins_ = [Arg<"The input.", ArgType<[I8,I16,I32,F16,F32]> >, diff --git a/include/halo/lib/ir/tf_convert.td b/include/halo/lib/ir/tf_convert.td index 140016948..32996b3cb 100644 --- a/include/halo/lib/ir/tf_convert.td +++ b/include/halo/lib/ir/tf_convert.td @@ -44,7 +44,10 @@ def TF_Reshape : TFExtension<"Reshape"> { let extension_attr_ = [ ExtensionAttr<"shape", IntegerList, "{}"> ]; } -def TF_Shape : TFExtension<"Shape">; +def TF_Shape: OpMapping<"Shape", Shape> { + let attr_mapping_ = [ + AttributeMapping<"", "data_type", "INT32">]; +} def TF_SquaredDifference : TFExtension<"SquaredDifference">; diff --git a/include/halo/lib/target/generic_cxx/generic_cxx_codegen.h b/include/halo/lib/target/generic_cxx/generic_cxx_codegen.h index c1aeb4f7d..54864844c 100644 --- a/include/halo/lib/target/generic_cxx/generic_cxx_codegen.h +++ b/include/halo/lib/target/generic_cxx/generic_cxx_codegen.h @@ -177,6 +177,7 @@ class GenericCXXCodeGen : public CodeGen { virtual void RunOnInstruction(ReturnInst*) override; virtual void RunOnInstruction(RNNInst*) override; virtual void RunOnInstruction(SelectInst*) override; + virtual void RunOnInstruction(ShapeInst*) override; virtual void RunOnInstruction(ShiftInst*) override; virtual void RunOnInstruction(ShrinkInst*) override; virtual void RunOnInstruction(SItoFPInst*) override; diff --git a/include/halo/lib/transforms/inst_simplify.h b/include/halo/lib/transforms/inst_simplify.h index 3456fd621..2518a2979 100644 --- a/include/halo/lib/transforms/inst_simplify.h +++ b/include/halo/lib/transforms/inst_simplify.h @@ -65,6 +65,7 @@ class InstSimplify final : public BasicBlockPass { static std::pair RunOnInstruction(ResizeInst* inst); static std::pair RunOnInstruction(SelectInst* inst); static std::pair RunOnInstruction(SetDiff1DInst* inst); + static std::pair RunOnInstruction(ShapeInst* inst); static std::pair RunOnInstruction(SigmoidInst* inst); static std::pair RunOnInstruction(SItoFPInst* inst); static std::pair RunOnInstruction(FPtoSIInst* inst); diff --git a/lib/target/generic_cpp/reshape.cc b/lib/target/generic_cpp/reshape.cc index 06cd83fba..fdf4a6961 100644 --- a/lib/target/generic_cpp/reshape.cc +++ b/lib/target/generic_cpp/reshape.cc @@ -32,4 +32,16 @@ void GenericCXXCodeGen::RunOnInstruction(ReshapeInst* inst) { ir_mapping_[*inst] = ret; } +void GenericCXXCodeGen::RunOnInstruction(ShapeInst* inst) { + const Def& input = inst->GetOperand(0); + + CXXValue op0 = ir_mapping_[input]; + + const auto& ret_type = inst->GetResultType(); + CXXValue ret(inst->GetName(), op0.type); + EmitODLACall(ret, "odla_Shape", op0, EmitShape(ret_type)); + + ir_mapping_[*inst] = ret; +} + } // namespace halo diff --git a/lib/transforms/inst_simplify.cc b/lib/transforms/inst_simplify.cc index 3b3e75e68..62ddf0a79 100644 --- a/lib/transforms/inst_simplify.cc +++ b/lib/transforms/inst_simplify.cc @@ -873,6 +873,31 @@ std::pair InstSimplify::RunOnInstruction(Relu6Inst* inst) { }); } +std::pair InstSimplify::RunOnInstruction(ShapeInst* inst) { + const auto& type = inst->GetOperand(0).GetType(); + + Def orig_def{inst, 0}; + if (!type.IsValid() || type.IsDynamicShape() || type.IsDynamicBatch()) { + return {orig_def, orig_def}; + } + + DataType dt = inst->GetDataType(); + ConstantBuilder cb(inst->GetParent()->GetParent()); + int64_t rank = type.GetNumOfDims(); + if (dt == DataType::INT32) { + std::vector shape; + for (int64_t i : type.GetDimSizes()) { + shape.push_back(static_cast(i)); + } + Constant* c = cb.CreateConstant(inst->GetName(), halo::Type{dt, {rank}}, + shape.data()); + return {orig_def, *c}; + } + HLCHECK(dt == DataType::INT64); + Constant* c = cb.CreateConstant(inst->GetName(), halo::Type{dt, {rank}}, + type.GetDimSizes()); + return {orig_def, *c}; +} std::pair InstSimplify::RunOnInstruction(SigmoidInst* inst) { return SinkTranspose( *inst, [](IRBuilder& builder, const std::string& name, const Def& op) { diff --git a/lib/transforms/tfextension_legalizer.cc b/lib/transforms/tfextension_legalizer.cc index 912fd8d27..59a7050ae 100644 --- a/lib/transforms/tfextension_legalizer.cc +++ b/lib/transforms/tfextension_legalizer.cc @@ -261,25 +261,6 @@ static std::vector ConvertFill(const TFExtensionInst* ext, return {}; } -static std::vector ConvertShape(const TFExtensionInst* ext, - IRBuilder* builder) { - auto input = ext->GetOperand(0); - const Type& input_type = input.GetType(); - if (!input_type.IsValid()) { - return {}; - } - std::vector shape; - for (int64_t i : input_type.GetDimSizes()) { - shape.push_back(static_cast(i)); - } - ConstantBuilder cb(ext->GetParent()->GetParent()); - Constant* c = cb.CreateConstant( - ext->GetName() + "_shape", - Type{DataType::INT32, {static_cast(input_type.GetNumOfDims())}}, - shape.data()); - return {*c}; -} - static std::vector ConvertSize(const TFExtensionInst* ext, IRBuilder* builder) { const auto& type = ext->GetOperand(0).GetType(); @@ -1089,9 +1070,6 @@ static std::vector ConvertTFExtension(const TFExtensionInst* tf_inst, case TFExtOpCode::SIZE: { return ConvertSize(tf_inst, builder); } - case TFExtOpCode::SHAPE: { - return ConvertShape(tf_inst, builder); - } case TFExtOpCode::SPLIT: { return ConvertSplit(tf_inst, builder); } diff --git a/lib/transforms/type_legalizer.cc b/lib/transforms/type_legalizer.cc index 560f21e71..c1d23ba8f 100644 --- a/lib/transforms/type_legalizer.cc +++ b/lib/transforms/type_legalizer.cc @@ -659,6 +659,16 @@ static void RunOnCommonReductionInstruction(T* inst, std::vector axis, inst->GetResultsTypes()[0] = halo::Type{dt, ret_shape}; } +static void RunOnInstruction(ShapeInst* inst) { + const Type& input_type = inst->GetOperand(0).GetType(); + + if (!input_type.IsValid()) { + return; + } + int rank = input_type.GetNumOfDims(); + inst->GetResultsTypes()[0] = halo::Type{inst->GetDataType(), {rank}}; +} + static void RunOnInstruction(ReduceL1Inst* inst) { RunOnCommonReductionInstruction(inst, inst->GetAxis(), inst->GetKeepDims()); } From a4bd8fb031fed8a2cd3396cc6863097762a24869 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Sat, 11 Dec 2021 07:15:50 +0000 Subject: [PATCH 03/12] wip --- lib/transforms/tfextension_legalizer.cc | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/lib/transforms/tfextension_legalizer.cc b/lib/transforms/tfextension_legalizer.cc index 59a7050ae..39a95154a 100644 --- a/lib/transforms/tfextension_legalizer.cc +++ b/lib/transforms/tfextension_legalizer.cc @@ -456,11 +456,15 @@ static std::vector ConvertStridedSlice(const TFExtensionInst* ext, } new_axis_mask >>= 1; } - - Constant* c_shape = cb.CreateConstant( - new_slice_inst->GetName() + "_shape", - Type{DataType::INT32, {static_cast(new_dims.size())}}, - new_dims.data()); + std::vector new_shape; + if (!new_dims.empty()) { + new_shape.push_back(new_dims.size()); + } else { + new_dims.push_back(1); // slice and reshape to a single scalar. + } + Constant* c_shape = + cb.CreateConstant(new_slice_inst->GetName() + "_shape", + Type{DataType::INT32, new_shape}, new_dims.data()); new_slice_inst = builder->CreateReshape(ext->GetName() + "_reshape", {Def{new_slice_inst, 0}, *c_shape}); } From dfecadff1d928bc52aeb9b982ecb27b51b2190bb Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Sat, 11 Dec 2021 07:16:01 +0000 Subject: [PATCH 04/12] debug output --- lib/target/generic_cpp/generic_cxx_codegen.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/target/generic_cpp/generic_cxx_codegen.cc b/lib/target/generic_cpp/generic_cxx_codegen.cc index ada9de59f..d07ce937d 100644 --- a/lib/target/generic_cpp/generic_cxx_codegen.cc +++ b/lib/target/generic_cpp/generic_cxx_codegen.cc @@ -125,6 +125,7 @@ static std::string GetBF16Mode(BF16Mode mode) { } bool GenericCXXCodeGen::RunOnModule(Module* module) { + module->Dump(); memory_analyzer_ = std::make_unique(*module); Function* entry_func = nullptr; EmitBanner(&os_, &header_os_, GetAPI()); From 6f23c127ba063af6e30fa9d47cd80a8bd49e8ae6 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Sat, 11 Dec 2021 07:16:20 +0000 Subject: [PATCH 05/12] add missing files --- include/halo/lib/transforms/convert_tf_cfg.h | 37 +++ lib/transforms/convert_tf_cfg.cc | 329 +++++++++++++++++++ 2 files changed, 366 insertions(+) create mode 100644 include/halo/lib/transforms/convert_tf_cfg.h create mode 100644 lib/transforms/convert_tf_cfg.cc diff --git a/include/halo/lib/transforms/convert_tf_cfg.h b/include/halo/lib/transforms/convert_tf_cfg.h new file mode 100644 index 000000000..6264dd738 --- /dev/null +++ b/include/halo/lib/transforms/convert_tf_cfg.h @@ -0,0 +1,37 @@ +//===- convert_tf_cfg.h ---------------------------------------------------===// +// +// Copyright (C) 2019-2020 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef HALO_LIB_TRANSFORMS_CONVERT_TF_CFG_H_ +#define HALO_LIB_TRANSFORMS_CONVERT_TF_CFG_H_ + +#include "halo/lib/pass/pass.h" + +namespace halo { + +/// This pass eliminates dead IRs. +class ConvertTFCFG final : public FunctionPass { + public: + ConvertTFCFG() : FunctionPass("Convert TF CFG"), converted_(false) {} + bool RunOnFunction(Function* func) override; + + private: + bool converted_; +}; + +} // end namespace halo. + +#endif // HALO_LIB_TRANSFORMS_CONVERT_TF_CFG_H_ \ No newline at end of file diff --git a/lib/transforms/convert_tf_cfg.cc b/lib/transforms/convert_tf_cfg.cc new file mode 100644 index 000000000..960882e05 --- /dev/null +++ b/lib/transforms/convert_tf_cfg.cc @@ -0,0 +1,329 @@ +//===- convert_tf_cfg.cc --------------------------------------------------===// +// +// Copyright (C) 2019-2021 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "halo/lib/transforms/convert_tf_cfg.h" + +#include +#include + +#include "halo/api/halo_data.h" +#include "halo/lib/ir/controlflow_instructions.h" +#include "halo/lib/ir/extension_instructions.h" +#include "halo/lib/ir/instruction.h" +#include "halo/lib/ir/ir_builder.h" + +namespace halo { + +static bool MergeIfs(BasicBlock* bb) { + bool changed = false; + std::unordered_map ifs; + for (const auto& it : *bb) { + IfInst* inst = DynCast(it.get()); + if (inst == nullptr) { + continue; + } + const Def& cond = inst->GetOperand(0); + if (ifs.count(cond) == 0) { + ifs[cond] = inst; + continue; + } + + IfInst* dst = ifs[cond]; + for (int i = 1, e = inst->GetNumOfOperands(); i < e; ++i) { + dst->AddOneOperand(inst->GetOperand(i)); + } + inst->DropAllOperands(); + for (unsigned i = 0, dst_idx = dst->GetNumOfResults(); + i < inst->GetNumOfResults(); ++i) { + const auto& ty = inst->GetResultsTypes()[i]; + dst->GetResultsTypes().push_back(ty); + inst->ReplaceAllUsesWith(i, Def{dst, static_cast(dst_idx++)}); + } + inst->GetThenBranch()->MoveTo(dst->GetThenBranch()); + inst->GetElseBranch()->MoveTo(dst->GetElseBranch()); + // merge arguments + auto merge_args = [](BasicBlock* bb) { + if (bb->GetNumOfOperands() == 0) { // who removed args? + return; + } + HLCHECK(bb->GetNumOfOperands() == 2); + Def arg0{bb->arg_front(), 0}; + bb->arg_back()->ReplaceAllUsesWith({arg0}); + bb->Args().pop_back(); + }; + merge_args(inst->GetThenBranch()); + merge_args(inst->GetElseBranch()); + + changed = true; + } + return changed; +} + +static void RewriteOutput(IfInst* if_inst, const std::vector& ops, + bool is_taken) { + auto bb = is_taken ? if_inst->GetThenBranch() : if_inst->GetElseBranch(); + ReturnInst* ret = bb->GetReturnInst(); + HLCHECK(ret != nullptr); + ret->DropAllOperands(); + HLCHECK(if_inst->GetNumOfResults() == (if_inst->GetNumOfOperands() - 1) * 2); + for (auto op : ops) { + // if's output: [v1_f, v1_t, v2_f, v2_t, ...], inputs: [cond, v1, v2, v3] + // Branch bb's args: [arg1, arg2, arg3] + if (op.GetOwner() == if_inst) { + op = Def{std::next(bb->arg_begin(), op.GetIdx() / 2)->get(), 0}; + } + ret->AddOneOperand(op); + } +} + +static bool RunOnBasicBlock(BasicBlock* bb) { + // run on main bb only. Fixme: need to deal with nested if. + if (bb != bb->GetParent()->begin()->get()) { + return false; + } + bool changed = false; + changed |= MergeIfs(bb); + std::unordered_map branch_bbs; + for (const auto& it : *bb) { + IfInst* inst = DynCast(it.get()); + if (inst != nullptr) { + branch_bbs[inst->GetThenBranch()] = inst; + branch_bbs[inst->GetElseBranch()] = inst; + } + } + + for (const auto& it : *bb) { + Instruction* inst = it.get(); + // tf_merge will be handled later. + if (auto ext = DynCast(inst); + ext != nullptr && ext->GetExtOpCode() == TFExtOpCode::MERGE) { + continue; + } + + /* + if (auto ext = DynCast(inst); + ext != nullptr && ext->GetExtOpCode() == TFExtOpCode::MERGE) { + // Handle tf_Merge: all the operands should come from if's output or + some + // branch's output. + IfInst* if_inst = nullptr; + int idx = 0; + for (auto& op : ext->GetOperands()) { + Instruction* op_inst = DynCast(op); + if (op_inst->GetOpCode() == OpCode::IF) { + // some branch is empty. nested if? + if (if_inst != nullptr && if_inst != op_inst) { + HLCHECK(0); + if_inst = nullptr; // merge inputs are from different "if" + break; + } + if_inst = DynCast(op_inst); + idx = op.GetIdx(); + } else { + BasicBlock* bb = op_inst->GetParent(); + auto it = branch_bbs.find(bb); + HLCHECK(it != branch_bbs.end()); + HLCHECK(if_inst == nullptr || if_inst == it->second); + if_inst = it->second; + } + } + if (if_inst != nullptr) { + // FIXME: + std::cout << "Replace with " << idx << "\n"; + // ext->ReplaceAllUsesWith(0, Def{if_inst, idx});// work as a barrier + } + continue; + } + */ + BasicBlock* new_parent = nullptr; + for (int i = 0, e = inst->GetNumOfOperands(); i < e; ++i) { + const auto& op = inst->GetOperand(i); + auto if_inst = DynCast(op); + if (if_inst != nullptr) { + int idx = op.GetIdx(); + auto bb = (idx & 1) == 0 ? if_inst->GetElseBranch() + : if_inst->GetThenBranch(); + if (new_parent == nullptr) { + new_parent = bb; + } else { + HLCHECK(new_parent == bb); + } + } else { + Instruction* op_inst = DynCast(op); + BasicBlock* op_bb = op_inst == nullptr ? nullptr : op_inst->GetParent(); + if (branch_bbs.count(op_bb) > 0) { + if (new_parent == nullptr) { + new_parent = op_bb; + } else { + HLCHECK(new_parent == op_bb); + } + } + } + } + if (new_parent != nullptr) { + IfInst* if_inst = branch_bbs[new_parent]; + HLCHECK(if_inst != nullptr); + IRBuilder new_builder(new_parent); + new_builder.SetInsertBefore(new_parent->GetReturnInst()); + std::vector operands = inst->GetOperands(); + for (auto& op : operands) { + if (op.GetOwner() == if_inst) { + op = Def{std::next(new_parent->arg_begin(), op.GetIdx() / 2)->get(), + 0}; + } + } + auto new_inst = new_builder.Clone(*inst, operands); + new_inst->GetResultsTypes() = inst->GetResultsTypes(); + HLCHECK(new_inst->GetOpCode() != OpCode::RETURN); + for (int i = 0, e = inst->GetNumOfResults(); i < e; ++i) { + inst->ReplaceAllUsesWith(i, Def{new_inst, i}); + } + changed |= true; + } /*else { + std::vector operands = inst->GetOperands(); + BasicBlock* new_parent = nullptr; + for (auto& op : operands) { + Instruction* op_inst = DynCast(op); + BasicBlock* op_bb = op_inst == nullptr ? nullptr : op_inst->GetParent(); + if (op_bb != nullptr && op_bb != bb) { + if (new_parent != nullptr && new_parent != op_bb) { + std::cerr << "unexpected parent\n"; + } + new_parent = op_bb; + } + } + if (new_parent != nullptr) { + IRBuilder new_builder(new_parent); + new_builder.SetInsertBefore(new_parent->GetReturnInst()); + auto new_inst = new_builder.Clone(*inst, operands); + new_inst->GetResultsTypes() = inst->GetResultsTypes(); + + HLCHECK(new_inst->GetOpCode() != OpCode::RETURN); + // inst->Dump(); + for (int i = 0, e = inst->GetNumOfResults(); i < e; ++i) { + inst->ReplaceAllUsesWith(i, Def{new_inst, i}); + } + changed |= true; + } + } + }*/ + } + + // Merge multiple tf_merge that associates with same if. + // All the inputs of tf_merge should associate with the same if. + std::unordered_map> if2merge; + + for (const auto& it : *bb) { + TFExtensionInst* inst = DynCast(it.get()); + if (inst == nullptr || inst->GetExtOpCode() != TFExtOpCode::MERGE) { + continue; + } + IfInst* if_inst = nullptr; + // int idx = 0; + for (auto& op : inst->GetOperands()) { + Instruction* op_inst = DynCast(op); + if (op_inst->GetOpCode() == OpCode::IF) { + // some branch is empty. nested if? + HLCHECK(if_inst == nullptr || if_inst == op_inst); + if_inst = DynCast(op_inst); + // idx = op.GetIdx(); + } else { + BasicBlock* bb = op_inst->GetParent(); + auto it = branch_bbs.find(bb); + HLCHECK(it != branch_bbs.end()); + HLCHECK(if_inst == nullptr || if_inst == it->second); + if_inst = it->second; + // Make it be the output of if. + } + } + HLCHECK(if_inst != nullptr); + if2merge[if_inst].push_back(inst); + } + for (auto& if_merge : if2merge) { + std::vector true_ops; + std::vector false_ops; + IfInst* if_inst = if_merge.first; + std::set op_indices; + for (Instruction* merge : if_merge.second) { + for (auto& op : merge->GetOperands()) { + bool is_taken = false; + if (op.GetOwner() == if_inst) { + is_taken = (op.GetIdx() & 1) == 1; + } else { + const Instruction* inst = DynCast(op.GetOwner()); + HLCHECK(inst != nullptr); + const auto& bb = inst->GetParent(); + auto it = branch_bbs.find(bb); + HLCHECK(branch_bbs.end() != it); + HLCHECK(bb == it->second->GetThenBranch() || + bb == it->second->GetElseBranch()); + is_taken = bb == it->second->GetThenBranch(); + } + if (is_taken) { + true_ops.push_back(op); + } else { + false_ops.push_back(op); + } + } + } + HLCHECK(true_ops.size() == false_ops.size()); + + RewriteOutput(if_inst, true_ops, true); + RewriteOutput(if_inst, false_ops, false); + std::vector rets; + rets.reserve(true_ops.size()); + for (int i = 0, e = true_ops.size(); i < e; ++i) { + const auto& true_ty = true_ops[i].GetType(); + const auto& false_ty = false_ops[i].GetType(); + // The output type is dynamic. Here we just pick a valid one. + rets.push_back(true_ty.IsValid() ? true_ty : false_ty); + } + if_inst->GetResultsTypes() = rets; + for (int i = 0, e = if_merge.second.size(); i < e; ++i) { + if_merge.second[i]->ReplaceAllUsesWith({Def{if_inst, i}}); + } + } + // Modify TF_Merge and associated "if": + // Before: + // if_results(true_val, false_val) = if (...) + // out = merge(if_results) + // After: + // if_result(val) = if(...) + // out = val + + return changed; +} // namespace halo + +bool ConvertTFCFG::RunOnFunction(Function* func) { + bool changed = false; + if (converted_) { + return false; + } + for (auto it = func->begin(), e = func->end(); it != e;) { + BasicBlock* bb = it->get(); + if (bb->Instructions().empty()) { + it = func->BasicBlocks().erase(it); + continue; + } + changed |= RunOnBasicBlock(bb); + it = std::next(it); + } + converted_ = true; + return changed; +} + +} // end namespace halo From 9eb878d3ed5701feb678574897f9e6c36b29354a Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Mon, 13 Dec 2021 20:51:44 +0000 Subject: [PATCH 06/12] fix bug --- lib/transforms/type_legalizer.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/transforms/type_legalizer.cc b/lib/transforms/type_legalizer.cc index c1d23ba8f..e73f3916d 100644 --- a/lib/transforms/type_legalizer.cc +++ b/lib/transforms/type_legalizer.cc @@ -1600,7 +1600,9 @@ bool TypeLegalizer::RunOnBasicBlock(BasicBlock* bb) { #undef GET_INST_DOWNCAST_SWITCH case OpCode::EXTENSION: { TFExtensionInst* ext = DynCast(inst); - RunOnInstruction(ext); + if (ext != nullptr) { + RunOnInstruction(ext); + } break; } default: { From c8ca9428f05d99583af82a0499b7e7a4674f5169 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Tue, 14 Dec 2021 06:37:09 +0000 Subject: [PATCH 07/12] add Stack Op --- .../target/generic_cxx/generic_cxx_codegen.h | 3 ++- lib/target/generic_cpp/concat.cc | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/include/halo/lib/target/generic_cxx/generic_cxx_codegen.h b/include/halo/lib/target/generic_cxx/generic_cxx_codegen.h index 54864844c..ed74f8862 100644 --- a/include/halo/lib/target/generic_cxx/generic_cxx_codegen.h +++ b/include/halo/lib/target/generic_cxx/generic_cxx_codegen.h @@ -180,12 +180,13 @@ class GenericCXXCodeGen : public CodeGen { virtual void RunOnInstruction(ShapeInst*) override; virtual void RunOnInstruction(ShiftInst*) override; virtual void RunOnInstruction(ShrinkInst*) override; + virtual void RunOnInstruction(SigmoidInst*) override; virtual void RunOnInstruction(SItoFPInst*) override; virtual void RunOnInstruction(SliceInst*) override; virtual void RunOnInstruction(SoftmaxInst*) override; virtual void RunOnInstruction(SoftplusInst*) override; virtual void RunOnInstruction(SoftsignInst*) override; - virtual void RunOnInstruction(SigmoidInst*) override; + virtual void RunOnInstruction(StackInst*) override; virtual void RunOnInstruction(HardSigmoidInst*) override; virtual void RunOnInstruction(SinInst*) override; virtual void RunOnInstruction(SinhInst*) override; diff --git a/lib/target/generic_cpp/concat.cc b/lib/target/generic_cpp/concat.cc index e50b242de..bb095231f 100644 --- a/lib/target/generic_cpp/concat.cc +++ b/lib/target/generic_cpp/concat.cc @@ -38,4 +38,22 @@ void GenericCXXCodeGen::RunOnInstruction(ConcatInst* inst) { EmitODLACall(ret, "odla_Concat", inputs, axis, EmitShape(ret_shape)); } +void GenericCXXCodeGen::RunOnInstruction(StackInst* inst) { + const auto& axis = inst->GetAxis(); + + CXXValue op0 = ir_mapping_[inst->GetOperand(0)]; + CXXValue ret(inst->GetName(), op0.type); + + ir_mapping_[*inst] = ret; + const halo::Type& ret_shape = inst->GetResultType(); + const auto num = inst->GetNumOfOperands(); + std::vector inputs; + for (size_t i = 0; i < num; ++i) { + const Def& op = inst->GetOperand(i); + CXXValue op_v = ir_mapping_[op]; + inputs.push_back(op_v); + } + EmitODLACall(ret, "odla_Stack", inputs, axis, EmitShape(ret_shape)); +} + } // namespace halo From b599da04f1d9debbc48a0446929152655f6dedf6 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Tue, 14 Dec 2021 06:40:08 +0000 Subject: [PATCH 08/12] Constant Folding for Slice: handle special case If the input is not constant but partially constant, we can still convert it to Constant if the data to be sliced happen to be constants. --- lib/transforms/inst_simplify.cc | 35 +++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/lib/transforms/inst_simplify.cc b/lib/transforms/inst_simplify.cc index 62ddf0a79..658688e0d 100644 --- a/lib/transforms/inst_simplify.cc +++ b/lib/transforms/inst_simplify.cc @@ -2429,6 +2429,41 @@ std::pair InstSimplify::RunOnInstruction(SliceInst* inst) { return {orig_def, *new_inst}; } } + // Some special cases. + // If input is a partial constant (sliced data is constant), we can still + // convert the result to contant. + const Constant* c_start = DynCast(inst->GetOperand(1)); + const Constant* c_len = DynCast(inst->GetOperand(2)); + + if (const auto shape_inst = DynCast(inst->GetOperand(0)); + shape_inst != nullptr && shape_inst->GetResultType().IsValid() && + c_start != nullptr && c_len != nullptr && + inst->GetResultType().IsValid()) { + const auto& ret_type = inst->GetResultType(); + HLCHECK(ret_type.GetNumOfDims() <= 1); + // rank must be 1 or 0 (scalar) and axis must be 0. + const auto& value_type = shape_inst->GetOperand(0).GetType(); + HLCHECK(value_type.IsValid()); + const auto& shape = value_type.GetDimSizes(); + auto from = c_start->GetDataAsInt64(0); + auto len = c_len->GetDataAsInt64(0); + bool is_constant = true; + std::vector data(shape.begin() + from, shape.begin() + from + len); + HLCHECK(static_cast(data.size()) == + ret_type.GetTotalNumOfElements()); + for (auto x : data) { + if (x < 0) { + is_constant = false; + break; + } + } + if (is_constant) { + ConstantBuilder cb(inst->GetParent()->GetParent()); + auto c = cb.CreateConstant(inst->GetName(), inst->GetResultType(), + data.data()); + return {orig_def, *c}; + } + } return {orig_def, orig_def}; } From 6aad491531ddcf405d4b79ca34ecc297a9d65af4 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Tue, 14 Dec 2021 06:43:30 +0000 Subject: [PATCH 09/12] TypeLegalization: support partial shape inference for Reshape and Slice --- lib/transforms/type_legalizer.cc | 71 +++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/lib/transforms/type_legalizer.cc b/lib/transforms/type_legalizer.cc index e73f3916d..cc6339913 100644 --- a/lib/transforms/type_legalizer.cc +++ b/lib/transforms/type_legalizer.cc @@ -264,9 +264,49 @@ static void RunOnInstruction(CompressInst* inst) { inst->GetResultsTypes()[0] = Type{dt, ret_dims}; } +// TODO: move to util +static int GetConstantValue(const StackInst& inst, int idx) { + int elems = 0; + for (auto& op : inst.GetOperands()) { + const auto& ty = op.GetType(); + if (!ty.IsValid()) { + return -1; + } + if (idx >= elems && idx < elems + ty.GetTotalNumOfElements()) { + auto c = DynCast(op); + if (c == nullptr) { + return -1; + } + return c->GetDataAsInt64(idx - elems); + } + elems += ty.GetTotalNumOfElements(); + } + return -1; +} + static void RunOnInstruction(ReshapeInst* inst) { auto& op0_type = inst->GetOperand(0).GetType(); Def op1 = inst->GetOperand(1); + if (!IsA(op1) && op1.GetType().IsValid() && + (op0_type.IsDynamicBatch() || op0_type.IsDynamicShape())) { + int rank = op1.GetType().GetTotalNumOfElements(); + std::vector new_shape(rank); + unsigned unknown_dims = 0; + if (IsA(op1)) { + const auto& stack = *DynCast(op1); + for (int i = 0; i < rank; ++i) { + new_shape[i] = GetConstantValue(stack, i); + } + /* + unknown_dims = std::count_if(new_shape.begin(), new_shape.end(), + [](int64_t x) { return x < 0; });*/ + if (unknown_dims <= 1) { // FIXME: remove + inst->GetResultsTypes()[0] = + halo::Type{op0_type.GetDataType(), new_shape}; + return; + } + } + } if (!IsA(op1)) { return; } @@ -928,11 +968,40 @@ static void RunOnInstruction(SliceInst* inst) { auto op_len = inst->GetOperand(2); auto& input_type = op0.GetType(); - if (!input_type.IsValid() || !IsA(op_len)) { + if (!input_type.IsValid()) { return; } auto dims = input_type.GetNumOfDims(); + if ((input_type.IsDynamicShape() || input_type.IsDynamicBatch()) && + !IsA(op_len)) { + auto ret_shape = input_type.GetDimSizes(); + bool is_constant_len = true; + if (IsA(op_len) && op_len.GetType().IsValid()) { + for (unsigned i = 0; i < dims; ++i) { + if (ret_shape[i] == kDynamicBatchSize || + ret_shape[i] == kDynamicShapeSize) { + continue; + } + int64_t v = GetConstantValue(*DynCast(op_len), (int)i); + if (v < 0) { + is_constant_len = false; + break; + } + ret_shape[i] = v; + } + if (is_constant_len) { + inst->GetResultsTypes()[0] = + halo::Type{input_type.GetDataType(), ret_shape}; + } + } + return; + } + + if (!IsA(op_len)) { + return; + } + std::unordered_set axes; if (inst->GetNumOfOperands() > 4) { auto op_axes = inst->GetOperand(4); From 05f052c87323b36d16ed9da65456e5a78cc91e27 Mon Sep 17 00:00:00 2001 From: Weiming Zhao Date: Tue, 14 Dec 2021 06:44:36 +0000 Subject: [PATCH 10/12] Add new ODLA for slice; Temporary --- lib/target/generic_cpp/slice.cc | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/lib/target/generic_cpp/slice.cc b/lib/target/generic_cpp/slice.cc index 9b8a89b9d..bfaaf1df4 100644 --- a/lib/target/generic_cpp/slice.cc +++ b/lib/target/generic_cpp/slice.cc @@ -50,7 +50,24 @@ static void NormalizerOperands(const Constant& operand, } // end namespace void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) { - const Def input = inst->GetOperand(0); + const Def& input = inst->GetOperand(0); + const Def& start = inst->GetOperand(1); + const Def& size = inst->GetOperand(2); + // auto strides = inst->GetOperand(3); //TODO + + CXXValue op0 = ir_mapping_[input]; + CXXValue ret(inst->GetName(), op0.type); + ir_mapping_[*inst] = ret; + + if (!IsA(start) || !IsA(size)) { + auto op1 = ir_mapping_[start]; + auto op2 = ir_mapping_[size]; + // auto op3 = ir_mapping_[strides]; // FIXME + EmitODLACall(ret, "odla_SliceDynamic", op0, op1, op2, /*op3,*/ + EmitShape(inst->GetResultType())); + + return; + } size_t dims = input.GetType().GetNumOfDims(); std::unordered_set axes; if (inst->GetNumOfOperands() > 4) { @@ -75,7 +92,6 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) { } std::vector start_v(dims, 0); - const Def& start = inst->GetOperand(1); HLCHECK(start.GetType().GetTotalNumOfElements() == static_cast(axes.size())); HLCHECK(IsA(start)); @@ -88,7 +104,6 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) { std::vector size_v(input.GetType().GetDimSizes().begin(), input.GetType().GetDimSizes().end()); - const Def& size = inst->GetOperand(2); HLCHECK(size.GetType().GetTotalNumOfElements() == static_cast(axes.size())); HLCHECK(IsA(size)); @@ -125,12 +140,8 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) { size_v.begin(), std::plus()); } - CXXValue op0 = ir_mapping_[input]; - CXXValue ret(inst->GetName(), op0.type); - EmitODLACall(ret, "odla_Slice", op0, start_v, size_v, strides_v, EmitShape(inst->GetResultType())); - ir_mapping_[*inst] = ret; } } // namespace halo From 94e363554dc6d3dec4edb9f65a3b6612d9801189 Mon Sep 17 00:00:00 2001 From: "maruiyan.mry" Date: Thu, 16 Dec 2021 01:42:12 -0800 Subject: [PATCH 11/12] add SliceDynamic & Stack odla_api --- ODLA/include/ODLA/ops/odla_ops_process.h | 32 ++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/ODLA/include/ODLA/ops/odla_ops_process.h b/ODLA/include/ODLA/ops/odla_ops_process.h index 1f5b9c618..3a3b11105 100644 --- a/ODLA/include/ODLA/ops/odla_ops_process.h +++ b/ODLA/include/ODLA/ops/odla_ops_process.h @@ -286,6 +286,22 @@ odla_Slice(odla_value input, const odla_uint32* start, const odla_uint32* end, const odla_uint32* stride, odla_value_shape output_dims, const odla_value_id value_id); +//! \brief Extract a dynamic slice +/*! + SliceDynamic extracts a dynamic slice from \p input. + + \param input the input value + \param start the offets at each slicing dimension + \param size the number of elements at each slicing dimension + \param output_dims the optional output shape (can be undefined) + \param value_id a unique value id (can be NULL) + + \return odla_value +*/ +extern ODLA_API_EXPORT odla_value ODLA_API_CALL +odla_SliceDynamic(odla_value input, odla_value start, odla_value size, + odla_value_shape output_dims, const odla_value_id value_id); + //! \brief Remove dimensions of size 1 /*! Squeeze removes dimensions of size 1 from the shape of \p input. @@ -303,6 +319,22 @@ extern ODLA_API_EXPORT odla_value ODLA_API_CALL odla_Squeeze(odla_value input, odla_size_t num_of_axes, const odla_uint32* axes, odla_value_shape output_dims, const odla_value_id value_id); +//! \brief Join a sequence of Values along a new axis. +/*! + Stack joins multiple values into single one along a new axis. All inputs + must have the same dimension. + + \param inputs the input values + \param axis the index of the new axis in the dimensions of the result + \param output_shape the result shape + \param value_id a unique value id (can be NULL) + + \return odla_value +*/ +extern ODLA_API_EXPORT odla_value ODLA_API_CALL +odla_Stack(odla_values inputs, odla_int32 axis, odla_value_shape output_shape, + const odla_value_id value_id); + //! \brief Transpose the input /*! Transpose returns a transposed value based on the \p permutation. From a6f8250f6fd52e747458f720e766b269d5031d32 Mon Sep 17 00:00:00 2001 From: "maruiyan.mry" Date: Thu, 16 Dec 2021 05:23:22 -0800 Subject: [PATCH 12/12] fix:slice_op parameters' type-uint2int --- ODLA/include/ODLA/ops/odla_ops_process.h | 4 ++-- ODLA/platforms/dnnl/odla_dnnl.cc | 8 ++++---- ODLA/platforms/odla_popart/odla_ops.cc | 4 ++-- ODLA/platforms/odla_profiler.cc | 4 ++-- ODLA/platforms/tensorrt/odla_tensorrt.cc | 4 ++-- lib/target/generic_cpp/slice.cc | 21 ++++++++++----------- 6 files changed, 22 insertions(+), 23 deletions(-) diff --git a/ODLA/include/ODLA/ops/odla_ops_process.h b/ODLA/include/ODLA/ops/odla_ops_process.h index 3a3b11105..c89ae676e 100644 --- a/ODLA/include/ODLA/ops/odla_ops_process.h +++ b/ODLA/include/ODLA/ops/odla_ops_process.h @@ -282,8 +282,8 @@ odla_Shape(odla_value input, odla_value_shape output_dims, \return odla_value */ extern ODLA_API_EXPORT odla_value ODLA_API_CALL -odla_Slice(odla_value input, const odla_uint32* start, const odla_uint32* end, - const odla_uint32* stride, odla_value_shape output_dims, +odla_Slice(odla_value input, const odla_int32* start, const odla_int32* end, + const odla_int32* stride, odla_value_shape output_dims, const odla_value_id value_id); //! \brief Extract a dynamic slice diff --git a/ODLA/platforms/dnnl/odla_dnnl.cc b/ODLA/platforms/dnnl/odla_dnnl.cc index 9d5a3f029..9f5b2337f 100644 --- a/ODLA/platforms/dnnl/odla_dnnl.cc +++ b/ODLA/platforms/dnnl/odla_dnnl.cc @@ -1834,8 +1834,8 @@ odla_value odla_Erf(odla_value input, const odla_value_id value_id) { static void strided_slice(const void* src, int elem_size, const odla_value_shape& input_dims, - const odla_uint32* start, const odla_uint32* end, - const odla_uint32* strides, void* dst, + const odla_int32* start, const odla_int32* end, + const odla_int32* strides, void* dst, const odla_value_shape& output_dims) { int64_t dst_elems = GetTotalElements(output_dims); int dims = input_dims.size; @@ -1871,8 +1871,8 @@ static void strided_slice(const void* src, int elem_size, } } -odla_value odla_Slice(odla_value input, const odla_uint32* start, - const odla_uint32* end, const odla_uint32* strides, +odla_value odla_Slice(odla_value input, const odla_int32* start, + const odla_int32* end, const odla_int32* strides, odla_value_shape output_dims, const odla_value_id id) { const auto& input_dims = input->shape; int dims = input_dims.size; diff --git a/ODLA/platforms/odla_popart/odla_ops.cc b/ODLA/platforms/odla_popart/odla_ops.cc index cf26acce1..458a92174 100644 --- a/ODLA/platforms/odla_popart/odla_ops.cc +++ b/ODLA/platforms/odla_popart/odla_ops.cc @@ -359,8 +359,8 @@ odla_value odla_ReduceMean(odla_value input, odla_size_t num_of_axes, return result; } -odla_value odla_Slice(odla_value input, const odla_uint32* start, - const odla_uint32* end, const odla_uint32* stride, +odla_value odla_Slice(odla_value input, const odla_int32* start, + const odla_int32* end, const odla_int32* stride, odla_value_shape output_dims, const odla_value_id id) { const auto& name = id ? std::string(reinterpret_cast(id)) : ""; diff --git a/ODLA/platforms/odla_profiler.cc b/ODLA/platforms/odla_profiler.cc index 39b3e789a..53d1ba8ba 100644 --- a/ODLA/platforms/odla_profiler.cc +++ b/ODLA/platforms/odla_profiler.cc @@ -522,8 +522,8 @@ odla_value odla_Rsqrt(odla_value input, const odla_value_id id) { } static constexpr const char fn_slice[] = "odla_Slice"; -odla_value odla_Slice(odla_value input, const odla_uint32* start, - const odla_uint32* end, const odla_uint32* strides, +odla_value odla_Slice(odla_value input, const odla_int32* start, + const odla_int32* end, const odla_int32* strides, odla_value_shape output_dims, const odla_value_id id) { return profile(input, start, end, strides, output_dims, id); } diff --git a/ODLA/platforms/tensorrt/odla_tensorrt.cc b/ODLA/platforms/tensorrt/odla_tensorrt.cc index 71d3c6edd..cadb58877 100644 --- a/ODLA/platforms/tensorrt/odla_tensorrt.cc +++ b/ODLA/platforms/tensorrt/odla_tensorrt.cc @@ -1838,8 +1838,8 @@ odla_value odla_Gather(odla_value input, const odla_value indices, return CreateValue(gather, {input->type.element_type, output_dims}, id); } -odla_value odla_Slice(odla_value input, const odla_uint32* start, - const odla_uint32* end, const odla_uint32* stride, +odla_value odla_Slice(odla_value input, const odla_int32* start, + const odla_int32* end, const odla_int32* stride, odla_value_shape output_dims, const odla_value_id id) { odla_value_shape start_dims, stride_dims; const auto& dims = input->type.shape; diff --git a/lib/target/generic_cpp/slice.cc b/lib/target/generic_cpp/slice.cc index bfaaf1df4..e371bce71 100644 --- a/lib/target/generic_cpp/slice.cc +++ b/lib/target/generic_cpp/slice.cc @@ -33,15 +33,14 @@ namespace { template static void NormalizerOperands(const Constant& operand, const std::unordered_set& axes, - const size_t dims, - std::vector* value) { + const size_t dims, std::vector* value) { bool onnx_mode = axes.size() != dims; for (size_t i = 0, j = 0; i < dims; ++i) { if (axes.count(i) != 0) { - (*value)[i] = static_cast(operand.GetData(j++)); + (*value)[i] = static_cast(operand.GetData(j++)); } else { if (!onnx_mode) { - (*value)[i] = static_cast(operand.GetData(i)); + (*value)[i] = static_cast(operand.GetData(i)); } } } @@ -91,7 +90,7 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) { } } - std::vector start_v(dims, 0); + std::vector start_v(dims, 0); HLCHECK(start.GetType().GetTotalNumOfElements() == static_cast(axes.size())); HLCHECK(IsA(start)); @@ -102,8 +101,8 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) { NormalizerOperands(*start_c, axes, dims, &start_v); } - std::vector size_v(input.GetType().GetDimSizes().begin(), - input.GetType().GetDimSizes().end()); + std::vector size_v(input.GetType().GetDimSizes().begin(), + input.GetType().GetDimSizes().end()); HLCHECK(size.GetType().GetTotalNumOfElements() == static_cast(axes.size())); HLCHECK(IsA(size)); @@ -114,7 +113,7 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) { NormalizerOperands(*size_c, axes, dims, &size_v); } - std::vector strides_v(dims, 1); + std::vector strides_v(dims, 1); if (inst->GetNumOfOperands() > 3) { const Def& strides = inst->GetOperand(3); HLCHECK(IsA(strides)); @@ -129,15 +128,15 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) { // stride is provided, calculate ends = starts + sizes * strides std::for_each(strides_v.begin(), strides_v.end(), - [=](uint32_t& s) { s = s >= 0 ? s : dims + s; }); + [=](int32_t& s) { s = s >= 0 ? s : dims + s; }); std::transform(strides_v.begin(), strides_v.end(), size_v.begin(), size_v.begin(), std::multiplies()); std::transform(start_v.begin(), start_v.end(), size_v.begin(), - size_v.begin(), std::plus()); + size_v.begin(), std::plus()); } else { // stride is omitted, set to [1,1,...,1], calculate ends = starts + sizes std::transform(size_v.begin(), size_v.end(), start_v.begin(), - size_v.begin(), std::plus()); + size_v.begin(), std::plus()); } EmitODLACall(ret, "odla_Slice", op0, start_v, size_v, strides_v,