Skip to content

Commit

Permalink
shape constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Nov 20, 2023
1 parent 2a85532 commit baf6272
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
14 changes: 4 additions & 10 deletions onnxruntime/core/mickey/blk_q4/prepack_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,6 @@

#pragma once

// #include <cuda.h>
// #include <vector_types.h>
// #include "cutlass/cutlass.h"
// #include "cutlass/matrix_shape.h"
// #include "cutlass/util/host_tensor.h"
// #include "cutlass/util/reference/host/tensor_compare.h"
// #include "cutlass/util/reference/host/tensor_copy.h"
// #include "cutlass/util/reference/host/tensor_fill.h"

#include "core/common/common.h"
#include "core/util/matrix_layout.h"
Expand Down Expand Up @@ -66,7 +58,7 @@ struct BlockwiseQuantization {
* into one int8
*/
static inline auto get_quant_weights_shape(int rows, int columns) {
return make_Position((rows + 1) / 2, columns);
return make_Position(rows / 2, columns);
}

static inline auto get_quant_meta_shape(int rows, int columns) {
Expand Down Expand Up @@ -107,7 +99,9 @@ struct BlockwiseQuantization {
const gsl::span<uint8_t const>& weights, // <- int4 weights, column major
const gsl::span<uint8_t>& weights_prepacked // <- int4 prepacked weights tensor, same size buffer
) {
ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0,
ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0 &&
(rows % QuantBlocking::kRow) == 0 &&
(columns % QuantBlocking::kColumn) == 0,
"Does not support odd number of rows or columns!");
ORT_ENFORCE(weights.size() == size_t(rows * columns / 2),
"Weight tensor shape mismatch!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ void testPrepack(int rows, int columns, bool has_offset = true) {
}

int q_rows, q_cols;
MlasBlockwiseQuantizedShape<ElementT>(
MlasBlockwiseQuantizedShape<ElementT, 4>(
block_size, ColumnMajorQuantBlocking, rows, columns, q_rows, q_cols);
// to be exact, q_rows are padded to multiple of block_size, deal with it when we care about strange shapes
EXPECT_EQ(q_rows, q_weight_shape[0]);
Expand Down

0 comments on commit baf6272

Please sign in to comment.