From 18bcb6cbed3702eeaddc6479f9f8c03ef64b8bcc Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 3 Oct 2022 14:43:20 -0700 Subject: [PATCH] Add InTopKV2 (#393) --- tensorflow/core/kernels/BUILD | 1 + tensorflow/core/kernels/dml_in_topk_op.cc | 210 ++++++++++++++++++++++ 2 files changed, 211 insertions(+) create mode 100644 tensorflow/core/kernels/dml_in_topk_op.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ce2c7bfc4a..86a068d658 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -8079,6 +8079,7 @@ tf_kernel_library( "dml_ops_common.cc", "dml_cwise_ops.cc", "dml_image_ops.cc", + "dml_in_topk_op.cc", "dml_relu_op.cc", "dml_random_ops.cc", "dml_batch_norm_ops.cc", diff --git a/tensorflow/core/kernels/dml_in_topk_op.cc b/tensorflow/core/kernels/dml_in_topk_op.cc new file mode 100644 index 0000000000..a7bfdf24af --- /dev/null +++ b/tensorflow/core/kernels/dml_in_topk_op.cc @@ -0,0 +1,210 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +Portions Copyright (c) Microsoft Corporation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/dml/dml_operator_helper.h" +#include "tensorflow/core/common_runtime/dml/dml_util.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/dml_kernel_wrapper.h" +#include "tensorflow/core/kernels/dml_ops_common.h" + +namespace tensorflow { + +class DmlInTopKInitHelper : public InitializationHelper { + public: + using Attributes = EmptyAttributes; + + DmlInTopKInitHelper(OpKernelContext* ctx, + std::shared_ptr attr) { + const auto& predictions_in = ctx->input(0); + const auto& targets_in = ctx->input(1); + const auto& k_in = ctx->input(2); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(k_in.shape()), + errors::InvalidArgument("k must be 0-D, got shape ", + k_in.shape().DebugString())); + OP_REQUIRES(ctx, predictions_in.dims() == 2, + errors::InvalidArgument("predictions must be 2-dimensional")); + OP_REQUIRES(ctx, targets_in.dims() == 1, + errors::InvalidArgument("targets must be 1-dimensional")); + OP_REQUIRES(ctx, predictions_in.dim_size(0) == targets_in.dim_size(0), + errors::InvalidArgument("First dimension of predictions ", + predictions_in.dim_size(0), + " must match length of targets ", + targets_in.dim_size(0))); + + // TODO: Remove once K is moved into device memory + if (k_in.dtype() == DT_INT32) { + k_ = k_in.scalar()(); + } else { + assert(k_in.dtype() == DT_INT64); + k_ = k_in.scalar()(); + } + } + int GetK() const { return k_; } + + private: + int64_t k_ = -1; +}; + +class DmlInTopKKernel : public DmlKernel { + public: + using InitHelper = DmlInTopKInitHelper; + explicit DmlInTopKKernel(DmlKernelConstruction* ctx, + const InitHelper* init_helper) { + const auto predictions_shape = ctx->GetInputTensorShape(0); + + DmlTensorInfo predictions_info; + predictions_info.kernel_index = 0; + predictions_info.desc = DmlTensorDesc::Create( + ctx->GetInputDataType(0), predictions_shape, predictions_shape); + + DmlTensorInfo targets_info; + targets_info.kernel_index = 1; + targets_info.desc = DmlTensorDesc::Create(ctx->GetInputDataType(1), + ctx->GetInputTensorShape(1), + ctx->GetInputTensorShape(1)); + + DmlTensorInfo output_info; + output_info.kernel_index = 0; + output_info.desc = DmlTensorDesc::Create(ctx->GetOutputDataType(0), + ctx->GetOutputTensorShape(0), + ctx->GetOutputTensorShape(0)); + output_info.desc.ForceUnsignedDataType(); + + DmlKernelTensors tensors; + tensors.inputs = {predictions_info, targets_info}; + tensors.outputs = {output_info}; + + int64_t k = init_helper->GetK(); + + // DML doesn't support K values bigger than the size of the TopK axis, + // so clamp it to the maximum + k = std::min( + k, predictions_shape.dim_size(predictions_shape.dims() - 1)); + + // When K is smaller than 1, none of the targets are in the top K + if (k < 1) { + all_false_ = true; + InitializeAsNoOp(ctx); + return; + } + + auto inputs = GetDmlTensorDescs(tensors.inputs); + auto scope = dml::Graph(ctx->GetDmlDevice()); + auto predictions = dml::InputTensor(scope, 0, inputs[0]); + auto targets = dml::InputTensor(scope, 1, inputs[1]); + + uint32_t axis = predictions_info.desc.GetDimensionCount() - 1; + dml::TopKOutputs topk_result = + dml::TopK(predictions, axis, k, DML_AXIS_DIRECTION_DECREASING); + + uint32_t num_classes = ctx->GetInputTensorShape(0).dim_size(1); + dml::Expression classes; + dml::Expression num_classes_tensor; + dml::Expression zero_tensor; + if (ctx->GetInputDataType(1) == DT_INT32) { + classes = dml::Sequence(scope, 0, 1, {1, 1, 1, num_classes}); + num_classes_tensor = dml::ScalarTensor( + scope, num_classes, targets.GetOutputDesc().sizes); + zero_tensor = + dml::ScalarTensor(scope, 0, targets.GetOutputDesc().sizes); + } else { + assert(ctx->GetInputDataType(1) == DT_INT64); + classes = dml::Sequence(scope, 0, 1, {1, 1, 1, num_classes}); + num_classes_tensor = dml::ScalarTensor( + scope, num_classes, targets.GetOutputDesc().sizes); + zero_tensor = + dml::ScalarTensor(scope, 0, targets.GetOutputDesc().sizes); + } + + // Broadcast the classes list to the number of targets + auto sparse_classes = + dml::Reinterpret(classes, predictions.GetOutputDesc().sizes, + dml::TensorStrides({0, 0, 0, 1})); + + auto zero_float_tensor = + dml::ScalarTensor(scope, 0.0, predictions.GetOutputDesc().sizes); + + auto sparse_targets = + dml::Reinterpret(targets, predictions.GetOutputDesc().sizes, + dml::TensorStrides({0, 0, 1, 0})); + + // Create a sparse predictions matrix where the value of the prediction + // is set to 0 when the column doesn't match the target + auto sparse_predictions = dml::If(sparse_targets == sparse_classes, + predictions, zero_float_tensor); + + auto class_prediction_per_batch = + dml::Reduce(sparse_predictions, DML_REDUCE_FUNCTION_SUM, {3}); + + class_prediction_per_batch = dml::Reinterpret( + class_prediction_per_batch, targets.GetOutputDesc().sizes, + dml::TensorStrides({0, 0, 0, 1})); + + // To handle ties, we check to see if the class prediction is greater or + // equal to the lowest prediction obtained from the TopK result + auto lowest_topk_value_per_batch = + k > 1 ? dml::Split(topk_result.value, 3, + {static_cast(k) - 1u, 1u})[1] + : topk_result.value; + + // Transpose lowest_topk_value_per_batch from a column to a row in order + // to match the shape of class_prediction_per_batch + lowest_topk_value_per_batch = + dml::Reinterpret(lowest_topk_value_per_batch, + class_prediction_per_batch.GetOutputDesc().sizes, + dml::TensorStrides({0, 0, 0, 1})); + + // Out of bounds and "inf" targets always yield "false" + auto result = targets >= zero_tensor && targets < num_classes_tensor && + class_prediction_per_batch >= lowest_topk_value_per_batch && + !dml::IsInfinity(class_prediction_per_batch); + + Microsoft::WRL::ComPtr compiled_op = + scope.Compile(DML_EXECUTION_FLAG_NONE, {result}); + + Initialize(ctx, std::move(tensors), compiled_op.Get()); + } + + StatusOr Compute(DmlKernelContext* ctx) const override { + if (all_false_) { + Tensor* output = ctx->GetOutputTensor(0); + return ctx->GetDmlDeviceContext()->ZeroBuffer( + ctx->GetDmlDeviceContext()->GetBufferForTensor(*output)); + } + + return DmlKernel::Compute(ctx); + } + + bool all_false_ = false; +}; + +// TODO: Remove k pinning to host once we have an InTopKV2 DML kernel and K is +// moved into device memory +REGISTER_KERNEL_BUILDER( + Name("InTopKV2") + .Device(DEVICE_DML) + .TypeConstraint("T") + .HostMemory("k"), + DmlKernelWrapper>); +REGISTER_KERNEL_BUILDER( + Name("InTopKV2") + .Device(DEVICE_DML) + .TypeConstraint("T") + .HostMemory("k"), + DmlKernelWrapper>); + +} // namespace tensorflow \ No newline at end of file