diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_add_v2_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_add_v2_impl.cu new file mode 100644 index 00000000000..04bb270f16e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_add_v2_impl.cu @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/tensor_add_v2_impl.cuh" +template +__global__ void TensorAddV2Kernel(const size_t element_num, const T* x1, const T* x2, T* y) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < element_num; i += blockDim.x * gridDim.x) { + y[i] = x1[i] + x2[i]; + } +} + +template +void TensorAddV2(const size_t &element_num, const T* x1, const T* x2, T* y, cudaStream_t cuda_stream) { + size_t thread_per_block = 256; + size_t block_per_grid = (element_num + thread_per_block -1)/thread_per_block; + TensorAddV2Kernel<<>>(element_num,x1,x2,y); + return; +} + +template void TensorAddV2(const size_t &element_num, const float* x1, const float* x2, float* y, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_add_v2_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_add_v2_impl.cuh new file mode 100644 index 00000000000..867cd7c651f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_add_v2_impl.cuh @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TENSOR_ADD_V2_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TENSOR_ADD_V2_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void TensorAddV2(const size_t &element_num, const T* x1, const T* x2, T* y, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TENSOR_ADD_V2_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensor_add_v2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensor_add_v2_gpu_kernel.cc new file mode 100644 index 00000000000..9d2719040ff --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensor_add_v2_gpu_kernel.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/tensor_add_v2_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + TensorAddV2, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + TensorAddV2GpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensor_add_v2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensor_add_v2_gpu_kernel.h new file mode 100644 index 00000000000..7a9c6f94bfb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensor_add_v2_gpu_kernel.h @@ -0,0 +1,73 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under thea License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TENSOR_ADD_V2_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TENSOR_ADD_V2_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +#include "backend/kernel_compiler/gpu/cuda_impl/tensor_add_v2_impl.cuh" + + +namespace mindspore { +namespace kernel { +template +class TensorAddV2GpuKernel : public GpuKernel { + public: + TensorAddV2GpuKernel() + : element_num_(1) {} + ~TensorAddV2GpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *x1 = GetDeviceAddress(inputs, 0); + T *x2 = GetDeviceAddress(inputs, 1); + T *y = GetDeviceAddress(outputs, 0); + TensorAddV2(element_num_, x1, x2, y, reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for(size_t i = 0; i < shape.size(); i++){ + element_num_ *= shape[i]; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(element_num_ * sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(T)); + output_size_list_.push_back(element_num_ * sizeof(T)); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t element_num_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ADDN_GPU_KERNEL_H_ diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 1e9a0b5713d..1f015026e4f 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -54,7 +54,7 @@ NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, IFMR, - Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan) + Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, TensorAddV2) from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal, RandomCategorical, StandardLaplace, Multinomial) @@ -382,6 +382,7 @@ "Pull", "ReLUV2", 'SparseToDense', + "TensorAddV2", ] __all__.sort() diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index e294bb92e46..9f2c8b48894 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -113,6 +113,26 @@ def infer_dtype(self, x1_type, x2_type): return _BitwiseBinaryOp._check_bitwise_op_input_type(x1_type, x2_type, self.name) +class TensorAddV2(PrimitiveWithInfer): + """ + add two input tensors + """ + @prim_attr_register + def __init__(self): + """init tensoraddv2""" + self.init_prim_io_names(inputs=["x1","x2"],outputs=["y"]) + + def infer_shape(self, x1_shape, x2_shape): + validator.check_int(len(x1_shape), len(x2_shape), Rel.EQ, "input dims", self.name) + for i in range(len(x1_shape)): + validator.check_int(x1_shape[i], x2_shape[i], Rel.EQ, "input_shape", self.name) + return x1_shape + + def infer_dtype(self, x1_dtype, x2_dtype): + validator.check_tensor_type_same({"x1_dtype":x1_dtype}, [mstype.float32], self.name) + validator.check_tensor_type_same({"x2_dtype":x2_dtype}, [mstype.float32], self.name) + return x1_dtype + class TensorAdd(_MathBinaryOp): """ Adds two input tensors element-wise.