Skip to content

Commit

Permalink
don't change shape on chunk quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
hseok-oh committed Aug 23, 2024
1 parent b782c53 commit f4ce5ba
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 32 deletions.
5 changes: 2 additions & 3 deletions runtime/onert/core/include/util/ShapeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,10 @@ ir::Shape inferExpandDimsShape(const ir::Shape &in_shape, int32_t axis);

template <typename T> ir::Shape inferFillShape(const ir::Shape &fill_shape, const T *shape_buf);

ir::Shape inferFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &ker_shape,
const bool chunk_kernel);
ir::Shape inferFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &ker_shape);

ir::Shape inferGatherShape(const ir::Shape &input_shape, const ir::Shape &indices_shape, int axis,
int rank, const bool chunk_input);
int rank);

ir::Shape inferOnehotShape(const ir::Shape &input_shape, const int depth, int axis);

Expand Down
12 changes: 3 additions & 9 deletions runtime/onert/core/src/compiler/StaticShapeInferer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,12 +575,9 @@ void StaticShapeInferer::visit(const ir::operation::FullyConnected &op)
// get mutable output operand
const auto output_idx = op.getOutputs().at(0);
ir::Operand &output = operands.at(output_idx);
const auto ker_type = ker.typeInfo().type();
const bool chunk_ker = (ker_type == ir::DataType::QUANT_UINT4_SYMM_PER_CHUNK ||
ker_type == ir::DataType::QUANT_INT8_SYMM_PER_CHUNK);
// re-sizing output shape
ir::Shape new_shape =
shape_inference::inferFullyConnectedShape(input.info().shape(), ker.info().shape(), chunk_ker);
shape_inference::inferFullyConnectedShape(input.info().shape(), ker.info().shape());
output.info().shape(new_shape);
}

Expand Down Expand Up @@ -608,11 +605,8 @@ void StaticShapeInferer::visit(const ir::operation::Gather &op)
assert(0 <= axis && axis < rank);

// re-sizing output shape
const auto input_type = input.typeInfo().type();
const bool chunk_input = (input_type == ir::DataType::QUANT_UINT4_SYMM_PER_CHUNK ||
input_type == ir::DataType::QUANT_INT8_SYMM_PER_CHUNK);
ir::Shape new_shape = shape_inference::inferGatherShape(
input.info().shape(), indices.info().shape(), axis, rank, chunk_input);
ir::Shape new_shape =
shape_inference::inferGatherShape(input.info().shape(), indices.info().shape(), axis, rank);
output.info().shape(new_shape);
}

Expand Down
12 changes: 2 additions & 10 deletions runtime/onert/core/src/exec/DynamicShapeInferer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,7 @@ void DynamicShapeInferer::visit(const ir::operation::FullyConnected &op)
auto input_shape = input->getShape();
auto ker_shape = ker->getShape();

const auto ker_type = ker->data_type();
const bool chunk_ker = (ker_type == ir::DataType::QUANT_UINT4_SYMM_PER_CHUNK ||
ker_type == ir::DataType::QUANT_INT8_SYMM_PER_CHUNK);
ir::Shape new_shape =
shape_inference::inferFullyConnectedShape(input_shape, ker_shape, chunk_ker);
ir::Shape new_shape = shape_inference::inferFullyConnectedShape(input_shape, ker_shape);

auto output_ind = op.getOutputs().at(0);
auto output = _tensor_registry->getITensor(output_ind);
Expand Down Expand Up @@ -483,11 +479,7 @@ void DynamicShapeInferer::visit(const ir::operation::Gather &op)

assert(0 <= axis && axis < rank);

const auto input_type = input->data_type();
const bool chunk_input = (input_type == ir::DataType::QUANT_UINT4_SYMM_PER_CHUNK ||
input_type == ir::DataType::QUANT_INT8_SYMM_PER_CHUNK);
ir::Shape new_shape =
shape_inference::inferGatherShape(input_shape, indices_shape, axis, rank, chunk_input);
ir::Shape new_shape = shape_inference::inferGatherShape(input_shape, indices_shape, axis, rank);

auto output_ind = op.getOutputs().at(0);
auto output = _tensor_registry->getITensor(output_ind);
Expand Down
11 changes: 4 additions & 7 deletions runtime/onert/core/src/util/ShapeInference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,15 +394,14 @@ template <typename T> ir::Shape inferFillShape(const ir::Shape &fill_shape, cons
template ir::Shape inferFillShape(const ir::Shape &fill_shape, const int32_t *shape_buf);
template ir::Shape inferFillShape(const ir::Shape &fill_shape, const int64_t *shape_buf);

ir::Shape inferFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &ker_shape,
bool chunk_kernel)
ir::Shape inferFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &ker_shape)
{
assert(in_shape.rank() >= 2);
assert(ker_shape.rank() == 2);

const auto input_size_with_batch = in_shape.num_elements();
const auto num_units = ker_shape.dim(0);
const auto input_size = chunk_kernel ? ker_shape.dim(1) * 32 : ker_shape.dim(1);
const auto input_size = ker_shape.dim(1);
const auto batch_size = input_size_with_batch / input_size;
assert(input_size_with_batch % input_size == 0);

Expand Down Expand Up @@ -457,7 +456,7 @@ ir::Shape inferBCQGatherShape(const ir::Shape &indices_shape, const ir::Shape &c
}

ir::Shape inferGatherShape(const ir::Shape &input_shape, const ir::Shape &indices_shape, int axis,
int rank, const bool chunk_input)
int rank)
{
ir::Shape out_shape;

Expand All @@ -474,9 +473,7 @@ ir::Shape inferGatherShape(const ir::Shape &input_shape, const ir::Shape &indice
}
else
{
auto output_dim =
(chunk_input && idx == rank - 1) ? input_shape.dim(idx) * 32 : input_shape.dim(idx);
out_shape.append(output_dim);
out_shape.append(input_shape.dim(idx));
}
}

Expand Down
5 changes: 2 additions & 3 deletions runtime/onert/core/src/util/ShapeInference.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,7 @@ TEST(ShapeInference, FullyConnected)
{
Shape in_shape{3, 4, 5, 6};
Shape ker_shape{3, 10};
auto infered_out_shape =
onert::shape_inference::inferFullyConnectedShape(in_shape, ker_shape, false);
auto infered_out_shape = onert::shape_inference::inferFullyConnectedShape(in_shape, ker_shape);

ASSERT_EQ(infered_out_shape.rank(), 2);
ASSERT_EQ(infered_out_shape.dim(0), 36);
Expand Down Expand Up @@ -421,7 +420,7 @@ TEST(ShapeInference, Gather)
{
auto check = [&](Shape &input, Shape &indices, Shape &expected, int32_t axis) {
int rank = input.rank();
auto actual = onert::shape_inference::inferGatherShape(input, indices, axis, rank, false);
auto actual = onert::shape_inference::inferGatherShape(input, indices, axis, rank);

ASSERT_EQ(actual.rank(), expected.rank());

Expand Down

0 comments on commit f4ce5ba

Please sign in to comment.