diff --git a/tensorflow/core/common_runtime/dml/dml_common.h b/tensorflow/core/common_runtime/dml/dml_common.h index 984b5146d4..34504170e7 100644 --- a/tensorflow/core/common_runtime/dml/dml_common.h +++ b/tensorflow/core/common_runtime/dml/dml_common.h @@ -131,6 +131,9 @@ static constexpr uint32_t kNchwSpatialDimensionCount = 2; static constexpr uint32_t kNcdhwDimensionCount = 5; static constexpr uint32_t kNcdhwSpatialDimensionCount = 3; +// 8 dimensions are supported for elementwise operators +static constexpr uint32_t kBinaryCwiseOpMaxDimCount = 8; + // The batch and channel dimensions of NCW, NCHW, NCDHW.... static constexpr uint32_t kNonspatialDimensionCount = 2; diff --git a/tensorflow/core/kernels/dml_cwise_ops.cc b/tensorflow/core/kernels/dml_cwise_ops.cc index 95a06afeb7..4c9f968581 100644 --- a/tensorflow/core/kernels/dml_cwise_ops.cc +++ b/tensorflow/core/kernels/dml_cwise_ops.cc @@ -450,8 +450,16 @@ REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(RealDiv, x / y, 8, true, Eigen::half, // cwise_op_floor_div.cc). REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(FloorDiv, dml::Floor(x / y), 8, true, Eigen::half, float) -REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(FloorMod, dml::ModulusFloor(x, y), 8, - true, Eigen::half, float, int64) +REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(FloorMod, dml::ModulusFloor(x, y), 8, + true, Eigen::half, float) +// TODO: Revisit this and consider having a native int64 alternative +// TFDML #41163316 +REGISTER_DML_COMPOSITE_BINARY_KERNEL_1( + FloorMod, + dml::Cast(dml::ModulusFloor(dml::Cast(x, DML_TENSOR_DATA_TYPE_INT32), + dml::Cast(y, DML_TENSOR_DATA_TYPE_INT32)), + DML_TENSOR_DATA_TYPE_INT64), + 8, false, int64) REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(SigmoidGrad, (y * x * (1 - x)), 8, false, Eigen::half, float) REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(TanhGrad, (y * (1 - x * x)), 8, false, @@ -481,25 +489,53 @@ REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Minimum, dml::Min(x, y), 8, true, // cwise_op_maximum.cc). REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Maximum, dml::Max(x, y), 8, true, Eigen::half, float, int64) -REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(SquaredDifference, +REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(SquaredDifference, dml::DifferenceSquare(x, y), 8, true, - Eigen::half, float, int64) + Eigen::half, float) +// TODO: Revisit this and consider having a native int64 alternative +// TFDML #41163316 +REGISTER_DML_COMPOSITE_BINARY_KERNEL_1(SquaredDifference, (x - y) * (x - y), 8, + false, int64) // TODO(b/25387198): A special kernel exists for int32 (see cwise_op_mul1.cc). REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Mul, (x * y), 8, true, Eigen::half, float, int64) -REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Pow, dml::Pow(x, y), 8, true, - Eigen::half, float, int64) +REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(Pow, dml::Pow(x, y), 8, true, + Eigen::half, float) +// TODO: Revisit this and consider having a native int64 alternative +// TFDML #41163316 +REGISTER_DML_COMPOSITE_BINARY_KERNEL_1( + Pow, + dml::Cast(dml::Pow(dml::Cast(x, DML_TENSOR_DATA_TYPE_INT32), + dml::Cast(y, DML_TENSOR_DATA_TYPE_INT32)), + DML_TENSOR_DATA_TYPE_INT64), + 8, false, int64) // TODO(b/25387198): A special kernel exists for int32 (see cwise_op_add1.cc). REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Add, x + y, 8, true, Eigen::half, float, int64) // TODO(b/25387198): A special kernel exists for int32 (see cwise_op_add1.cc). REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(AddV2, x + y, 8, true, Eigen::half, float, int64) -REGISTER_DML_COMPOSITE_BINARY_KERNEL_4(TruncateDiv, x / y, 8, true, uint8, - uint16, int16, int64) +REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(TruncateDiv, x / y, 8, true, uint8, + uint16, int16) +// TODO: Revisit this and consider having a native int64 alternative +// TFDML #41163316 +REGISTER_DML_COMPOSITE_BINARY_KERNEL_1( + TruncateDiv, + dml::Cast(dml::Cast(x, DML_TENSOR_DATA_TYPE_INT32) / + dml::Cast(y, DML_TENSOR_DATA_TYPE_INT32), + DML_TENSOR_DATA_TYPE_INT64), + 8, false, int64) // TODO(b/25387198): A special kernel exists for int32 (see cwise_op_div.cc). -REGISTER_DML_COMPOSITE_BINARY_KERNEL_6(Div, x / y, 8, true, Eigen::half, float, - uint8, uint16, int16, int64) +REGISTER_DML_COMPOSITE_BINARY_KERNEL_5(Div, x / y, 8, true, Eigen::half, float, + uint8, uint16, int16) +// TODO: Revisit this and consider having a native int64 alternative +// TFDML #41163316 +REGISTER_DML_COMPOSITE_BINARY_KERNEL_1( + Div, + dml::Cast(dml::Cast(x, DML_TENSOR_DATA_TYPE_INT32) / + dml::Cast(y, DML_TENSOR_DATA_TYPE_INT32), + DML_TENSOR_DATA_TYPE_INT64), + 8, false, int64) // TODO(b/25387198): A special kernel exists for int32 (see // cwise_op_greater.cc). REGISTER_DML_COMPOSITE_BINARY_KERNEL_6(Greater, x > y, 8, false, Eigen::half, @@ -869,10 +905,40 @@ class DmlLeakyReluKernel : public DmlKernel { TF_CALL_DML_FLOAT_TYPES(DML_REGISTER_KERNEL); #undef DML_REGISTER_KERNEL +class ApproximateEqualInitHelper + : public ElementWiseInitHelper { + public: + struct Attributes + : public ElementWiseInitHelper::Attributes { + explicit Attributes(OpKernelConstruction* ctx) + : ElementWiseInitHelper::Attributes(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("tolerance", &tolerance)); + } + float tolerance; + }; + ApproximateEqualInitHelper(OpKernelContext* ctx, + std::shared_ptr attr) + : ElementWiseInitHelper(ctx, attr), + tolerance_(attr->tolerance) { + const Tensor& x_input = ctx->input(0); + const Tensor& y_input = ctx->input(1); + OP_REQUIRES( + ctx, x_input.shape() == y_input.shape(), + errors::InvalidArgument("x and y must be of the same shape. ", + "x shape: ", x_input.shape().DebugString(), + ". y shape: ", y_input.shape().DebugString())); + } + + float GetTolerance() const { return tolerance_; } + + private: + float tolerance_; +}; + template class DmlApproximateEqualKernel : public DmlKernel { public: - using InitHelper = ElementWiseInitHelper; + using InitHelper = ApproximateEqualInitHelper; explicit DmlApproximateEqualKernel(DmlKernelConstruction* ctx, const InitHelper* init_helper) { @@ -891,11 +957,9 @@ class DmlApproximateEqualKernel : public DmlKernel { auto x = dml::InputTensor(scope, 0, inputs[0]); auto y = dml::InputTensor(scope, 1, inputs[1]); - float tolerance; - TF_CHECK_OK(ctx->GetAttr("tolerance", &tolerance)); - auto tolerance_tensor = - dml::ScalarTensor(scope, TfTensorTypeTraits::FromFloat(tolerance), - x.GetOutputDesc().sizes); + auto tolerance_tensor = dml::ScalarTensor( + scope, TfTensorTypeTraits::FromFloat(init_helper->GetTolerance()), + x.GetOutputDesc().sizes); auto result = dml::Abs(x - y) < tolerance_tensor;