Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] dynamic shape & cfg #736

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/build/build_in_docker.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash -xe

REPO="registry-intl.us-west-1.aliyuncs.com/computation/halo"
VER="0.7.6"
VER="0.7.7"
FLAVOR="devel"

MOUNT_DIR="$PWD"
Expand Down
2 changes: 1 addition & 1 deletion ODLA/include/ODLA/ops/odla_ops_process.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ odla_Resize(odla_value input, odla_interpolation_mode interpolation,
the result value is implementation determined.

\param input the input value
\param value_id a unique value id (can be NULL)
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)

\return odla_value
*/
Expand Down
3 changes: 3 additions & 0 deletions ODLA/platforms/tensorrt/odla_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ static odla_value CreateValue(T* t, const odla_value_type& type,
auto v = std::make_unique<_odla_value>(t, type, name);
auto ret = v.get();
g_comp->vals.push_back(std::move(v));
if (!g_comp->branchs.empty()) {
g_comp->branchs.top().branch->addInput(*ret);
}
return ret;
}

Expand Down
6 changes: 6 additions & 0 deletions include/halo/lib/ir/common_instructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ let cat_ = cat_common in {
let outs_ = [Arg<"The result.", MatchArgType<0> >];
}

def Shape : Inst<"Compute the shape of input."> {
let ins_ = [Arg<"The input.", ArgType<[I8,I16,I32,F16,F32]> >];
let attrs_ = [Attr<"The output date type size", EnumDataType, "data_type", "INT64">];
let outs_ = [Arg<"The result.", ArgType<[I64, I32]> >];
}

def Reshape : Inst<"Reshape the input X1 to create the result with the same"
" number of elements and the shape specified by X2."> {
let ins_ = [Arg<"The input.", ArgType<[I8,I16,I32,F16,F32]> >,
Expand Down
5 changes: 4 additions & 1 deletion include/halo/lib/ir/tf_convert.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def TF_Reshape : TFExtension<"Reshape"> {
let extension_attr_ = [ ExtensionAttr<"shape", IntegerList, "{}"> ];
}

def TF_Shape : TFExtension<"Shape">;
def TF_Shape: OpMapping<"Shape", Shape> {
let attr_mapping_ = [
AttributeMapping<"", "data_type", "INT32">];
}

def TF_SquaredDifference : TFExtension<"SquaredDifference">;

Expand Down
1 change: 1 addition & 0 deletions include/halo/lib/pass/pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class HL_API_EXPORT PassManager final {
Pass* AddCodeFormatterPass(std::ostringstream& code,
std::ostringstream& header,
const CXXCodeGenOpts& opts);
Pass* AddConvertTFCFGPass();
Pass* AddDCEPass();
Pass* AddDevicePlacementPass();
Pass* AddFusionPass(const FusionOptions& opts);
Expand Down
4 changes: 3 additions & 1 deletion include/halo/lib/target/generic_cxx/generic_cxx_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,16 @@ class GenericCXXCodeGen : public CodeGen {
virtual void RunOnInstruction(ReturnInst*) override;
virtual void RunOnInstruction(RNNInst*) override;
virtual void RunOnInstruction(SelectInst*) override;
virtual void RunOnInstruction(ShapeInst*) override;
virtual void RunOnInstruction(ShiftInst*) override;
virtual void RunOnInstruction(ShrinkInst*) override;
virtual void RunOnInstruction(SigmoidInst*) override;
virtual void RunOnInstruction(SItoFPInst*) override;
virtual void RunOnInstruction(SliceInst*) override;
virtual void RunOnInstruction(SoftmaxInst*) override;
virtual void RunOnInstruction(SoftplusInst*) override;
virtual void RunOnInstruction(SoftsignInst*) override;
virtual void RunOnInstruction(SigmoidInst*) override;
virtual void RunOnInstruction(StackInst*) override;
virtual void RunOnInstruction(HardSigmoidInst*) override;
virtual void RunOnInstruction(SinInst*) override;
virtual void RunOnInstruction(SinhInst*) override;
Expand Down
37 changes: 37 additions & 0 deletions include/halo/lib/transforms/convert_tf_cfg.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- convert_tf_cfg.h ---------------------------------------------------===//
//
// Copyright (C) 2019-2020 Alibaba Group Holding Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#ifndef HALO_LIB_TRANSFORMS_CONVERT_TF_CFG_H_
#define HALO_LIB_TRANSFORMS_CONVERT_TF_CFG_H_

#include "halo/lib/pass/pass.h"

namespace halo {

/// This pass eliminates dead IRs.
class ConvertTFCFG final : public FunctionPass {
public:
ConvertTFCFG() : FunctionPass("Convert TF CFG"), converted_(false) {}
bool RunOnFunction(Function* func) override;

private:
bool converted_;
};

} // end namespace halo.

#endif // HALO_LIB_TRANSFORMS_CONVERT_TF_CFG_H_
1 change: 1 addition & 0 deletions include/halo/lib/transforms/inst_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class InstSimplify final : public BasicBlockPass {
static std::pair<Def, Def> RunOnInstruction(ResizeInst* inst);
static std::pair<Def, Def> RunOnInstruction(SelectInst* inst);
static std::pair<Def, Def> RunOnInstruction(SetDiff1DInst* inst);
static std::pair<Def, Def> RunOnInstruction(ShapeInst* inst);
static std::pair<Def, Def> RunOnInstruction(SigmoidInst* inst);
static std::pair<Def, Def> RunOnInstruction(SItoFPInst* inst);
static std::pair<Def, Def> RunOnInstruction(FPtoSIInst* inst);
Expand Down
3 changes: 3 additions & 0 deletions include/halo/utils/passes_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ static void PopulateOptPasses(PassManager* pm, const std::string& target,
if (opts.enable_type_cast) {
pm->AddTypeCastPass();
}
if (format == ModelFormat::TENSORFLOW) {
pm->AddConvertTFCFGPass();
}
if (opts.constant_decombine) {
pm->AddConstantDecombinePass();
}
Expand Down
3 changes: 3 additions & 0 deletions lib/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "halo/lib/transforms/analyzer.h"
#include "halo/lib/transforms/caffeextension_legalizer.h"
#include "halo/lib/transforms/constant_decombine.h"
#include "halo/lib/transforms/convert_tf_cfg.h"
#include "halo/lib/transforms/dce.h"
#include "halo/lib/transforms/device_placement.h"
#include "halo/lib/transforms/fusion.h"
Expand Down Expand Up @@ -265,6 +266,8 @@ Pass* PassManager::AddCodeFormatterPass(std::ostringstream& buf_code,
return AddPass<CodeFormatter>(buf_code, buf_header, opts);
}

Pass* PassManager::AddConvertTFCFGPass() { return AddPass<ConvertTFCFG>(); }

Pass* PassManager::AddDCEPass() { return AddPass<DCE>(); }

Pass* PassManager::AddDevicePlacementPass() {
Expand Down
18 changes: 18 additions & 0 deletions lib/target/generic_cpp/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,22 @@ void GenericCXXCodeGen::RunOnInstruction(ConcatInst* inst) {
EmitODLACall(ret, "odla_Concat", inputs, axis, EmitShape(ret_shape));
}

void GenericCXXCodeGen::RunOnInstruction(StackInst* inst) {
const auto& axis = inst->GetAxis();

CXXValue op0 = ir_mapping_[inst->GetOperand(0)];
CXXValue ret(inst->GetName(), op0.type);

ir_mapping_[*inst] = ret;
const halo::Type& ret_shape = inst->GetResultType();
const auto num = inst->GetNumOfOperands();
std::vector<CXXValue> inputs;
for (size_t i = 0; i < num; ++i) {
const Def& op = inst->GetOperand(i);
CXXValue op_v = ir_mapping_[op];
inputs.push_back(op_v);
}
EmitODLACall(ret, "odla_Stack", inputs, axis, EmitShape(ret_shape));
}

} // namespace halo
7 changes: 7 additions & 0 deletions lib/target/generic_cpp/generic_cxx_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ static std::string GetBF16Mode(BF16Mode mode) {
}

bool GenericCXXCodeGen::RunOnModule(Module* module) {
module->Dump();
memory_analyzer_ = std::make_unique<MemoryAnalyzer>(*module);
Function* entry_func = nullptr;
EmitBanner(&os_, &header_os_, GetAPI());
Expand Down Expand Up @@ -270,6 +271,9 @@ std::string GenericCXXCodeGen::GetFunctionDecl(const Function& func,
auto nr_outputs = ret_inst.GetNumOfOperands();
model_info.num_outputs = nr_outputs;
for (const auto& out : ret_inst.GetOperands()) {
if (out.IsNull()) {
continue;
}
const auto& type = out.GetType();
if (ir_mapping_.find(out) == ir_mapping_.end()) {
CXXValue cv(out.GetDef()->GetName(), TensorTypeToCXXType(type, false));
Expand Down Expand Up @@ -840,6 +844,9 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
index = 0;
// Pre-launch binding.
for (auto& op : return_inst->GetOperands()) {
if (op.IsNull()) {
continue;
}
auto& cv = ir_mapping_[op];
std::string arg_name = (opts_.emit_inference_func_sig || is_sub)
? (is_sub ? "outputs.values[" : "outputs[") +
Expand Down
12 changes: 12 additions & 0 deletions lib/target/generic_cpp/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,16 @@ void GenericCXXCodeGen::RunOnInstruction(ReshapeInst* inst) {
ir_mapping_[*inst] = ret;
}

void GenericCXXCodeGen::RunOnInstruction(ShapeInst* inst) {
const Def& input = inst->GetOperand(0);

CXXValue op0 = ir_mapping_[input];

const auto& ret_type = inst->GetResultType();
CXXValue ret(inst->GetName(), op0.type);
EmitODLACall(ret, "odla_Shape", op0, EmitShape(ret_type));

ir_mapping_[*inst] = ret;
}

} // namespace halo
3 changes: 3 additions & 0 deletions lib/target/generic_cpp/return.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ namespace halo {
void GenericCXXCodeGen::RunOnInstruction(ReturnInst* inst) {
bool is_compile_mode = opts_.exec_mode == ExecMode::Compile;
for (auto& op : inst->GetOperands()) {
if (op.IsNull()) {
continue;
}
const CXXValue& val = ir_mapping_[op];
if (is_compile_mode) {
bool is_entry_with_calls =
Expand Down
25 changes: 18 additions & 7 deletions lib/target/generic_cpp/slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,24 @@ static void NormalizerOperands(const Constant& operand,
} // end namespace

void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {
const Def input = inst->GetOperand(0);
const Def& input = inst->GetOperand(0);
const Def& start = inst->GetOperand(1);
const Def& size = inst->GetOperand(2);
// auto strides = inst->GetOperand(3); //TODO

CXXValue op0 = ir_mapping_[input];
CXXValue ret(inst->GetName(), op0.type);
ir_mapping_[*inst] = ret;

if (!IsA<Constant>(start) || !IsA<Constant>(size)) {
auto op1 = ir_mapping_[start];
auto op2 = ir_mapping_[size];
// auto op3 = ir_mapping_[strides]; // FIXME
EmitODLACall(ret, "odla_SliceDynamic", op0, op1, op2, /*op3,*/
EmitShape(inst->GetResultType()));

return;
}
size_t dims = input.GetType().GetNumOfDims();
std::unordered_set<int32_t> axes;
if (inst->GetNumOfOperands() > 4) {
Expand All @@ -75,7 +92,6 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {
}

std::vector<uint32_t> start_v(dims, 0);
const Def& start = inst->GetOperand(1);
HLCHECK(start.GetType().GetTotalNumOfElements() ==
static_cast<int64_t>(axes.size()));
HLCHECK(IsA<Constant>(start));
Expand All @@ -88,7 +104,6 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {

std::vector<uint32_t> size_v(input.GetType().GetDimSizes().begin(),
input.GetType().GetDimSizes().end());
const Def& size = inst->GetOperand(2);
HLCHECK(size.GetType().GetTotalNumOfElements() ==
static_cast<int64_t>(axes.size()));
HLCHECK(IsA<Constant>(size));
Expand Down Expand Up @@ -125,12 +140,8 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {
size_v.begin(), std::plus<uint32_t>());
}

CXXValue op0 = ir_mapping_[input];
CXXValue ret(inst->GetName(), op0.type);

EmitODLACall(ret, "odla_Slice", op0, start_v, size_v, strides_v,
EmitShape(inst->GetResultType()));
ir_mapping_[*inst] = ret;
}

} // namespace halo
1 change: 1 addition & 0 deletions lib/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ set(SRCS
analyzer.cc
caffeextension_legalizer.cc
constant_decombine.cc
convert_tf_cfg.cc
dce.cc
device_placement.cc
fusion.cc
Expand Down
Loading