Skip to content

Commit

Permalink
Refactored unary operator, and provid support for logic NOT operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
tianboh authored and weimingzha0 committed Sep 30, 2021
1 parent 1f424d3 commit 54e80f3
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 281 deletions.
1 change: 1 addition & 0 deletions ODLA/platforms/dnnl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ set(ODLA_DNNL_SRC
odla_dnnl_cast.cc
odla_dnnl_loss.cc
odla_dnnl_rnn.cc
odla_dnnl_unary.cc
odla_dnnl_binary.cc
odla_dnnl_statistics.cc
)
Expand Down
274 changes: 0 additions & 274 deletions ODLA/platforms/dnnl/odla_dnnl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,6 @@
#error This library requires minimum ODLA version 0.5
#endif

enum class alg_unary_eltwise {
isnan,
isinf,
isinf_pos,
isinf_neg,
abs,
acos,
asin,
atan,
ceil,
cos,
cosh,
sin,
sinh,
log,
tan,
tanh,
sqrt,
neg,
acosh,
asinh,
atanh,
reciprocal,
sign,
};

struct _odla_context {
odla_computation comp;
std::unique_ptr<dnnl::stream> stream;
Expand Down Expand Up @@ -528,23 +502,6 @@ odla_value odla_GatherElements(odla_value data, const odla_value indices,
return CreateValue(ret_mem, output_dims, id);
}

static odla_value unary_eltwise_op(
dnnl::algorithm algo, odla_value input, odla_float32 alpha,
odla_float32 beta, const odla_value_id id,
dnnl::primitive_attr attr = dnnl::primitive_attr()) {
auto eltwise_d =
dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo,
input->mem.get_desc(), alpha, beta);
auto pd = dnnl::eltwise_forward::primitive_desc(eltwise_d, attr, g_comp->eng);

dnnl::primitive prim = dnnl::eltwise_forward(pd);
auto ret_mem = dnnl::memory(input->mem.get_desc(), g_comp->eng);
odla_value v = CreateValue(ret_mem, input->shape, id);
add_op(prim, {{DNNL_ARG_SRC, input->mem}, {DNNL_ARG_DST, ret_mem}});
InterpretIfNeeded();
return v;
}

static odla_value binary_eltwise_s32(dnnl::algorithm alg, dnnl::memory lhs_mem,
dnnl::memory rhs_mem,
odla_value_shape shape,
Expand Down Expand Up @@ -590,16 +547,6 @@ static odla_value binary_eltwise(dnnl::algorithm algo, odla_value lhs,
return v;
}

odla_value odla_Abs(odla_value input, const odla_value_id value_id) {
return unary_eltwise_op(dnnl::algorithm::eltwise_abs, input, 0.f, 0.f,
value_id);
}

odla_value odla_Tanh(odla_value input, const odla_value_id value_id) {
return unary_eltwise_op(dnnl::algorithm::eltwise_tanh, input, 0.f, 0.f,
value_id);
}

odla_value odla_Add(odla_value lhs, odla_value rhs, const odla_value_id id) {
return binary_eltwise(dnnl::algorithm::binary_add, lhs, rhs, id);
}
Expand Down Expand Up @@ -903,227 +850,6 @@ odla_value odla_Shift(odla_value input, odla_value shift_amount,
return v;
}

template <typename T>
static void unary_eltwise_T(alg_unary_eltwise alg, void* dst, const void* input,
int n) {
const T* input_t = static_cast<const T*>(input);
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>> in(input_t, n);
T* dst_t = static_cast<T*>(dst);
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> out(dst_t, n);
switch (alg) {
case alg_unary_eltwise::abs:
out = in.abs();
break;
case alg_unary_eltwise::neg:
out = -in;
break;
case alg_unary_eltwise::sign:
out = (0 < in).select(1, in);
out = (0 > out).select(-1, out);
break;
case alg_unary_eltwise::ceil:
out = in.ceil();
break;
case alg_unary_eltwise::log:
out = in.log();
break;
case alg_unary_eltwise::sqrt:
out = in.sqrt();
break;
case alg_unary_eltwise::reciprocal:
out = in.pow(-1);
break;
case alg_unary_eltwise::sin:
out = in.sin();
break;
case alg_unary_eltwise::cos:
out = in.cos();
break;
case alg_unary_eltwise::tan:
out = in.tan();
break;
case alg_unary_eltwise::acos:
out = in.acos();
break;
case alg_unary_eltwise::asin:
out = in.asin();
break;
case alg_unary_eltwise::asinh:
out = in.asinh();
break;
case alg_unary_eltwise::atan:
out = in.atan();
break;
case alg_unary_eltwise::atanh:
out = in.atanh();
break;
case alg_unary_eltwise::sinh:
out = in.sinh();
break;
case alg_unary_eltwise::tanh:
out = in.tanh();
break;
case alg_unary_eltwise::cosh:
out = in.cosh();
break;
case alg_unary_eltwise::acosh:
out = in.acosh();
break;
default:
assert(0);
}
}

template <typename T>
static void unary_eltwise_bool(alg_unary_eltwise alg, void* dst,
const void* input, int n) {
const T* input_t = static_cast<const T*>(input);
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>> in(input_t, n);
bool* dst_t = static_cast<bool*>(dst);
Eigen::Map<Eigen::Array<bool, Eigen::Dynamic, 1>> out(dst_t, n);
switch (alg) {
case alg_unary_eltwise::isnan:
out = in.isNaN();
break;
case alg_unary_eltwise::isinf:
out = in.isInf();
break;
case alg_unary_eltwise::isinf_neg:
out = in.isInf() && (in < 0);
break;
case alg_unary_eltwise::isinf_pos:
out = in.isInf() && (in > 0);
break;
default:
assert(0);
}
}

static odla_value odla_unary_eltwise(alg_unary_eltwise alg, odla_value input,
const odla_value_id value_id) {
// Extract type and size
auto elem_type = input->elem_type;
bool ret_bool =
(alg == alg_unary_eltwise::isnan || alg == alg_unary_eltwise::isinf ||
alg == alg_unary_eltwise::isinf_neg ||
alg == alg_unary_eltwise::isinf_pos);
if (ret_bool) {
elem_type = ODLA_BOOL;
}
int n = GetTotalElements(input->shape);
// Prepare destination memory
dnnl::memory dst_mem;
dnnl::memory::desc dst_md = getMemoryDesc({elem_type, input->shape});
dst_mem = dnnl::memory(dst_md, g_comp->eng);
auto v = CreateValue(dst_mem, input->shape, value_id);
v->elem_type = elem_type;
// Create lambda operation
auto op = [alg, ret_bool, input, dst_mem, n] {
void* dst = dst_mem.get_data_handle();
const void* data = input->mem.get_data_handle();
if (input->elem_type == ODLA_FLOAT32) {
ret_bool ? unary_eltwise_bool<float>(alg, dst, data, n)
: unary_eltwise_T<float>(alg, dst, data, n);
} else if (input->elem_type == ODLA_FLOAT64) {
ret_bool ? unary_eltwise_bool<double>(alg, dst, data, n)
: unary_eltwise_T<double>(alg, dst, data, n);
} else if (input->elem_type == ODLA_UINT8) {
ret_bool ? unary_eltwise_bool<uint8_t>(alg, dst, data, n)
: unary_eltwise_T<uint8_t>(alg, dst, data, n);
} else if (input->elem_type == ODLA_UINT16) {
ret_bool ? unary_eltwise_bool<uint16_t>(alg, dst, data, n)
: unary_eltwise_T<uint16_t>(alg, dst, data, n);
} else if (input->elem_type == ODLA_UINT32) {
ret_bool ? unary_eltwise_bool<uint32_t>(alg, dst, data, n)
: unary_eltwise_T<uint32_t>(alg, dst, data, n);
} else if (input->elem_type == ODLA_UINT64) {
ret_bool ? unary_eltwise_bool<uint64_t>(alg, dst, data, n)
: unary_eltwise_T<uint64_t>(alg, dst, data, n);
} else {
assert(0);
}
};
// Postprocess
add_op(op);
InterpretIfNeeded();
return v;
}

odla_value odla_IsNaN(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::isnan, input, value_id);
}

odla_value odla_IsInf(odla_value input, odla_bool detect_pos,
odla_bool detect_neg, const odla_value_id value_id) {
if (detect_pos != 0 && detect_neg != 0) {
return odla_unary_eltwise(alg_unary_eltwise::isinf, input, value_id);
}
if (detect_pos != 0) {
return odla_unary_eltwise(alg_unary_eltwise::isinf_pos, input, value_id);
}
return odla_unary_eltwise(alg_unary_eltwise::isinf_neg, input, value_id);
}

odla_value odla_Cos(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::cos, input, value_id);
}

odla_value odla_Sin(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::sin, input, value_id);
}

odla_value odla_Tan(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::tan, input, value_id);
}

odla_value odla_ACos(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::acos, input, value_id);
}

odla_value odla_ACosh(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::acosh, input, value_id);
}

odla_value odla_ASin(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::asin, input, value_id);
}

odla_value odla_ASinh(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::asinh, input, value_id);
}

odla_value odla_ATan(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::atan, input, value_id);
}

odla_value odla_ATanh(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::atanh, input, value_id);
}

odla_value odla_Sinh(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::sinh, input, value_id);
}

odla_value odla_Cosh(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::cosh, input, value_id);
}

odla_value odla_Ceil(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::ceil, input, value_id);
}

odla_value odla_Neg(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::neg, input, value_id);
}

odla_value odla_Reciprocal(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::reciprocal, input, value_id);
}

odla_value odla_Sign(odla_value input, const odla_value_id value_id) {
return odla_unary_eltwise(alg_unary_eltwise::sign, input, value_id);
}

odla_value odla_Conv(odla_value input, odla_memory_layout input_layout,
odla_uint32 group, odla_value kernel,
odla_memory_layout kernel_layout,
Expand Down
17 changes: 17 additions & 0 deletions ODLA/platforms/dnnl/odla_dnnl.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,4 +375,21 @@ static inline std::pair<dnnl::memory, dnnl::memory> broadcast_operands(
};
}

static inline odla_value unary_eltwise_op(
dnnl::algorithm algo, odla_value input, odla_float32 alpha,
odla_float32 beta, const odla_value_id id,
dnnl::primitive_attr attr = dnnl::primitive_attr()) {
auto eltwise_d =
dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo,
input->mem.get_desc(), alpha, beta);
auto pd = dnnl::eltwise_forward::primitive_desc(eltwise_d, attr, g_comp->eng);

dnnl::primitive prim = dnnl::eltwise_forward(pd);
auto ret_mem = dnnl::memory(input->mem.get_desc(), g_comp->eng);
odla_value v = CreateValue(ret_mem, input->shape, id);
add_op(prim, {{DNNL_ARG_SRC, input->mem}, {DNNL_ARG_DST, ret_mem}});
InterpretIfNeeded();
return v;
}

#endif // ODLA_DNNL_H_
2 changes: 0 additions & 2 deletions ODLA/platforms/dnnl/odla_dnnl_binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
enum class alg_binary_eltwise {
logic_or,
logic_and,
logic_not,
logic_xor,
cmp_equal,
cmp_less,
Expand Down Expand Up @@ -114,7 +113,6 @@ static void binary_eltwise_T(alg_binary_eltwise alg, void* dst,
bool binary_ret_bool(alg_binary_eltwise alg) {
return (alg == alg_binary_eltwise::logic_or) ||
(alg == alg_binary_eltwise::logic_and) ||
(alg == alg_binary_eltwise::logic_not) ||
(alg == alg_binary_eltwise::logic_xor) ||
(alg == alg_binary_eltwise::cmp_equal) ||
(alg == alg_binary_eltwise::cmp_less) ||
Expand Down
Loading

0 comments on commit 54e80f3

Please sign in to comment.