Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] dynamic shape & cfg from pr736 #743

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/build/build_in_docker.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash -xe

REPO="registry-intl.us-west-1.aliyuncs.com/computation/halo"
VER="0.7.6"
VER="0.7.7"
FLAVOR="devel"

MOUNT_DIR="$PWD"
Expand Down
38 changes: 35 additions & 3 deletions ODLA/include/ODLA/ops/odla_ops_process.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ odla_Resize(odla_value input, odla_interpolation_mode interpolation,
the result value is implementation determined.

\param input the input value
\param value_id a unique value id (can be NULL)
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)

\return odla_value
*/
Expand All @@ -282,10 +282,26 @@ odla_Shape(odla_value input, odla_value_shape output_dims,
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Slice(odla_value input, const odla_uint32* start, const odla_uint32* end,
const odla_uint32* stride, odla_value_shape output_dims,
odla_Slice(odla_value input, const odla_int32* start, const odla_int32* end,
const odla_int32* stride, odla_value_shape output_dims,
const odla_value_id value_id);

//! \brief Extract a dynamic slice
/*!
SliceDynamic extracts a dynamic slice from \p input.

\param input the input value
\param start the offets at each slicing dimension
\param size the number of elements at each slicing dimension
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)

\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_SliceDynamic(odla_value input, odla_value start, odla_value size,
odla_value_shape output_dims, const odla_value_id value_id);

//! \brief Remove dimensions of size 1
/*!
Squeeze removes dimensions of size 1 from the shape of \p input.
Expand All @@ -303,6 +319,22 @@ extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Squeeze(odla_value input, odla_size_t num_of_axes, const odla_uint32* axes,
odla_value_shape output_dims, const odla_value_id value_id);

//! \brief Join a sequence of Values along a new axis.
/*!
Stack joins multiple values into single one along a new axis. All inputs
must have the same dimension.

\param inputs the input values
\param axis the index of the new axis in the dimensions of the result
\param output_shape the result shape
\param value_id a unique value id (can be NULL)

\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Stack(odla_values inputs, odla_int32 axis, odla_value_shape output_shape,
const odla_value_id value_id);

//! \brief Transpose the input
/*!
Transpose returns a transposed value based on the \p permutation.
Expand Down
8 changes: 4 additions & 4 deletions ODLA/platforms/dnnl/odla_dnnl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1834,8 +1834,8 @@ odla_value odla_Erf(odla_value input, const odla_value_id value_id) {

static void strided_slice(const void* src, int elem_size,
const odla_value_shape& input_dims,
const odla_uint32* start, const odla_uint32* end,
const odla_uint32* strides, void* dst,
const odla_int32* start, const odla_int32* end,
const odla_int32* strides, void* dst,
const odla_value_shape& output_dims) {
int64_t dst_elems = GetTotalElements(output_dims);
int dims = input_dims.size;
Expand Down Expand Up @@ -1871,8 +1871,8 @@ static void strided_slice(const void* src, int elem_size,
}
}

odla_value odla_Slice(odla_value input, const odla_uint32* start,
const odla_uint32* end, const odla_uint32* strides,
odla_value odla_Slice(odla_value input, const odla_int32* start,
const odla_int32* end, const odla_int32* strides,
odla_value_shape output_dims, const odla_value_id id) {
const auto& input_dims = input->shape;
int dims = input_dims.size;
Expand Down
4 changes: 2 additions & 2 deletions ODLA/platforms/odla_popart/odla_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ odla_value odla_ReduceMean(odla_value input, odla_size_t num_of_axes,
return result;
}

odla_value odla_Slice(odla_value input, const odla_uint32* start,
const odla_uint32* end, const odla_uint32* stride,
odla_value odla_Slice(odla_value input, const odla_int32* start,
const odla_int32* end, const odla_int32* stride,
odla_value_shape output_dims, const odla_value_id id) {
const auto& name = id ? std::string(reinterpret_cast<const char*>(id)) : "";

Expand Down
4 changes: 2 additions & 2 deletions ODLA/platforms/odla_profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,8 @@ odla_value odla_Rsqrt(odla_value input, const odla_value_id id) {
}

static constexpr const char fn_slice[] = "odla_Slice";
odla_value odla_Slice(odla_value input, const odla_uint32* start,
const odla_uint32* end, const odla_uint32* strides,
odla_value odla_Slice(odla_value input, const odla_int32* start,
const odla_int32* end, const odla_int32* strides,
odla_value_shape output_dims, const odla_value_id id) {
return profile<fn_slice>(input, start, end, strides, output_dims, id);
}
Expand Down
7 changes: 5 additions & 2 deletions ODLA/platforms/tensorrt/odla_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ static odla_value CreateValue(T* t, const odla_value_type& type,
auto v = std::make_unique<_odla_value>(t, type, name);
auto ret = v.get();
g_comp->vals.push_back(std::move(v));
if (!g_comp->branchs.empty()) {
g_comp->branchs.top().branch->addInput(*ret);
}
return ret;
}

Expand Down Expand Up @@ -1835,8 +1838,8 @@ odla_value odla_Gather(odla_value input, const odla_value indices,
return CreateValue(gather, {input->type.element_type, output_dims}, id);
}

odla_value odla_Slice(odla_value input, const odla_uint32* start,
const odla_uint32* end, const odla_uint32* stride,
odla_value odla_Slice(odla_value input, const odla_int32* start,
const odla_int32* end, const odla_int32* stride,
odla_value_shape output_dims, const odla_value_id id) {
odla_value_shape start_dims, stride_dims;
const auto& dims = input->type.shape;
Expand Down
6 changes: 6 additions & 0 deletions include/halo/lib/ir/common_instructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ let cat_ = cat_common in {
let outs_ = [Arg<"The result.", MatchArgType<0> >];
}

def Shape : Inst<"Compute the shape of input."> {
let ins_ = [Arg<"The input.", ArgType<[I8,I16,I32,F16,F32]> >];
let attrs_ = [Attr<"The output date type size", EnumDataType, "data_type", "INT64">];
let outs_ = [Arg<"The result.", ArgType<[I64, I32]> >];
}

def Reshape : Inst<"Reshape the input X1 to create the result with the same"
" number of elements and the shape specified by X2."> {
let ins_ = [Arg<"The input.", ArgType<[I8,I16,I32,F16,F32]> >,
Expand Down
5 changes: 4 additions & 1 deletion include/halo/lib/ir/tf_convert.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def TF_Reshape : TFExtension<"Reshape"> {
let extension_attr_ = [ ExtensionAttr<"shape", IntegerList, "{}"> ];
}

def TF_Shape : TFExtension<"Shape">;
def TF_Shape: OpMapping<"Shape", Shape> {
let attr_mapping_ = [
AttributeMapping<"", "data_type", "INT32">];
}

def TF_SquaredDifference : TFExtension<"SquaredDifference">;

Expand Down
1 change: 1 addition & 0 deletions include/halo/lib/pass/pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class HL_API_EXPORT PassManager final {
Pass* AddCodeFormatterPass(std::ostringstream& code,
std::ostringstream& header,
const CXXCodeGenOpts& opts);
Pass* AddConvertTFCFGPass();
Pass* AddDCEPass();
Pass* AddDevicePlacementPass();
Pass* AddFusionPass(const FusionOptions& opts);
Expand Down
4 changes: 3 additions & 1 deletion include/halo/lib/target/generic_cxx/generic_cxx_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,16 @@ class GenericCXXCodeGen : public CodeGen {
virtual void RunOnInstruction(ReturnInst*) override;
virtual void RunOnInstruction(RNNInst*) override;
virtual void RunOnInstruction(SelectInst*) override;
virtual void RunOnInstruction(ShapeInst*) override;
virtual void RunOnInstruction(ShiftInst*) override;
virtual void RunOnInstruction(ShrinkInst*) override;
virtual void RunOnInstruction(SigmoidInst*) override;
virtual void RunOnInstruction(SItoFPInst*) override;
virtual void RunOnInstruction(SliceInst*) override;
virtual void RunOnInstruction(SoftmaxInst*) override;
virtual void RunOnInstruction(SoftplusInst*) override;
virtual void RunOnInstruction(SoftsignInst*) override;
virtual void RunOnInstruction(SigmoidInst*) override;
virtual void RunOnInstruction(StackInst*) override;
virtual void RunOnInstruction(HardSigmoidInst*) override;
virtual void RunOnInstruction(SinInst*) override;
virtual void RunOnInstruction(SinhInst*) override;
Expand Down
37 changes: 37 additions & 0 deletions include/halo/lib/transforms/convert_tf_cfg.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- convert_tf_cfg.h ---------------------------------------------------===//
//
// Copyright (C) 2019-2020 Alibaba Group Holding Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#ifndef HALO_LIB_TRANSFORMS_CONVERT_TF_CFG_H_
#define HALO_LIB_TRANSFORMS_CONVERT_TF_CFG_H_

#include "halo/lib/pass/pass.h"

namespace halo {

/// This pass eliminates dead IRs.
class ConvertTFCFG final : public FunctionPass {
public:
ConvertTFCFG() : FunctionPass("Convert TF CFG"), converted_(false) {}
bool RunOnFunction(Function* func) override;

private:
bool converted_;
};

} // end namespace halo.

#endif // HALO_LIB_TRANSFORMS_CONVERT_TF_CFG_H_
1 change: 1 addition & 0 deletions include/halo/lib/transforms/inst_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class InstSimplify final : public BasicBlockPass {
static std::pair<Def, Def> RunOnInstruction(ResizeInst* inst);
static std::pair<Def, Def> RunOnInstruction(SelectInst* inst);
static std::pair<Def, Def> RunOnInstruction(SetDiff1DInst* inst);
static std::pair<Def, Def> RunOnInstruction(ShapeInst* inst);
static std::pair<Def, Def> RunOnInstruction(SigmoidInst* inst);
static std::pair<Def, Def> RunOnInstruction(SItoFPInst* inst);
static std::pair<Def, Def> RunOnInstruction(FPtoSIInst* inst);
Expand Down
3 changes: 3 additions & 0 deletions include/halo/utils/passes_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ static void PopulateOptPasses(PassManager* pm, const std::string& target,
if (opts.enable_type_cast) {
pm->AddTypeCastPass();
}
if (format == ModelFormat::TENSORFLOW) {
pm->AddConvertTFCFGPass();
}
if (opts.constant_decombine) {
pm->AddConstantDecombinePass();
}
Expand Down
3 changes: 3 additions & 0 deletions lib/pass/pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "halo/lib/transforms/analyzer.h"
#include "halo/lib/transforms/caffeextension_legalizer.h"
#include "halo/lib/transforms/constant_decombine.h"
#include "halo/lib/transforms/convert_tf_cfg.h"
#include "halo/lib/transforms/dce.h"
#include "halo/lib/transforms/device_placement.h"
#include "halo/lib/transforms/fusion.h"
Expand Down Expand Up @@ -265,6 +266,8 @@ Pass* PassManager::AddCodeFormatterPass(std::ostringstream& buf_code,
return AddPass<CodeFormatter>(buf_code, buf_header, opts);
}

Pass* PassManager::AddConvertTFCFGPass() { return AddPass<ConvertTFCFG>(); }

Pass* PassManager::AddDCEPass() { return AddPass<DCE>(); }

Pass* PassManager::AddDevicePlacementPass() {
Expand Down
18 changes: 18 additions & 0 deletions lib/target/generic_cpp/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,22 @@ void GenericCXXCodeGen::RunOnInstruction(ConcatInst* inst) {
EmitODLACall(ret, "odla_Concat", inputs, axis, EmitShape(ret_shape));
}

void GenericCXXCodeGen::RunOnInstruction(StackInst* inst) {
const auto& axis = inst->GetAxis();

CXXValue op0 = ir_mapping_[inst->GetOperand(0)];
CXXValue ret(inst->GetName(), op0.type);

ir_mapping_[*inst] = ret;
const halo::Type& ret_shape = inst->GetResultType();
const auto num = inst->GetNumOfOperands();
std::vector<CXXValue> inputs;
for (size_t i = 0; i < num; ++i) {
const Def& op = inst->GetOperand(i);
CXXValue op_v = ir_mapping_[op];
inputs.push_back(op_v);
}
EmitODLACall(ret, "odla_Stack", inputs, axis, EmitShape(ret_shape));
}

} // namespace halo
7 changes: 7 additions & 0 deletions lib/target/generic_cpp/generic_cxx_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ static std::string GetBF16Mode(BF16Mode mode) {
}

bool GenericCXXCodeGen::RunOnModule(Module* module) {
module->Dump();
memory_analyzer_ = std::make_unique<MemoryAnalyzer>(*module);
Function* entry_func = nullptr;
EmitBanner(&os_, &header_os_, GetAPI());
Expand Down Expand Up @@ -270,6 +271,9 @@ std::string GenericCXXCodeGen::GetFunctionDecl(const Function& func,
auto nr_outputs = ret_inst.GetNumOfOperands();
model_info.num_outputs = nr_outputs;
for (const auto& out : ret_inst.GetOperands()) {
if (out.IsNull()) {
continue;
}
const auto& type = out.GetType();
if (ir_mapping_.find(out) == ir_mapping_.end()) {
CXXValue cv(out.GetDef()->GetName(), TensorTypeToCXXType(type, false));
Expand Down Expand Up @@ -840,6 +844,9 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
index = 0;
// Pre-launch binding.
for (auto& op : return_inst->GetOperands()) {
if (op.IsNull()) {
continue;
}
auto& cv = ir_mapping_[op];
std::string arg_name = (opts_.emit_inference_func_sig || is_sub)
? (is_sub ? "outputs.values[" : "outputs[") +
Expand Down
12 changes: 12 additions & 0 deletions lib/target/generic_cpp/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,16 @@ void GenericCXXCodeGen::RunOnInstruction(ReshapeInst* inst) {
ir_mapping_[*inst] = ret;
}

void GenericCXXCodeGen::RunOnInstruction(ShapeInst* inst) {
const Def& input = inst->GetOperand(0);

CXXValue op0 = ir_mapping_[input];

const auto& ret_type = inst->GetResultType();
CXXValue ret(inst->GetName(), op0.type);
EmitODLACall(ret, "odla_Shape", op0, EmitShape(ret_type));

ir_mapping_[*inst] = ret;
}

} // namespace halo
3 changes: 3 additions & 0 deletions lib/target/generic_cpp/return.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ namespace halo {
void GenericCXXCodeGen::RunOnInstruction(ReturnInst* inst) {
bool is_compile_mode = opts_.exec_mode == ExecMode::Compile;
for (auto& op : inst->GetOperands()) {
if (op.IsNull()) {
continue;
}
const CXXValue& val = ir_mapping_[op];
if (is_compile_mode) {
bool is_entry_with_calls =
Expand Down
Loading