Skip to content

Commit

Permalink
Dtype compliance: split_copy
Browse files Browse the repository at this point in the history
Reviewed By: SS-JIA

Differential Revision: D48318690

fbshipit-source-id: 935759301c9a4deada20550bb4a2de549646f9ac
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Aug 15, 2023
1 parent 88eda7b commit 76e2e12
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 27 deletions.
54 changes: 30 additions & 24 deletions kernels/portable/cpu/op_split_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ void check_args(

// Validate each output.
for (size_t i = 0; i < out.size(); ++i) {
// All output dtypes must match the input type.
// All output dtypes must be the same.
ET_CHECK_MSG(
out[i].scalar_type() == input.scalar_type(),
"out[%zu] dtype %hhd != input dtype %hhd",
out[i].scalar_type() == out[0].scalar_type(),
"out[%zu] dtype %hhd != out[0] dtype %hhd",
i,
out[i].scalar_type(),
input.scalar_type());
out[0].scalar_type());

// All outputs must have the same number of dimensions as the input.
ET_CHECK_MSG(
Expand Down Expand Up @@ -170,26 +170,32 @@ void split_copy_Tensor_out(

const size_t leading_dims = getLeadingDims(input, dim);
const size_t trailing_dims = getTrailingDims(input, dim);

const size_t element_size = input.element_size();
const size_t step = input.size(dim) * trailing_dims * element_size;

const char* input_data = input.const_data_ptr<char>();
for (size_t i = 0, e = out.size(); i < e; ++i) {
size_t num_bytes = out[i].size(dim) * trailing_dims * element_size;
if (num_bytes == 0) {
continue;
}

const char* src = input_data;
char* dest = out[i].mutable_data_ptr<char>();
for (size_t j = 0; j < leading_dims; ++j) {
memcpy(dest, src, num_bytes);
src += step;
dest += num_bytes;
}
input_data += num_bytes;
}
const size_t step = input.size(dim) * trailing_dims;

ScalarType in_type = input.scalar_type();
ScalarType out_type = out[0].scalar_type();

ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE_IN, [&]() {
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() {
const CTYPE_IN* input_data = input.const_data_ptr<CTYPE_IN>();
for (size_t i = 0, e = out.size(); i < e; ++i) {
size_t out_step = out[i].size(dim) * trailing_dims;
if (out_step == 0) {
continue;
}
const CTYPE_IN* src = input_data;
CTYPE_OUT* dest = out[i].mutable_data_ptr<CTYPE_OUT>();
for (size_t j = 0; j < leading_dims; ++j) {
for (size_t k = 0; k < out_step; ++k) {
dest[k] = convert<CTYPE_OUT, CTYPE_IN>(src[k]);
}
src += step;
dest += out_step;
}
input_data += out_step;
}
});
});
}

} // namespace native
Expand Down
4 changes: 1 addition & 3 deletions kernels/test/op_split_copy_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,7 @@ TEST(OpSplitCopyTensorOutTest, OutOfRangeDimsDie) {
}

TEST(OpSplitCopyTensorOutTest, DtypeMismatchDies) {
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
GTEST_SKIP() << "ATen kernel can handle dtype mismatch";
}
GTEST_SKIP() << "ATen kernel can handle dtype mismatch";
TensorFactory<ScalarType::Int> tf_int;
TensorListFactory<ScalarType::Int> tlf_int;
TensorListFactory<ScalarType::Float> tlf_float;
Expand Down

0 comments on commit 76e2e12

Please sign in to comment.