Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
guyang3532 committed Nov 7, 2023
1 parent ba4fa1e commit 8c7ce15
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 29 deletions.
2 changes: 1 addition & 1 deletion orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef{"FlattenAndUnpad", kMSDomain, 1},
{GO(0), I(1)},
{GI(0), IA("No_use")})};
{GI(0), IA("Unflatten_dims")})};
}

IMPLEMENT_GRADIENT_BUILDER(GetFlattenAndUnpadGradient) {
Expand Down
6 changes: 3 additions & 3 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4773,11 +4773,11 @@ Return true if all elements are true and false otherwise.
.SinceVersion(1)
.SetDoc(
"FlattenAndUnpad operator flattens the first two dims of input tensor, and unpad according to given indices."
"This is used by padding elimination graph transformers.")
.Input(0, "input", "input data of rank N, shape is [M1, M2, d2, ..., dN]", "T")
"This is used by padding elimination graph transformer.")
.Input(0, "input", "input data of rank N + 1, shape is [M1, M2, d2, ..., dN]", "T")
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",

Check warning on line 4778 in orttraining/orttraining/core/graph/training_op_defs.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/core/graph/training_op_defs.cc#L4778

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
orttraining/orttraining/core/graph/training_op_defs.cc:4778:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
"T_INT")
.Output(0, "output", "output data of rank N-1, [d1, d2, ..., dN]", "T")
.Output(0, "output", "output data of rank N, [d1, d2, ..., dN]", "T")
.Output(1, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
.TypeConstraint(
"T_INT",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace test {

#if defined(USE_CUDA) || defined(USE_ROCM)

TEST(FlattenAndUnpadTest, Int32Type1D) {
TEST(FlattenAndUnpadTest, Int32Type2D) {
std::vector<int32_t> input = {1, 1, 3, 2, 0, 3, 0, 4,
0, 5, 0, 6, 0, 0, 0};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
Expand All @@ -25,7 +25,7 @@ TEST(FlattenAndUnpadTest, Int32Type1D) {
test.Run();
}

TEST(FlattenAndUnpadTest, Int32Type2D) {
TEST(FlattenAndUnpadTest, Int32Type3D) {
std::vector<int32_t> input = {0, 0, 0, 1, 2, 3, 0, 0, 0,
4, 5, 6, 7, 8, 9, 0, 0, 0};
std::vector<int64_t> indices = {1, 3, 4};
Expand All @@ -41,7 +41,7 @@ TEST(FlattenAndUnpadTest, Int32Type2D) {
test.Run();
}

TEST(FlattenAndUnpadTest, Int64Type1D) {
TEST(FlattenAndUnpadTest, Int64Type2D) {
std::vector<int64_t> input = {1, 1, 3, 2, 0, 3, 0, 4,
0, 5, 0, 6, 0, 0, 0};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
Expand All @@ -57,7 +57,7 @@ TEST(FlattenAndUnpadTest, Int64Type1D) {
test.Run();
}

TEST(FlattenAndUnpadTest, Int64Type2D) {
TEST(FlattenAndUnpadTest, Int64Type3D) {
std::vector<int64_t> input = {0, 0, 0, 1, 2, 3, 0, 0, 0,
4, 5, 6, 7, 8, 9, 0, 0, 0};
std::vector<int64_t> indices = {1, 3, 4};
Expand All @@ -73,7 +73,7 @@ TEST(FlattenAndUnpadTest, Int64Type2D) {
test.Run();
}

TEST(FlattenAndUnpadTest, FloatType1D) {
TEST(FlattenAndUnpadTest, FloatType2D) {
std::vector<float> input = {1.0f, 1.0f, 3.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
Expand All @@ -89,7 +89,7 @@ TEST(FlattenAndUnpadTest, FloatType1D) {
test.Run();
}

TEST(FlattenAndUnpadTest, FloatType2D) {
TEST(FlattenAndUnpadTest, FloatType3D) {
std::vector<float> input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 4};
Expand All @@ -105,7 +105,7 @@ TEST(FlattenAndUnpadTest, FloatType2D) {
test.Run();
}

TEST(FlattenAndUnpadTest, MLFloat16Type1D) {
TEST(FlattenAndUnpadTest, MLFloat16Type2D) {
std::vector<float> input = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f,
0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11};
Expand All @@ -128,7 +128,7 @@ TEST(FlattenAndUnpadTest, MLFloat16Type1D) {
test.Run();
}

TEST(FlattenAndUnpadTest, MLFloat16Type2D) {
TEST(FlattenAndUnpadTest, MLFloat16Type3D) {
std::vector<float> input = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f,
4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f};
std::vector<int64_t> indices = {1, 3, 4};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ struct FlattenAndUnpadFunctor {
Status FlattenAndUnpad::ComputeInternal(OpKernelContext* context) const {
const Tensor* input_tensor = context->Input<Tensor>(0);
const Tensor* indices_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(input_tensor->Shape().NumDimensions() >= 2,
"input_tensor tensor must have at least 2 dimensions.", input_tensor->Shape().NumDimensions());
ORT_ENFORCE(indices_tensor->Shape().NumDimensions() == 1,
"indices_tensor tensor must be 1-D.", indices_tensor->Shape().NumDimensions());

const auto& input_shape = input_tensor->Shape();
std::vector<int64_t> output_shape_vec;
output_shape_vec.reserve(input_shape.NumDimensions() - 1);
output_shape_vec.push_back(indices_tensor->Shape()[0]);
const auto& input_shape = input_tensor->Shape();
int64_t element_stride = 1;
for (size_t i = 2; i < input_shape.NumDimensions(); ++i) {
output_shape_vec.push_back(input_shape[i]);
Expand All @@ -62,6 +65,7 @@ Status FlattenAndUnpad::ComputeInternal(OpKernelContext* context) const {
Tensor* output_tensor = context->Output(0, output_shape);

std::vector<int64_t> unflatten_dims_vec;

Check warning on line 67 in orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc#L67

Add #include <vector> for vector<> [build/include_what_you_use] [4]
Raw output
orttraining/orttraining/training_ops/cuda/tensor/flatten_and_unpad.cc:67:  Add #include <vector> for vector<>  [build/include_what_you_use] [4]
unflatten_dims_vec.reserve(2);
unflatten_dims_vec.push_back(input_shape[0]);
unflatten_dims_vec.push_back(input_shape[1]);
const int64_t index_value_upper_bound = input_shape[0] * input_shape[1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void FlattenAndUnpadImpl(cudaStream_t stream,
output_data);
}

#define SPECIALIZED_RESTORE_FROM_MASK_IMPL(T) \
#define FLATTEN_AND_UNPAD_FROM_MASK_IMPL(T) \
template void FlattenAndUnpadImpl<T>(cudaStream_t stream, \
const int64_t total_element_count, \
const fast_divmod output_element_stride_fdm, \
Expand All @@ -70,14 +70,14 @@ void FlattenAndUnpadImpl(cudaStream_t stream,
const int64_t* indices_data, \
T* output_data);

SPECIALIZED_RESTORE_FROM_MASK_IMPL(float)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(double)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(half)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(BFloat16)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(int32_t)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(int64_t)
FLATTEN_AND_UNPAD_FROM_MASK_IMPL(float)
FLATTEN_AND_UNPAD_FROM_MASK_IMPL(double)
FLATTEN_AND_UNPAD_FROM_MASK_IMPL(half)
FLATTEN_AND_UNPAD_FROM_MASK_IMPL(BFloat16)
FLATTEN_AND_UNPAD_FROM_MASK_IMPL(int32_t)
FLATTEN_AND_UNPAD_FROM_MASK_IMPL(int64_t)

#undef SPECIALIZED_RESTORE_FROM_MASK_IMPL
#undef FLATTEN_AND_UNPAD_FROM_MASK_IMPL

} // namespace cuda
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void PadAndUnflattenImpl(cudaStream_t stream,
output_data);
}

#define SPECIALIZED_RESTORE_FROM_MASK_IMPL(T) \
#define PAD_AND_UNFLATTEN_FROM_MASK_IMPL(T) \
template void PadAndUnflattenImpl<T>(cudaStream_t stream, \
const int64_t total_element_count, \
const fast_divmod output_element_stride_fdm, \
Expand All @@ -70,12 +70,12 @@ void PadAndUnflattenImpl(cudaStream_t stream,
const int64_t* indices_data, \
T* output_data);

SPECIALIZED_RESTORE_FROM_MASK_IMPL(float)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(double)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(half)
SPECIALIZED_RESTORE_FROM_MASK_IMPL(BFloat16)
PAD_AND_UNFLATTEN_FROM_MASK_IMPL(float)
PAD_AND_UNFLATTEN_FROM_MASK_IMPL(double)
PAD_AND_UNFLATTEN_FROM_MASK_IMPL(half)
PAD_AND_UNFLATTEN_FROM_MASK_IMPL(BFloat16)

#undef SPECIALIZED_RESTORE_FROM_MASK_IMPL
#undef PAD_AND_UNFLATTEN_FROM_MASK_IMPL

} // namespace cuda
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, FlattenAndUnpad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ResizeGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ResizeGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, FlattenAndUnpad);

#if defined(ORT_USE_NCCL) || defined(USE_MPI)
// P2P communication operators.
Expand Down Expand Up @@ -391,10 +391,10 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, FlattenAndUnpad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ResizeGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ResizeGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, FlattenAndUnpad)>,

// P2P communication operators.
#if defined(ORT_USE_NCCL) || defined(USE_MPI)
Expand Down

0 comments on commit 8c7ce15

Please sign in to comment.