Skip to content

Commit

Permalink
Add a small int64 workaround for elementwise ops that don't support it (
Browse files Browse the repository at this point in the history
#391)

It turns out that DIVIDE, MODULUS_FLOOR, DIFFERENCE_SQUARE, MODULUS_TRUNCATE and POW don't yet support emulated int64 in DirectML, so we can't entirely get rid of int64 workarounds.
  • Loading branch information
PatriceVignola authored Sep 6, 2022
1 parent 8e565a1 commit a4a0e27
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 16 deletions.
3 changes: 3 additions & 0 deletions tensorflow/core/common_runtime/dml/dml_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ static constexpr uint32_t kNchwSpatialDimensionCount = 2;
static constexpr uint32_t kNcdhwDimensionCount = 5;
static constexpr uint32_t kNcdhwSpatialDimensionCount = 3;

// 8 dimensions are supported for elementwise operators
static constexpr uint32_t kBinaryCwiseOpMaxDimCount = 8;

// The batch and channel dimensions of NCW, NCHW, NCDHW....
static constexpr uint32_t kNonspatialDimensionCount = 2;

Expand Down
96 changes: 80 additions & 16 deletions tensorflow/core/kernels/dml_cwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,16 @@ REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(RealDiv, x / y, 8, true, Eigen::half,
// cwise_op_floor_div.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(FloorDiv, dml::Floor(x / y), 8, true,
Eigen::half, float)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(FloorMod, dml::ModulusFloor(x, y), 8,
true, Eigen::half, float, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(FloorMod, dml::ModulusFloor(x, y), 8,
true, Eigen::half, float)
// TODO: Revisit this and consider having a native int64 alternative
// TFDML #41163316
REGISTER_DML_COMPOSITE_BINARY_KERNEL_1(
FloorMod,
dml::Cast(dml::ModulusFloor(dml::Cast(x, DML_TENSOR_DATA_TYPE_INT32),
dml::Cast(y, DML_TENSOR_DATA_TYPE_INT32)),
DML_TENSOR_DATA_TYPE_INT64),
8, false, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(SigmoidGrad, (y * x * (1 - x)), 8, false,
Eigen::half, float)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(TanhGrad, (y * (1 - x * x)), 8, false,
Expand Down Expand Up @@ -481,25 +489,53 @@ REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Minimum, dml::Min(x, y), 8, true,
// cwise_op_maximum.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Maximum, dml::Max(x, y), 8, true,
Eigen::half, float, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(SquaredDifference,
REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(SquaredDifference,
dml::DifferenceSquare(x, y), 8, true,
Eigen::half, float, int64)
Eigen::half, float)
// TODO: Revisit this and consider having a native int64 alternative
// TFDML #41163316
REGISTER_DML_COMPOSITE_BINARY_KERNEL_1(SquaredDifference, (x - y) * (x - y), 8,
false, int64)
// TODO(b/25387198): A special kernel exists for int32 (see cwise_op_mul1.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Mul, (x * y), 8, true, Eigen::half,
float, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Pow, dml::Pow(x, y), 8, true,
Eigen::half, float, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(Pow, dml::Pow(x, y), 8, true,
Eigen::half, float)
// TODO: Revisit this and consider having a native int64 alternative
// TFDML #41163316
REGISTER_DML_COMPOSITE_BINARY_KERNEL_1(
Pow,
dml::Cast(dml::Pow(dml::Cast(x, DML_TENSOR_DATA_TYPE_INT32),
dml::Cast(y, DML_TENSOR_DATA_TYPE_INT32)),
DML_TENSOR_DATA_TYPE_INT64),
8, false, int64)
// TODO(b/25387198): A special kernel exists for int32 (see cwise_op_add1.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Add, x + y, 8, true, Eigen::half, float,
int64)
// TODO(b/25387198): A special kernel exists for int32 (see cwise_op_add1.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(AddV2, x + y, 8, true, Eigen::half,
float, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_4(TruncateDiv, x / y, 8, true, uint8,
uint16, int16, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(TruncateDiv, x / y, 8, true, uint8,
uint16, int16)
// TODO: Revisit this and consider having a native int64 alternative
// TFDML #41163316
REGISTER_DML_COMPOSITE_BINARY_KERNEL_1(
TruncateDiv,
dml::Cast(dml::Cast(x, DML_TENSOR_DATA_TYPE_INT32) /
dml::Cast(y, DML_TENSOR_DATA_TYPE_INT32),
DML_TENSOR_DATA_TYPE_INT64),
8, false, int64)
// TODO(b/25387198): A special kernel exists for int32 (see cwise_op_div.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_6(Div, x / y, 8, true, Eigen::half, float,
uint8, uint16, int16, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_5(Div, x / y, 8, true, Eigen::half, float,
uint8, uint16, int16)
// TODO: Revisit this and consider having a native int64 alternative
// TFDML #41163316
REGISTER_DML_COMPOSITE_BINARY_KERNEL_1(
Div,
dml::Cast(dml::Cast(x, DML_TENSOR_DATA_TYPE_INT32) /
dml::Cast(y, DML_TENSOR_DATA_TYPE_INT32),
DML_TENSOR_DATA_TYPE_INT64),
8, false, int64)
// TODO(b/25387198): A special kernel exists for int32 (see
// cwise_op_greater.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_6(Greater, x > y, 8, false, Eigen::half,
Expand Down Expand Up @@ -869,10 +905,40 @@ class DmlLeakyReluKernel : public DmlKernel {
TF_CALL_DML_FLOAT_TYPES(DML_REGISTER_KERNEL);
#undef DML_REGISTER_KERNEL

class ApproximateEqualInitHelper
: public ElementWiseInitHelper<kBinaryCwiseOpMaxDimCount> {
public:
struct Attributes
: public ElementWiseInitHelper<kBinaryCwiseOpMaxDimCount>::Attributes {
explicit Attributes(OpKernelConstruction* ctx)
: ElementWiseInitHelper<kBinaryCwiseOpMaxDimCount>::Attributes(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("tolerance", &tolerance));
}
float tolerance;
};
ApproximateEqualInitHelper(OpKernelContext* ctx,
std::shared_ptr<const Attributes> attr)
: ElementWiseInitHelper<kBinaryCwiseOpMaxDimCount>(ctx, attr),
tolerance_(attr->tolerance) {
const Tensor& x_input = ctx->input(0);
const Tensor& y_input = ctx->input(1);
OP_REQUIRES(
ctx, x_input.shape() == y_input.shape(),
errors::InvalidArgument("x and y must be of the same shape. ",
"x shape: ", x_input.shape().DebugString(),
". y shape: ", y_input.shape().DebugString()));
}

float GetTolerance() const { return tolerance_; }

private:
float tolerance_;
};

template <typename T>
class DmlApproximateEqualKernel : public DmlKernel {
public:
using InitHelper = ElementWiseInitHelper<kNchwDimensionCount>;
using InitHelper = ApproximateEqualInitHelper;

explicit DmlApproximateEqualKernel(DmlKernelConstruction* ctx,
const InitHelper* init_helper) {
Expand All @@ -891,11 +957,9 @@ class DmlApproximateEqualKernel : public DmlKernel {
auto x = dml::InputTensor(scope, 0, inputs[0]);
auto y = dml::InputTensor(scope, 1, inputs[1]);

float tolerance;
TF_CHECK_OK(ctx->GetAttr("tolerance", &tolerance));
auto tolerance_tensor =
dml::ScalarTensor<T>(scope, TfTensorTypeTraits<T>::FromFloat(tolerance),
x.GetOutputDesc().sizes);
auto tolerance_tensor = dml::ScalarTensor<T>(
scope, TfTensorTypeTraits<T>::FromFloat(init_helper->GetTolerance()),
x.GetOutputDesc().sizes);

auto result = dml::Abs(x - y) < tolerance_tensor;

Expand Down

0 comments on commit a4a0e27

Please sign in to comment.