From 59db1de68b7c1cf82d0628717f9972c65a47dd52 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 2 Dec 2024 14:24:31 -0800 Subject: [PATCH] address yulong comments --- onnxruntime/core/providers/webgpu/tensor/flatten.h | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/flatten.h b/onnxruntime/core/providers/webgpu/tensor/flatten.h index 9d8c2ab6f6a42..5fc49a844b404 100644 --- a/onnxruntime/core/providers/webgpu/tensor/flatten.h +++ b/onnxruntime/core/providers/webgpu/tensor/flatten.h @@ -13,20 +13,17 @@ namespace webgpu { class Flatten final : public OpKernel { public: explicit Flatten(const OpKernelInfo& info) : OpKernel{info} { - ORT_ENFORCE(info.GetAttr("axis", &axis_).IsOK()); + axis_ = info.GetAttrOrDefault("axis", 1); } Status Compute(OpKernelContext* context) const override { const Tensor* input_tensor = context->Input(0); - if (input_tensor == nullptr) { - return Status(common::ONNXRUNTIME, common::FAIL, "Input tensor is not set"); - } const TensorShape& input_shape = input_tensor->Shape(); int64_t input_rank = input_shape.NumDimensions(); // Handle negative axis int64_t axis = axis_; - if (axis_ < 0) { + if (axis < 0) { axis += input_rank; } @@ -34,8 +31,6 @@ class Flatten final : public OpKernel { return Status(common::ONNXRUNTIME, common::FAIL, "Invalid value for axis, must be less than or equal to input_rank"); } - std::initializer_list output_dims; - int64_t first_dim = 1; for (int64_t i = 0; i < axis; i++) { first_dim *= input_shape[i]; @@ -45,9 +40,8 @@ class Flatten final : public OpKernel { for (int64_t i = axis; i < input_rank; i++) { second_dim *= input_shape[i]; } - output_dims = {first_dim, second_dim}; - TensorShape output_shape(output_dims); + TensorShape output_shape({first_dim, second_dim}); Tensor* output_tensor = context->Output(0, output_shape); const void* source = input_tensor->DataRaw();