Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tt #94

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

Tt #94

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <stdio.h>
#include <stdint.h>
#include <cuda_runtime.h>
#include "backend/kernel_compiler/gpu/cuda_impl/tensor_add_v2_impl.cuh"
template <typename T>
__global__ void TensorAddV2Kernel(const size_t element_num, const T* x1, const T* x2, T* y) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < element_num; i += blockDim.x * gridDim.x) {
y[i] = x1[i] + x2[i];
}
}

template <typename T>
void TensorAddV2(const size_t &element_num, const T* x1, const T* x2, T* y, cudaStream_t cuda_stream) {
size_t thread_per_block = 256;
size_t block_per_grid = (element_num + thread_per_block -1)/thread_per_block;
TensorAddV2Kernel<<<block_per_grid, thread_per_block, 0, cuda_stream>>>(element_num,x1,x2,y);
return;
}

template void TensorAddV2(const size_t &element_num, const float* x1, const float* x2, float* y, cudaStream_t cuda_stream);
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TENSOR_ADD_V2_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TENSOR_ADD_V2_H_

#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void TensorAddV2(const size_t &element_num, const T* x1, const T* x2, T* y, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TENSOR_ADD_V2_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/kernel_compiler/gpu/math/tensor_add_v2_gpu_kernel.h"

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
TensorAddV2, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TensorAddV2GpuKernel, float)
} // namespace kernel
} // namespace mindspore
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under thea License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TENSOR_ADD_V2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TENSOR_ADD_V2_GPU_KERNEL_H_

#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"

#include "backend/kernel_compiler/gpu/cuda_impl/tensor_add_v2_impl.cuh"


namespace mindspore {
namespace kernel {
template <typename T>
class TensorAddV2GpuKernel : public GpuKernel {
public:
TensorAddV2GpuKernel()
: element_num_(1) {}
~TensorAddV2GpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *x1 = GetDeviceAddress<T>(inputs, 0);
T *x2 = GetDeviceAddress<T>(inputs, 1);
T *y = GetDeviceAddress<T>(outputs, 0);
TensorAddV2(element_num_, x1, x2, y, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for(size_t i = 0; i < shape.size(); i++){
element_num_ *= shape[i];
}
InitSizeLists();
return true;
}

protected:
void InitSizeLists() override {
input_size_list_.push_back(element_num_ * sizeof(T));
input_size_list_.push_back(element_num_ * sizeof(T));
output_size_list_.push_back(element_num_ * sizeof(T));
}

private:
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

size_t element_num_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ADDN_GPU_KERNEL_H_
3 changes: 2 additions & 1 deletion mindspore/ops/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, IFMR,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan, TensorAddV2)

from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial)
Expand Down Expand Up @@ -382,6 +382,7 @@
"Pull",
"ReLUV2",
'SparseToDense',
"TensorAddV2",
]

__all__.sort()
20 changes: 20 additions & 0 deletions mindspore/ops/operations/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,26 @@ def infer_dtype(self, x1_type, x2_type):
return _BitwiseBinaryOp._check_bitwise_op_input_type(x1_type, x2_type, self.name)


class TensorAddV2(PrimitiveWithInfer):
"""
add two input tensors
"""
@prim_attr_register
def __init__(self):
"""init tensoraddv2"""
self.init_prim_io_names(inputs=["x1","x2"],outputs=["y"])

def infer_shape(self, x1_shape, x2_shape):
validator.check_int(len(x1_shape), len(x2_shape), Rel.EQ, "input dims", self.name)
for i in range(len(x1_shape)):
validator.check_int(x1_shape[i], x2_shape[i], Rel.EQ, "input_shape", self.name)
return x1_shape

def infer_dtype(self, x1_dtype, x2_dtype):
validator.check_tensor_type_same({"x1_dtype":x1_dtype}, [mstype.float32], self.name)
validator.check_tensor_type_same({"x2_dtype":x2_dtype}, [mstype.float32], self.name)
return x1_dtype

class TensorAdd(_MathBinaryOp):
"""
Adds two input tensors element-wise.
Expand Down