Skip to content

Commit

Permalink
backend: compiler: core: support MHA padding and fix thread local buf…
Browse files Browse the repository at this point in the history
…fer release
  • Loading branch information
ZhennanQin authored and TaoLv committed Dec 28, 2021
1 parent e762d77 commit ef99a0c
Show file tree
Hide file tree
Showing 13 changed files with 316 additions and 90 deletions.
5 changes: 3 additions & 2 deletions src/backend/graph_compiler/core/src/compiler/ir/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,10 @@ expr make_ceil(const expr_c &v) {
std::vector<expr> {v.remove_const()}, any_map_t());
}

expr make_exp(const expr_c &v) {
expr make_exp(const expr_c &v, int mask_count) {
return make_expr<intrin_call_node>(intrin_type::exp,
std::vector<expr> {v.remove_const()}, any_map_t());
std::vector<expr> {v.remove_const()},
any_map_t({{"mask_count", mask_count}}));
}

expr make_sqrt(const expr_c &v) {
Expand Down
3 changes: 2 additions & 1 deletion src/backend/graph_compiler/core/src/compiler/ir/builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,10 @@ expr make_ceil(const expr_c &v);
/**
* Makes an exp node
* @param v the input value
* @param mask the mask count
* @return the created node
* */
expr make_exp(const expr_c &v);
expr make_exp(const expr_c &v, int mask = -1);

/**
* Makes an sqrt node
Expand Down
170 changes: 156 additions & 14 deletions src/backend/graph_compiler/core/src/compiler/ir/graph/fusible_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,114 @@ static inline uint32_t vectorize_step(
return std::min(16U, ctx->get_max_vector_lanes(detype));
}

// todo use uint64_t instead of mask count
expr make_select_by_mask(expr lhs_vec, int mask_count, uint32_t vector_lanes) {
if (mask_count == -1) { return lhs_vec; }
expr rhs_vec = make_expr<constant_node>(
std::vector<union_val>(vector_lanes, UINT64_C(0)),
sc_data_type_t(lhs_vec->dtype_.type_code_, vector_lanes));
if (mask_count == 0) { return rhs_vec; }
std::vector<union_val> mask_const(vector_lanes, 1.f);
for (uint32_t i = mask_count; i < vector_lanes; i++) {
mask_const[i] = 0.f;
}
expr mask_vec = make_expr<constant_node>(
mask_const, sc_data_type_t::f32(vector_lanes));
expr zero_vec = make_expr<constant_node>(
std::vector<union_val>(vector_lanes, 0.f),
sc_data_type_t::f32(vector_lanes));
return builder::make_select(mask_vec > zero_vec, lhs_vec, rhs_vec);
}

struct mask_compute_func_t {
mask_compute_func_t(const std::function<stmt(const std::vector<expr> &,
std::vector<expr::lvalue_proxy_t> &, int, float)> &func)
: impl_(func) {}
stmt operator()(const std::vector<expr> &in,
std::vector<expr::lvalue_proxy_t> &out, int mask_count = -1,
float mask_value = 0.f) const {
return impl_(in, out, mask_count, mask_value);
}
std::function<stmt(const std::vector<expr> &,
std::vector<expr::lvalue_proxy_t> &, int, float)>
impl_;
};

/** Determine whether masks are needed during elementwise computation and
* generate conditional expressions for the mask
* @param src input slice
* @param plain_dims plain shapes
* @param format input format
* @param iter_vars input loop vars
* @param lanes simd lanes
* @param condition key is related iter var, value is two conditions:first means
* in the condition, all elements should be all computed,second means only
* `mask_count` elements will be computed
* @param last_axis_mask mask count, how many elements should be computed in
* this time. -1 means all.
* */
static void compute_mask_and_generate_condition(
const std::vector<const tensor_slice *> &src, const sc_dims &plain_dims,
sc_data_format_t format, const std::vector<expr> &iter_vars, int lanes,
std::unordered_map<expr, std::pair<expr, expr>> &conditions,
int &last_axis_mask) {
auto blocking_dims
= sc_data_format_t::get_blocking_shapes(plain_dims, format);
auto padded_dims
= sc_data_format_t::get_padded_plain_shapes(blocking_dims, format);
auto &format_code = format.format_code_;
if (plain_dims == padded_dims) { return; }
auto offset = src[0]->get_offset();
auto shapes = src[0]->get_shape();
size_t ndims = format_code.ndims();
assert(offset.size() == ndims && shapes.size() == ndims
&& iter_vars.size() == ndims);
auto plain2block = format_code.collect_p2b_mapping();
for (size_t i = 0; i < plain2block.size(); i++) {
auto &orig_dim = i;
if (plain_dims[orig_dim] == padded_dims[orig_dim]
|| plain2block[i].size() == 1) {
continue;
}
auto &block_dim = plain2block[i][plain2block[i].size() - 1];
auto blocks = format_code.collect_blocking_index(orig_dim);
int padding_count = 0;
conditions[iter_vars[block_dim]].first
= lanes + (iter_vars[block_dim] + offset[block_dim]);
conditions[iter_vars[block_dim]].second
= iter_vars[block_dim] + offset[block_dim];
for (int b = static_cast<int>(blocks.size()) - 1; b >= 0; b--) {
if (b > 0 && blocks[b - 1] % blocks[b] != 0) { padding_count++; }
conditions[iter_vars[block_dim]].first
= conditions[iter_vars[block_dim]].first
+ (iter_vars[plain2block[i][b]] + offset[plain2block[i][b]])
* format.blocks_[blocks[b]];
conditions[iter_vars[block_dim]].second
= conditions[iter_vars[block_dim]].second
+ (iter_vars[plain2block[i][b]] + offset[plain2block[i][b]])
* format.blocks_[blocks[b]];
}
conditions[iter_vars[block_dim]].first
= conditions[iter_vars[block_dim]].first
< dim2unsigned(plain_dims[orig_dim]);
conditions[iter_vars[block_dim]].second
= conditions[iter_vars[block_dim]].second
< dim2unsigned(plain_dims[orig_dim]);
COMPILE_ASSERT(padding_count < 2,
"Currently we don't support multi-level padding mask.");
if (block_dim == format_code.ndims() - 1) {
assert(lanes > 1);
last_axis_mask = plain_dims[orig_dim] % lanes;
}
}
}

void compute_vectorized_op(const std::vector<const tensor_slice *> &src,
const tensor_slice &dst, sc_op_info_t &info,
const vectorized_info_t &vx_info,
const std::function<stmt(const std::vector<expr> &,
std::vector<expr::lvalue_proxy_t> &)> &compute_lanes,
const std::function<stmt(const std::vector<expr> &,
std::vector<expr::lvalue_proxy_t> &)> &compute_scalar,
size_t wkld = 0UL) {
const mask_compute_func_t &compute_lanes,
const mask_compute_func_t &compute_scalar, size_t wkld = 0UL,
bool use_mask = false) {
// nested loop vars
std::vector<expr> iter_vars;
// the indices for multiple inputs. First dim: the input, Second dim:
Expand Down Expand Up @@ -263,6 +363,19 @@ void compute_vectorized_op(const std::vector<const tensor_slice *> &src,
dst.get_shape().at(vx_info.axis).static_as<constant>());
int floor = slice_len / vx_info.lanes * vx_info.lanes;
int tail = slice_len % vx_info.lanes;
int last_axis_mask = -1;
std::unordered_map<expr, std::pair<expr, expr>> conditions;
if (use_mask) {
compute_mask_and_generate_condition(src,
info.inputs_[0]->details_.get_plain_dims(),
info.inputs_[0]->details_.get_format(), iter_vars,
vx_info.lanes, conditions, last_axis_mask);
}
if (last_axis_mask != -1) {
COMPILE_ASSERT(tail == 0,
"Currently we only support mask in vectorize compute not "
"tail.");
}
std::vector<stmt> tcur;
stmt cur;
// recover schedule loop
Expand All @@ -273,7 +386,25 @@ void compute_vectorized_op(const std::vector<const tensor_slice *> &src,
&& i == vx_info.axis) {
if (floor) {
bld->push_scope();
cur = compute_lanes(indexed_input_floor, target_floor);
if (conditions.find(iter_vars[i]) != conditions.end()) {
assert(last_axis_mask != -1);
stmt no_mask = builder::make_stmts_unattached(
{compute_lanes(indexed_input_floor, target_floor)});
stmt semi_mask = builder::make_stmts_unattached(
{compute_lanes(indexed_input_floor, target_floor,
last_axis_mask)});
stmt all_mask
= builder::make_stmts_unattached({compute_lanes(
indexed_input_floor, target_floor, 0)});
cur = builder::make_if_else_unattached(
conditions[iter_vars[i]].first, no_mask,
builder::make_stmts_unattached(
{builder::make_if_else_unattached(
conditions[iter_vars[i]].second,
semi_mask, all_mask)}));
} else {
cur = compute_lanes(indexed_input_floor, target_floor);
}
cur->attr()[op_traits::workload_computable_t::workload_number]
= wkld;
bld->emit(cur);
Expand Down Expand Up @@ -1411,7 +1542,8 @@ void binary_elementwise_op_t::compute_block(context_ptr ctx,
info_.outputs_[0]->details_.dtype_, wkld);
} else {
auto func = [&](const std::vector<expr> &in,
std::vector<expr::lvalue_proxy_t> &out) -> stmt {
std::vector<expr::lvalue_proxy_t> &out,
int mask_count, float mask_value) -> stmt {
switch (elt_op_) {
case elt_operator::ADD:
return builder::make_assign_unattached(
Expand All @@ -1423,8 +1555,9 @@ void binary_elementwise_op_t::compute_block(context_ptr ctx,
return builder::make_assign_unattached(
out[0], in[0] * in[1]);
case elt_operator::DIV:
return builder::make_assign_unattached(
out[0], in[0] / in[1]);
return builder::make_assign_unattached(out[0],
make_select_by_mask(
in[0] / in[1], mask_count, vector_lanes));
case elt_operator::MIN:
return builder::make_assign_unattached(
out[0], builder::make_min(in[0], in[1]));
Expand All @@ -1441,8 +1574,11 @@ void binary_elementwise_op_t::compute_block(context_ptr ctx,
return stmt();
}
};
compute_vectorized_op(
inputs, *dst[0], info_, vx_info_, func, func, wkld);
// todo: currently we only support mask for div.
bool use_mask = elt_op_ == elt_operator::DIV;
compute_vectorized_op(inputs, *dst[0], info_, vx_info_,
mask_compute_func_t(func), mask_compute_func_t(func), wkld,
use_mask);
}
}

Expand Down Expand Up @@ -1600,10 +1736,16 @@ void unary_elementwise_op_t::compute_block(context_ptr ctx,
vx_info.lanes
= vectorize_step(ctx, info_.inputs_[0]->details_.dtype_.type_code_);
auto func = [&](const std::vector<expr> &in,
std::vector<expr::lvalue_proxy_t> &out) -> stmt {
return builder::make_assign_unattached(out[0], compute_element(in[0]));
std::vector<expr::lvalue_proxy_t> &out, int mask_count,
float mask_value) -> stmt {
return builder::make_assign_unattached(
out[0], compute_element(in[0], mask_count, mask_value));
};
compute_vectorized_op(inputs, *dst[0], info_, vx_info, func, func, wkld);
// Currenly only support for exp
bool use_mask = op_name_ == "exp";
compute_vectorized_op(inputs, *dst[0], info_, vx_info,
mask_compute_func_t(func), mask_compute_func_t(func), wkld,
use_mask);
}

void unary_elementwise_op_t::prepare_fusion_data(context_ptr ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,20 @@ class constant_op_t : public fusible_op_t, public op_traits::auto_copyable_t {
bool compare_contents(const sc_op *other) const override;
size_t hash_contents() const override;

// if necessary, reset const_values according possible `var` from attrs
void reset_const_values() {
if (attrs_.has_key("temp.var") && attrs_.has_key("temp.val/var")) {
int K = static_cast<int>(
attrs_.get<std::shared_ptr<VConst>>("temp.var")->var_);
int base_val = attrs_.get<int>("temp.val/var");
// update private member
const_values_ = std::make_shared<static_data_t>(
std::vector<int> {base_val * K});
// update attr
attrs_.set("values", const_values_);
}
}

private:
std::shared_ptr<static_data_t> const_values_;
};
Expand Down Expand Up @@ -448,7 +462,7 @@ class unary_elementwise_op_t : public fusible_op_t,
const std::vector<graph_tensor_ptr> &outs, const any_map_t &attrs);
vectorized_info_t &get_vx_info() { return vx_info_; }

virtual expr compute_element(expr in) = 0;
virtual expr compute_element(expr in, int mask_count, float mask_value) = 0;

private:
vectorized_info_t vx_info_;
Expand Down
11 changes: 11 additions & 0 deletions src/backend/graph_compiler/core/src/compiler/ir/graph/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ struct fusion_anchor_data;
struct tensor_slice;
struct fusion_data;

/** VConst struct record possible varible in constant value, e.g.
*
* const int a = k * b;
*
* where `k` maybe variable related on other factor such as blocking dims.
* */
struct VConst {
// variable value
int64_t var_;
};

// a weak pointer which always asserts the object exists
struct sc_op_weak_ptr_t : public std::weak_ptr<sc_op> {
using parent = std::weak_ptr<sc_op>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,5 +205,15 @@ SC_INTERNAL_API void layout_propagation(
}
}
});

// it should be refactor to one standalone pass to finally fix constant
// value
auto vis2 = op_visitor_t::bfs();
vis2.visit_graph(graph, [&](const sc_op_ptr &node) {
if (node->isa<constant_op_t>() && node->attrs_.has_key("temp.var")) {
auto const_op = node->dyn_cast<constant_op_t>();
const_op->reset_const_values();
}
});
}
} // namespace sc
Loading

0 comments on commit ef99a0c

Please sign in to comment.