Skip to content

Commit

Permalink
lint, thread
Browse files Browse the repository at this point in the history
Signed-off-by: Liqun Fu <[email protected]>
  • Loading branch information
liqunfu committed Oct 3, 2023
1 parent 6634647 commit 6c6dc97
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 75 deletions.
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Do not modify directly.*
|||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|Affine|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|AffineGrid|*in* theta:**T1**<br> *in* size:**T2**<br> *out* grid:**T1**|20+|**T1** = tensor(float)<br/> **T2** = tensor(int64)|
|And|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)<br/> **T1** = tensor(bool)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)|
Expand Down
122 changes: 66 additions & 56 deletions onnxruntime/core/providers/cpu/tensor/affine_grid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,95 +13,99 @@

namespace onnxruntime {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
AffineGrid, \
20, \
T, \
KernelDefBuilder() \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int64_t>()), \
#define REGISTER_KERNEL_TYPED(T) \
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
AffineGrid, \
20, \
T, \
KernelDefBuilder() \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int64_t>()), \
AffineGrid<T>);

REGISTER_KERNEL_TYPED(float)

void generate_base_grid_2d(int64_t H, int64_t W, bool align_corners, Eigen::Matrix<float, Eigen::Dynamic, 2>& base_grid) {

Check warning on line 28 in onnxruntime/core/providers/cpu/tensor/affine_grid.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cpu/tensor/affine_grid.cc#L28

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cpu/tensor/affine_grid.cc:28:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(W, -1, 1);
Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(W), -1, 1);
if (!align_corners) {
row_vec = row_vec * (W - 1) / W;
}
Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(H, -1, 1);
Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(H), -1, 1);
if (!align_corners) {
col_vec = col_vec * (H - 1) / H;
}

base_grid.resize(H * W, 2);
for (int j = 0; j < H; j++) {
for (int i = 0; i < W; i++) {
for (Eigen::Index j = 0; j < H; j++) {
for (Eigen::Index i = 0; i < W; i++) {
base_grid.row(j * W + i) << row_vec(i), col_vec(j);
}
}
}

void generate_base_grid_3d(int64_t D, int64_t H, int64_t W, bool align_corners, Eigen::Matrix<float, Eigen::Dynamic, 3>& base_grid) {

Check warning on line 46 in onnxruntime/core/providers/cpu/tensor/affine_grid.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cpu/tensor/affine_grid.cc#L46

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cpu/tensor/affine_grid.cc:46:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(W, -1, 1);
Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(W), -1, 1);
if (!align_corners) {
row_vec = row_vec * (W - 1) / W;
}
Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(H, -1, 1);
Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(H), -1, 1);
if (!align_corners) {
col_vec = col_vec * (H - 1) / H;
}

Eigen::VectorXf slice_vec = Eigen::VectorXf::LinSpaced(D, -1, 1);
Eigen::VectorXf slice_vec = Eigen::VectorXf::LinSpaced(static_cast<Eigen::Index>(D), -1, 1);
if (!align_corners) {
slice_vec = slice_vec * (D - 1) / D;
}

base_grid.resize(D * H * W, 3);
for (int k = 0; k < D; k++) {
for (int j = 0; j < H; j++) {
for (int i = 0; i < W; i++) {
for (Eigen::Index k = 0; k < D; k++) {
for (Eigen::Index j = 0; j < H; j++) {
for (Eigen::Index i = 0; i < W; i++) {
base_grid.row(k * H * W + j * W + i) << row_vec(i), col_vec(j), slice_vec(k);
}
}
}
}

void affine_grid_generator_2d(const Tensor* theta, const Eigen::Matrix<float, 2, Eigen::Dynamic>& base_grid_transposed, int64_t batch_num, int64_t H, int64_t W, Tensor* grid) {
const Eigen::StorageOptions option = Eigen::RowMajor;
auto theta_batch_offset = batch_num * 2 * 3;
const float* theta_data = theta->Data<float>() + theta_batch_offset;
const Eigen::Matrix<float, 2, 2, option> theta_R{{theta_data[0], theta_data[1]}, {theta_data[3], theta_data[4]}};
const Eigen::Array<float, 2, 1> theta_T(theta_data[2], theta_data[5]);

auto grid_batch_offset = batch_num * H * W * 2;
float* grid_data = grid->MutableData<float>() + grid_batch_offset;
Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, 2, option>> grid_matrix(grid_data, narrow<size_t>(H * W), 2);
grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose();
}
template <typename T>
struct AffineGridGenerator2D {
void operator()(const Tensor* theta, const Eigen::Matrix<T, 2, Eigen::Dynamic>& base_grid_transposed, int64_t batch_num, int64_t H, int64_t W, Tensor* grid) const {

Check warning on line 72 in onnxruntime/core/providers/cpu/tensor/affine_grid.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cpu/tensor/affine_grid.cc#L72

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cpu/tensor/affine_grid.cc:72:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
const Eigen::StorageOptions option = Eigen::RowMajor;
auto theta_batch_offset = batch_num * 2 * 3;
const T* theta_data = theta->Data<T>() + theta_batch_offset;
const Eigen::Matrix<T, 2, 2, option> theta_R{{theta_data[0], theta_data[1]}, {theta_data[3], theta_data[4]}};
const Eigen::Array<T, 2, 1> theta_T(theta_data[2], theta_data[5]);

auto grid_batch_offset = batch_num * H * W * 2;
T* grid_data = grid->MutableData<T>() + grid_batch_offset;
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 2, option>> grid_matrix(grid_data, narrow<size_t>(H * W), 2);
grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose();
}
};

void affine_grid_generator_3d(const Tensor* theta, const Eigen::Matrix<float, 3, Eigen::Dynamic>& base_grid_transposed, int64_t batch_num, int64_t D, int64_t H, int64_t W, Tensor* grid) {
const Eigen::StorageOptions option = Eigen::RowMajor;
auto theta_batch_offset = batch_num * 3 * 4;
const float* theta_data = theta->Data<float>() + theta_batch_offset;
const Eigen::Matrix<float, 3, 3, option> theta_R{
{theta_data[0], theta_data[1], theta_data[2]},
{theta_data[4], theta_data[5], theta_data[6]},
{theta_data[8], theta_data[9], theta_data[10]}
};
const Eigen::Array<float, 3, 1> theta_T(theta_data[3], theta_data[7], theta_data[11]);

auto grid_batch_offset = batch_num * D * H * W * 3;
float* grid_data = grid->MutableData<float>() + grid_batch_offset;
Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, 3, option>> grid_matrix(grid_data, narrow<size_t>(D * H * W), 3);
grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose();
}
template <typename T>
struct AffineGridGenerator3D {
void operator()(const Tensor* theta, const Eigen::Matrix<float, 3, Eigen::Dynamic>& base_grid_transposed, int64_t batch_num, int64_t D, int64_t H, int64_t W, Tensor* grid) {

Check warning on line 88 in onnxruntime/core/providers/cpu/tensor/affine_grid.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cpu/tensor/affine_grid.cc#L88

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cpu/tensor/affine_grid.cc:88:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
const Eigen::StorageOptions option = Eigen::RowMajor;
auto theta_batch_offset = batch_num * 3 * 4;
const float* theta_data = theta->Data<float>() + theta_batch_offset;
const Eigen::Matrix<float, 3, 3, option> theta_R{
{theta_data[0], theta_data[1], theta_data[2]},
{theta_data[4], theta_data[5], theta_data[6]},
{theta_data[8], theta_data[9], theta_data[10]}};
const Eigen::Array<float, 3, 1> theta_T(theta_data[3], theta_data[7], theta_data[11]);

auto grid_batch_offset = batch_num * D * H * W * 3;
float* grid_data = grid->MutableData<float>() + grid_batch_offset;
Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, 3, option>> grid_matrix(grid_data, narrow<size_t>(D * H * W), 3);
grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose();
}
};

template <typename T>
Status AffineGrid<T>::Compute(OpKernelContext* context) const {
const Tensor* theta = context->Input<Tensor>(0);
//const auto elem_type = theta.GetElementType();
const auto elem_type = theta->GetElementType();
const TensorShape& theta_shape = theta->Shape();
if (theta_shape.NumDimensions() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "AffineGrid : Input theta tensor dimension is not 3");
Expand All @@ -113,17 +117,20 @@ Status AffineGrid<T>::Compute(OpKernelContext* context) const {

if (size_shape.GetDims()[0] == 4 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) {
int64_t N = size_data[0], H = size_data[2], W = size_data[3];

TensorShape grid_shape{N, H, W, 2};
auto grid = context->Output(0, grid_shape);

Eigen::Matrix<float, Eigen::Dynamic, 2> base_grid;
generate_base_grid_2d(H, W, align_corners_, base_grid);
Eigen::Matrix<float, 2, Eigen::Dynamic> base_grid_transposed = base_grid.transpose();

for (int64_t batch_num = 0; batch_num < N; batch_num++) {
affine_grid_generator_2d(theta, base_grid_transposed, batch_num, H, W, grid);
}
std::function<void(ptrdiff_t)> fn = [elem_type, theta, base_grid_transposed, H, W, grid](ptrdiff_t batch_num) {
utils::MLTypeCallDispatcher<float> t_disp(elem_type);
t_disp.Invoke<AffineGridGenerator2D>(theta, base_grid_transposed, batch_num, H, W, grid);
};

concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow<size_t>(N), std::move(fn), 0);
} else if (size_shape.GetDims()[0] == 5 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) {

Check warning on line 134 in onnxruntime/core/providers/cpu/tensor/affine_grid.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cpu/tensor/affine_grid.cc#L134

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cpu/tensor/affine_grid.cc:134:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
int64_t N = size_data[0], D = size_data[2], H = size_data[3], W = size_data[4];

Expand All @@ -134,9 +141,12 @@ Status AffineGrid<T>::Compute(OpKernelContext* context) const {
generate_base_grid_3d(D, H, W, align_corners_, base_grid);
Eigen::Matrix<float, 3, Eigen::Dynamic> base_grid_transposed = base_grid.transpose();

for (int64_t batch_num = 0; batch_num < N; batch_num++) {
affine_grid_generator_3d(theta, base_grid_transposed, batch_num, D, H, W, grid);
}
std::function<void(ptrdiff_t)> fn = [elem_type, theta, base_grid_transposed, D, H, W, grid](ptrdiff_t batch_num) {
utils::MLTypeCallDispatcher<float> t_disp(elem_type);
t_disp.Invoke<AffineGridGenerator3D>(theta, base_grid_transposed, batch_num, D, H, W, grid);
};

concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow<size_t>(N), std::move(fn), 0);

Check warning on line 149 in onnxruntime/core/providers/cpu/tensor/affine_grid.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cpu/tensor/affine_grid.cc#L149

Add #include <utility> for move [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/cpu/tensor/affine_grid.cc:149:  Add #include <utility> for move  [build/include_what_you_use] [4]
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "AffineGrid : Invalidate size - length of size shall be 4 or 5.");

Check warning on line 151 in onnxruntime/core/providers/cpu/tensor/affine_grid.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cpu/tensor/affine_grid.cc#L151

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cpu/tensor/affine_grid.cc:151:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
Expand Down
20 changes: 10 additions & 10 deletions onnxruntime/core/providers/cpu/tensor/affine_grid.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@ namespace onnxruntime {

template <typename T>
class AffineGrid final : public OpKernel {
public:
AffineGrid(const OpKernelInfo& info) : OpKernel(info) {
int64_t align_corners = info.GetAttrOrDefault<int64_t>("align_corners", 0);
align_corners_ = (align_corners != 0);
}
public:
AffineGrid(const OpKernelInfo& info) : OpKernel(info) {
int64_t align_corners = info.GetAttrOrDefault<int64_t>("align_corners", 0);
align_corners_ = (align_corners != 0);
}

Status Compute(OpKernelContext* context) const override;
Status Compute(OpKernelContext* context) const override;

private:
bool align_corners_;
int64_t dtype_;
int64_t k_;
private:
bool align_corners_;
int64_t dtype_;
int64_t k_;
};

} // namespace onnxruntime
17 changes: 8 additions & 9 deletions onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
for angle, translation, scale in zip(angles, translations, scales):
for size in sizes:
theta = np.array([], dtype=np.float32)
for _n in range(size[0]):
for _ in range(size[0]):
angle_radian = (angle / 180.0) * np.pi
theta = np.append(
theta,
Expand Down Expand Up @@ -68,28 +68,27 @@
for angle, translation, scale in zip(angles, translations, scales):
for size in sizes:
theta = np.array([], dtype=np.float32)
for _n in range(size[0]):
for _ in range(size[0]):
angle_radian_x = (angle[0] / 180.0) * np.pi
angle_radian_y = (angle[1] / 180.0) * np.pi
rotMatrix_x = np.array(
rot_matrix_x = np.array(
[
[1, 0, 0],
[0, np.cos(angle_radian_x), -np.sin(angle_radian_x)],
[0, np.sin(angle_radian_x), np.cos(angle_radian_x)],
]
)
rotMatrix_y = np.array(
rot_matrix_y = np.array(
[
[np.cos(angle_radian_y), 0, np.sin(angle_radian_y)],
[0, 1, 0],
[-np.sin(angle_radian_y), 0, np.cos(angle_radian_y)],
]
)
rotMatrix = np.matmul(rotMatrix_x, rotMatrix_y)
rotMatrix = rotMatrix * scale.reshape(3, 1)
translation = np.reshape(translation, (3, 1))
rotMatrix = np.append(rotMatrix, translation, axis=1)
theta = np.append(theta, rotMatrix.flatten())
rot_matrix = np.matmul(rot_matrix_x, rot_matrix_y)
rot_matrix = rot_matrix * scale.reshape(3, 1)
rot_matrix = np.append(rot_matrix, np.reshape(translation, (3, 1)), axis=1)
theta = np.append(theta, rot_matrix.flatten())
theta = theta.reshape(size[0], 3, 4)
theta = torch.Tensor(theta)
grid = affine_grid(theta, size, align_corners=align_corners)
Expand Down

0 comments on commit 6c6dc97

Please sign in to comment.