From b33216be4c02adfbbdeac2fd30ddc55f673eda3d Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Fri, 12 Apr 2024 11:03:05 -0700 Subject: [PATCH] [JS/WebGPU] Improve MatMulNBits perf (#19974) ### Description Improve performance using shared memory ### Motivation and Context --- js/web/lib/wasm/jsep/init.ts | 11 + .../lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 252 +++-- js/web/lib/wasm/jsep/webgpu/types.ts | 2 + js/web/test/data/ops/matmulnbits.jsonc | 960 +++++++++++++++++- 4 files changed, 1097 insertions(+), 128 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index adcaa145cdca8..1ceae2394f462 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -92,6 +92,17 @@ class ComputeContextImpl implements ComputeContext { this.inputs = inputs; } + getMaxComputeWorkgroupSizes(): [number, number, number] { + return [ + this.backend.device.limits.maxComputeWorkgroupSizeX, this.backend.device.limits.maxComputeWorkgroupSizeY, + this.backend.device.limits.maxComputeWorkgroupSizeZ + ]; + } + + getMaxComputeWorkgroupStoragesize(): number { + return this.backend.device.limits.maxComputeWorkgroupStorageSize; + } + compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] { // prepare inputs. inputs should always be valid data. const mappedInputs = diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index 9bf5e4066139d..7f1a5b96863f7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; +import {DataType, getTensorElementSize} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -50,38 +50,49 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt }; export const createMatMulNBitsProgramInfo = - (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => { + (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes, + maxComputeWorkgroupSizes: [number, number, number], maxComputeWorkgroupStorageSize: number): ProgramInfo => { const inputShape = inputs[0].dims; const aRank = inputShape.length; - const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.n); - const m = inputShape[aRank - 2]; + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const dimAOuter = inputShape[aRank - 2]; + const dimInner = attributes.k; + const dimBOuter = attributes.n; + const batchDims = inputShape.slice(0, aRank - 2); + const batchSize = ShapeUtil.size(batchDims); const blobSize = attributes.blockSize / 8 * attributes.bits; const blobSizeInWords = blobSize / 4; - const outputNumber = getMaxComponents(m); - const components = getMaxComponents(attributes.n); + const dataType = inputs[0].dataType; + const outputNumber = getMaxComponents(dimAOuter); const aComponents = getMaxComponents(attributes.k); const bComponents = getMaxComponents(blobSizeInWords); + const elementSize = getTensorElementSize(dataType)!; + const workgroupOutputSize = dimAOuter * nBlocksPerCol * elementSize; + const maxNumberOfComponents = Math.floor(maxComputeWorkgroupStorageSize / workgroupOutputSize); + const useBlockwiseMatMulNBits = nBlocksPerCol <= maxComputeWorkgroupSizes[0] && maxNumberOfComponents > 0; + const components = (!useBlockwiseMatMulNBits || maxNumberOfComponents >= 4) ? getMaxComponents(dimBOuter) : + ((maxNumberOfComponents >= 2) && getMaxComponents(dimBOuter) >= 2) ? 2 : + 1; + const outputShape = batchDims.concat([dimAOuter, dimBOuter]); const outputSize = ShapeUtil.size(outputShape) / components / outputNumber; - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k}, - {type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel}, - {type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize} - ]; - const aShape = inputShape.slice(); - aShape.splice(-1, 1, attributes.k / aComponents); + + const programUniforms: ProgramUniform[] = useBlockwiseMatMulNBits ? + [] : + [{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.blockSize}]; + const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents]; const bShape = ShapeUtil.convertShape(inputs[1].dims).slice(); bShape.splice(-1, 1, blobSizeInWords / bComponents); - programUniforms.push(...createTensorShapeVariables(aShape)); + programUniforms.push(...createTensorShapeVariables(inputShapeTemp)); programUniforms.push(...createTensorShapeVariables(bShape)); programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); if (inputs.length === 4) { programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); } - const oShape = outputShape.slice(); - oShape.splice(-1, 1, attributes.n / components); - programUniforms.push(...createTensorShapeVariables(oShape)); + const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; + programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); const getShaderSource = (shaderHelper: ShaderHelper) => { - const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); + const inputRank = inputShapeTemp.length; + const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents); const b = inputVariable('b', DataType.uint32, bShape.length, bComponents); const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); const inputVariables = [a, b, scales]; @@ -90,12 +101,9 @@ export const createMatMulNBitsProgramInfo = if (zeroPoints) { inputVariables.push(zeroPoints); } - const output = outputVariable('output', inputs[0].dataType, outputShape.length, components); - const uniforms: UniformsArrayType = [ - {name: 'output_size', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, - {name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'} - ]; - const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const outputRank = outputShapeTemp.length; + const output = outputVariable('output', inputs[0].dataType, outputRank, components); + const uniforms: UniformsArrayType = [{name: 'output_size', type: 'u32'}, {name: 'block_size', type: 'u32'}]; const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const qDqDataType = (() => { @@ -111,43 +119,49 @@ export const createMatMulNBitsProgramInfo = } })(); - const dequantizeImpl = ` - fn dequantize(quantized: ${qDqDataType}, zero_point: ${dataType}, scale: ${dataType}) -> ${qDqDataType} { - ${(() => { + const processOneBlock = ` + for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) { + ${b.indicesSet('b_indices', '2', 'word')}; + let b_data = ${b.getByIndices('b_indices')}; + for (var i: u32 = 0; i < ${bComponents}; i++) { + let b_value: u32 = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'}; + let b_mask: u32 = 0x0F0F0F0Fu; + let b_value_lower: vec4 = unpack4xU8(b_value & b_mask); + let b_value_upper: vec4 = unpack4xU8((b_value >> 4) & b_mask); + let b_quantized_values = ${qDqDataType}(${ + Array.from({length: 4}, (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`) + .join(', ')}); + let b_dequantized_values = ${(() => { if (aComponents === 1) { - return `var dequantized = ${qDqDataType}(${ - Array.from({length: 8}, (_, i) => `(quantized[${i}] - zero_point) * scale`).join(', ')}); - return dequantized;`; + return `${qDqDataType}(${ + Array.from({length: 8}, (_, i) => `(b_quantized_values[${i}] - zero_point) * scale`).join(', ')});`; } else { - return `var zero_points: ${qDqDataType} = ${qDqDataType}(${Array(8).fill('zero_point').join(',')}); - return (quantized - zero_points) * scale;`; - } - })()} - }`; - const ortUnpack8x4snormImpl = ` - fn ortUnpack8x4snorm(value: u32) -> ${qDqDataType} { - var quantized: ${qDqDataType}; - var offset: u32 = 0; - let count: u32 = 4; - for (var i: u32 = 0; i < 8u; i++) { - var result = ${dataType}(extractBits(value, offset, count)); - ${(() => { - switch (aComponents) { - case 1: - return 'quantized[i] = result;'; - case 2: - return 'quantized[i / 2][i % 2] = result;'; - case 4: - return 'quantized[i / 4][i % 4] = result;'; - default: - throw new Error(`${aComponents}-component is not supported.`); + return `(b_quantized_values - ${qDqDataType}(${Array(8).fill('zero_point').join(',')})) * scale;`; } - })()} - offset += count; + })()}; + // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 + for (var m: u32 = 0; m < ${useBlockwiseMatMulNBits ? dimAOuter : outputNumber}u; m++) { + ${a.indicesSet('a_indices', inputRank - 2, useBlockwiseMatMulNBits ? 'm' : `row * ${outputNumber} + m`)}; + ${a.indicesSet('a_indices', inputRank - 1, 'word_offset')}; + var input_offset = ${a.indicesToOffset('a_indices')}; + var a_data: ${qDqDataType}; + for (var j: u32 = 0; j < ${8 / aComponents}; j++) { + a_data[j] = ${a.getByOffset('input_offset')}; + input_offset++; + } + ${useBlockwiseMatMulNBits ? 'workgroup_shared[workgroup_shared_offset + m]' : 'output_values[m]'}${ + components > 1 ? '[c]' : ''} += ${ + Array + .from( + {length: 8 / aComponents}, + (_, i) => `${ + aComponents === 1 ? `a_data[${i}] * b_dequantized_values[${i}]` : + `dot(a_data[${i}], b_dequantized_values[${i}])`}`) + .join(' + ')}; + } + word_offset += ${8 / aComponents}; } - return quantized; }`; - const updateZeroPointIndex = zeroPoints ? ` zero_point_offset += 4; if (zero_point_offset == 32) { @@ -157,30 +171,84 @@ export const createMatMulNBitsProgramInfo = }` : ''; - return ` - ${dequantizeImpl}; - ${ortUnpack8x4snormImpl}; + return useBlockwiseMatMulNBits ? ` + var workgroup_shared: array<${output.type.value}, ${dimAOuter * nBlocksPerCol}>; + ${shaderHelper.declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart([ + nBlocksPerCol, 1, 1 + ])} + var a_indices: ${a.type.indices}; + var block = local_id.x; + var col = workgroup_id.y; + var batch = workgroup_id.z; + ${a.indicesSet('a_indices', '0', 'batch')}; + // Two zero points are packed into one byte when uniforms.bits is 4. + for (var c: u32 = 0; c < ${components}; c++) { + let col_times_components_plus_c = col * ${components} + c; + ${ + zeroPoints ? ` + var zero_point_bytes_per_col: u32 = (${nBlocksPerCol} + 1) / 2; + var zero_point_byte_count: u32 = col_times_components_plus_c * zero_point_bytes_per_col + (block >> 0x1u); + var zero_point_word_index: u32 = zero_point_byte_count >> 0x2u; + var zero_point_byte_offset: u32 = zero_point_byte_count & 0x3u; + var zero_point_nibble_offset: u32 = block & 0x1u; + var zero_point_bits_offset: u32 = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2); + var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;` : + ''} + var b_indices: ${b.type.indices}; + ${b.indicesSet('b_indices', '0', 'col_times_components_plus_c')}; + // The scale and zero points are computed per block. + var scales_index = col_times_components_plus_c * ${nBlocksPerCol} + block; + let scale = ${scales.getByOffset('scales_index')}; + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point = ${dataType}(${zeroPoints ? '(zero_point_word) & 0xFu' : 8.0}); + ${b.indicesSet('b_indices', '1', 'block')}; + var word_offset: u32 = block * ${attributes.blockSize / aComponents}; + var workgroup_shared_offset: u32 = block * ${dimAOuter}; + ${processOneBlock} + } + workgroupBarrier(); + if (local_id.x == 0u) { + var output_indices: ${output.type.indices}; + ${output.indicesSet('output_indices', '0', 'batch')}; + ${output.indicesSet('output_indices', outputRank - 1, 'col')}; + ${output.indicesSet('output_indices', outputRank - 2, '0')}; + var output_offset = ${output.indicesToOffset('output_indices')}; + for (var m: u32 = 0u; m < ${dimAOuter}u; m++) { + var output_value: ${output.type.value} = ${output.type.value}(0); + var workgroup_shared_offset: u32 = m; + for (var b: u32 = 0u; b < ${nBlocksPerCol}u; b++) { + output_value += workgroup_shared[workgroup_shared_offset]; + workgroup_shared_offset += ${dimAOuter}; + } + ${output.setByOffset('output_offset', 'output_value')}; + output_offset += ${dimBOuter / components}; + } + } + }` : + ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} var output_values: array<${output.type.value}, ${outputNumber}>; var output_indices = ${output.offsetToIndices('global_idx')}; - var n = ${output.indicesGet('output_indices', aRank - 1)}; - var m = ${output.indicesGet('output_indices', aRank - 2)}; + var col = ${output.indicesGet('output_indices', outputRank - 1)}; + var row = ${output.indicesGet('output_indices', outputRank - 2)}; var a_indices: ${a.type.indices} = output_indices; // Two zero points are packed into one byte because uniforms.bits <= 4. // zero_point_offset is either 0 or 4. It is bit offset within one byte. // TODO support zero_point_offset for bits > 4 ${ - zeroPoints ? ` - var zero_point_index: u32 = n * ${components} * ((${nBlocksPerCol} + 1) / 2) / 4; + zeroPoints ? ` + var zero_point_abs_offset = col * ${components} * ((${nBlocksPerCol} + 1) / 2); + var zero_point_index: u32 = zero_point_abs_offset / 4; var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; - var zero_point_offset: u32 = 0;` : - ''} - var scale_index = n * ${nBlocksPerCol * components}; + var zero_point_offset: u32 = (zero_point_abs_offset % 4) * 8;` : + ''} + var scale_index = col * ${nBlocksPerCol * components}; var b_indices: ${b.type.indices}; for (var c: u32 = 0; c < ${components}; c++) { - ${b.indicesSet('b_indices', '0', `n * ${components} + c`)}; + ${b.indicesSet('b_indices', '0', `col * ${components} + c`)}; var block_offset: u32 = 0; for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) { // The scale and zero points are computed per block. @@ -189,52 +257,35 @@ export const createMatMulNBitsProgramInfo = let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0}); ${b.indicesSet('b_indices', '1', 'block')}; var word_offset: u32 = block_offset; - for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) { - ${b.indicesSet('b_indices', '2', 'word')}; - let b_data = ${b.getByIndices('b_indices')}; - for (var i: u32 = 0; i < ${bComponents}; i++) { - let b_value = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'}; - let b_quantized_values: ${qDqDataType} = ortUnpack8x4snorm(b_value); - let b_dequantized_values = dequantize(b_quantized_values, zero_point, scale); - // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 - var offset: u32 = word_offset; - for (var j: u32 = 0; j < 8/${aComponents}; j++) { - ${a.indicesSet('a_indices', aRank - 1, `offset/${aComponents}`)}; - for (var k: u32 = 0; k < ${outputNumber}u; k++) { - ${a.indicesSet('a_indices', aRank - 2, `m * ${outputNumber} + k`)}; - let a_data = ${a.getByIndices('a_indices')}; - output_values[k]${components > 1 ? '[c]' : ''} += ${ - aComponents === 1 ? 'a_data * b_dequantized_values[j]' : 'dot(a_data, b_dequantized_values[j])'}; - } - offset += ${aComponents}; - } - word_offset += 8; - } - } + ${processOneBlock} scale_index++; ${updateZeroPointIndex} - block_offset += uniforms.block_size; + block_offset += uniforms.block_size / ${aComponents}; } // Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte. ${ - zeroPoints ? `if (zero_point_offset % 8 > 0) { + zeroPoints ? `if (zero_point_offset % 8 > 0) { ${updateZeroPointIndex} }` : - ''} + ''} } for (var k: u32 = 0u; k < ${outputNumber}u; k++) { - ${output.indicesSet('output_indices', aRank - 2, `${outputNumber + ' * m + k'}`)}; + ${output.indicesSet('output_indices', outputRank - 2, `${outputNumber} * row + k`)}; ${output.setByIndices('output_indices', 'output_values[k]')} } }`; }; return { - name: 'MatMulNBits', - shaderCache: - {hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')}, + name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits', + shaderCache: { + hint: `${attributes.cacheKey};${dimAOuter};${dataType};${inputs.length}`, + inputDependencies: Array(inputs.length).fill('rank') + }, getRunData: () => ({ - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + outputs: [{dims: outputShape, dataType}], + name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits', + dispatchGroup: useBlockwiseMatMulNBits ? {x: 1, y: Math.ceil(dimBOuter / components), z: batchSize} : + {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms }), getShaderSource @@ -243,7 +294,10 @@ export const createMatMulNBitsProgramInfo = export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { validateInputs(context.inputs, attributes); - context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); + const maxComputeWorkgroupSizes: [number, number, number] = context.getMaxComputeWorkgroupSizes(); + const maxComputeWorkgroupStorageSize = context.getMaxComputeWorkgroupStoragesize(); + context.compute(createMatMulNBitsProgramInfo( + context.inputs, attributes, maxComputeWorkgroupSizes, maxComputeWorkgroupStorageSize)); }; export const parseMatMulNBitsAttributes = (attributes: Record): MatMulNBitsAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index a0c4b77d44a77..2a584fc0a2218 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -188,6 +188,8 @@ export interface ComputeContext { compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[]; output(index: number, dims: readonly number[]): number; + getMaxComputeWorkgroupSizes(): [number, number, number]; + getMaxComputeWorkgroupStoragesize(): number; } export type TimestampQuery = 'none'|'inside-passes'|'at-passes'; diff --git a/js/web/test/data/ops/matmulnbits.jsonc b/js/web/test/data/ops/matmulnbits.jsonc index 175be78cc0818..63e0a0ed52879 100644 --- a/js/web/test/data/ops/matmulnbits.jsonc +++ b/js/web/test/data/ops/matmulnbits.jsonc @@ -1,6 +1,6 @@ [ { - "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4", "operator": "MatMulNBits", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ @@ -11,7 +11,64 @@ ], "cases": [ { - "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric", + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ], + "dims": [8, 16], + "type": "float32" + }, + { + "dims": [8, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ] + }, + { + "dims": [8], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0, + -1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, + 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232, + -11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032, + -16405, -48288, -16247 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4; asymmetric", "inputs": [ { "data": [ @@ -35,9 +92,592 @@ ] }, { - "dims": [8], - "type": "float32", - "data": [0, 1, 2, 3, 4, 5, 6, 7] + "dims": [8], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + }, + { + "dims": [8], + "type": "uint8", + "data": [248, 249, 250, 251, 252, 253, 254, 255] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + 0, -505, -1600, -2043, -3904, -4285, -6912, -7231, 0, -1449, -5312, -6027, -12864, -12845, -22656, -21903, + 0, -2393, -9024, -10011, -21824, -21405, -38400, -36575, 0, -3337, -12736, -13995, -30784, -29965, -54144, + -51247, 0, -4281, -16448, -17979, -39744, -38525, -69888, -65919, 0, -5225, -20160, -21963, -48704, + -47085, -85632, -80591, 0, -6169, -23872, -25947, -57664, -55645, -101376, -95263, 0, -7113, -27584, + -29931, -66624, -64205, -117120, -109935 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=8, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [8, 32], + "type": "float32" + }, + { + "dims": [8, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + -1073, -3763, -5429, -6071, -5689, -4283, -1853, 1601, -2449, -12499, -19477, -23383, -24217, -21979, + -16669, -8287, -3825, -21235, -33525, -40695, -42745, -39675, -31485, -18175, -5201, -29971, -47573, + -58007, -61273, -57371, -46301, -28063, -6577, -38707, -61621, -75319, -79801, -75067, -61117, -37951, + -7953, -47443, -75669, -92631, -98329, -92763, -75933, -47839, -9329, -56179, -89717, -109943, -116857, + -110459, -90749, -57727, -10705, -64915, -103765, -127255, -135385, -128155, -105565, -67615 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=8, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [8, 32], + "type": "float32" + }, + { + "dims": [8, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + }, + { + "dims": [8], + "type": "uint8", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + 1935, 6941, 12491, 18585, 25223, 32405, 40131, 48401, 4655, 17661, 31211, 45305, 59943, 75125, 90851, + 107121, 7375, 28381, 49931, 72025, 94663, 117845, 141571, 165841, 10095, 39101, 68651, 98745, 129383, + 160565, 192291, 224561, 12815, 49821, 87371, 125465, 164103, 203285, 243011, 283281, 15535, 60541, 106091, + 152185, 198823, 246005, 293731, 342001, 18255, 71261, 124811, 178905, 233543, 288725, 344451, 400721, + 20975, 81981, 143531, 205625, 268263, 331445, 395171, 459441 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=48, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 48, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=48, N=8, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, + 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, + 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, + 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, + 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, + 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, + 379, 380, 381, 382, 383 + ], + "dims": [8, 48], + "type": "float32" + }, + { + "dims": [8, 3, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192 + ] + }, + { + "dims": [24], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + -7569, -13416, -24375, -14292, -20445, 5568, 4221, 46164, -17697, -39528, -73383, -45588, -66861, 10560, + 1869, 128916, -27825, -65640, -122391, -76884, -113277, 15552, -483, 211668, -37953, -91752, -171399, + -108180, -159693, 20544, -2835, 294420, -48081, -117864, -220407, -139476, -206109, 25536, -5187, 377172, + -58209, -143976, -269415, -170772, -252525, 30528, -7539, 459924, -68337, -170088, -318423, -202068, + -298941, 35520, -9891, 542676, -78465, -196200, -367431, -233364, -345357, 40512, -12243, 625428 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=48, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 48, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=48, N=8, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, + 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, + 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, + 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, + 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, + 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, + 379, 380, 381, 382, 383 + ], + "dims": [8, 48], + "type": "float32" + }, + { + "dims": [8, 3, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192 + ] + }, + { + "dims": [24], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + }, + { + "dims": [16], + "type": "uint8", + "data": [240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + -1353, -5984, -24751, -31500, -63509, -72376, -117627, -128612, -6105, -20576, -74527, -94284, -190565, + -215608, -354219, -384548, -10857, -35168, -124303, -157068, -317621, -358840, -590811, -640484, -15609, + -49760, -174079, -219852, -444677, -502072, -827403, -896420, -20361, -64352, -223855, -282636, -571733, + -645304, -1063995, -1152356, -25113, -78944, -273631, -345420, -698789, -788536, -1300587, -1408292, + -29865, -93536, -323407, -408204, -825845, -931768, -1537179, -1664228, -34617, -108128, -373183, -470988, + -952901, -1075000, -1773771, -1920164 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=64, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 64, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=64, N=8, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, + 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, + 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, + 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, + 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, + 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, + 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, + 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, + 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, + 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, + 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, + 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, + 505, 506, 507, 508, 509, 510, 511 + ], + "dims": [8, 64], + "type": "float32" + }, + { + "dims": [8, 4, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + -13572, -28812, -27668, -10140, 23772, 74068, 140748, 192564, -33796, -91532, -100116, -59548, 30172, + 169044, 357068, 531252, -54020, -154252, -172564, -108956, 36572, 264020, 573388, 869940, -74244, -216972, + -245012, -158364, 42972, 358996, 789708, 1208628, -94468, -279692, -317460, -207772, 49372, 453972, + 1006028, 1547316, -114692, -342412, -389908, -257180, 55772, 548948, 1222348, 1886004, -134916, -405132, + -462356, -306588, 62172, 643924, 1438668, 2224692, -155140, -467852, -534804, -355996, 68572, 738900, + 1654988, 2563380 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=64, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 64, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=64, N=8, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, + 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, + 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, + 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, + 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, + 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, + 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, + 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, + 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, + 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, + 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, + 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, + 505, 506, 507, 508, 509, 510, 511 + ], + "dims": [8, 64], + "type": "float32" + }, + { + "dims": [8, 4, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [16], + "type": "uint8", + "data": [240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] + } + ], + "outputs": [ + { + "dims": [8, 8], + "type": "float32", + "data": [ + -26004, -63644, -96932, -125868, -150452, -170684, -186564, -229340, -60564, -157084, -249252, -337068, + -420532, -499644, -574404, -707804, -95124, -250524, -401572, -548268, -690612, -828604, -962244, + -1186268, -129684, -343964, -553892, -759468, -960692, -1157564, -1350084, -1664732, -164244, -437404, + -706212, -970668, -1230772, -1486524, -1737924, -2143196, -198804, -530844, -858532, -1181868, -1500852, + -1815484, -2125764, -2621660, -233364, -624284, -1010852, -1393068, -1770932, -2144444, -2513604, + -3100124, -267924, -717724, -1163172, -1604268, -2041012, -2473404, -2901444, -3578588 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=80, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 80, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=80, N=8, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, + 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, + 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, + 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, + 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, + 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, + 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, + 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, + 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, + 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, + 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, + 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, + 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, + 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, + 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, + 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, + 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, + 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, + 631, 632, 633, 634, 635, 636, 637, 638, 639 + ], + "dims": [8, 80], + "type": "float32" + }, + { + "dims": [8, 5, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320 + ] + }, + { + "dims": [40], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 + ] + }, + { + "dims": [24], + "type": "uint8", + "data": [ + 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, + 261, 262, 263 + ] } ], "outputs": [ @@ -45,11 +685,12 @@ "dims": [8, 8], "type": "float32", "data": [ - 0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0, - -1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, - 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232, - -11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032, - -16405, -48288, -16247 + -19988, -63429, -155448, -216179, -358428, 740351, 259888, 172481, -56788, -186869, -451128, -632899, + -1053788, 1574031, 1165488, 546481, -93588, -310309, -746808, -1049619, -1749148, 2407711, 2071088, + 920481, -130388, -433749, -1042488, -1466339, -2444508, 3241391, 2976688, 1294481, -167188, -557189, + -1338168, -1883059, -3139868, 4075071, 3882288, 1668481, -203988, -680629, -1633848, -2299779, -3835228, + 4908751, 4787888, 2042481, -240788, -804069, -1929528, -2716499, -4530588, 5742431, 5693488, 2416481, + -277588, -927509, -2225208, -3133219, -5225948, 6576111, 6599088, 2790481 ] } ] @@ -188,7 +829,7 @@ { "dims": [16], "type": "uint8", - "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] } ], "outputs": [ @@ -196,24 +837,25 @@ "dims": [16, 16], "type": "float32", "data": [ - 0, 728, 688, 2376, 1632, 4280, 2832, 6440, 4288, 8856, 6000, 11528, 7968, 14456, 10192, 17640, 0, 2200, - 1840, 7176, 4448, 12920, 7824, 19432, 11968, 26712, 16880, 34760, 22560, 43576, 29008, 53160, 0, 3672, - 2992, 11976, 7264, 21560, 12816, 32424, 19648, 44568, 27760, 57992, 37152, 72696, 47824, 88680, 0, 5144, - 4144, 16776, 10080, 30200, 17808, 45416, 27328, 62424, 38640, 81224, 51744, 101816, 66640, 124200, 0, - 6616, 5296, 21576, 12896, 38840, 22800, 58408, 35008, 80280, 49520, 104456, 66336, 130936, 85456, 159720, - 0, 8088, 6448, 26376, 15712, 47480, 27792, 71400, 42688, 98136, 60400, 127688, 80928, 160056, 104272, - 195240, 0, 9560, 7600, 31176, 18528, 56120, 32784, 84392, 50368, 115992, 71280, 150920, 95520, 189176, - 123088, 230760, 0, 11032, 8752, 35976, 21344, 64760, 37776, 97384, 58048, 133848, 82160, 174152, 110112, - 218296, 141904, 266280, 0, 12504, 9904, 40776, 24160, 73400, 42768, 110376, 65728, 151704, 93040, 197384, - 124704, 247416, 160720, 301800, 0, 13976, 11056, 45576, 26976, 82040, 47760, 123368, 73408, 169560, - 103920, 220616, 139296, 276536, 179536, 337320, 0, 15448, 12208, 50376, 29792, 90680, 52752, 136360, - 81088, 187416, 114800, 243848, 153888, 305656, 198352, 372840, 0, 16920, 13360, 55176, 32608, 99320, - 57744, 149352, 88768, 205272, 125680, 267080, 168480, 334776, 217168, 408360, 0, 18392, 14512, 59976, - 35424, 107960, 62736, 162344, 96448, 223128, 136560, 290312, 183072, 363896, 235984, 443880, 0, 19864, - 15664, 64776, 38240, 116600, 67728, 175336, 104128, 240984, 147440, 313544, 197664, 393016, 254800, - 479400, 0, 21336, 16816, 69576, 41056, 125240, 72720, 188328, 111808, 258840, 158320, 336776, 212256, - 422136, 273616, 514920, 0, 22808, 17968, 74376, 43872, 133880, 77712, 201320, 119488, 276696, 169200, - 360008, 226848, 451256, 292432, 550440 + 0, 608, 208, 1296, -288, 1280, -1488, 560, -3392, -864, -6000, -2992, -9312, -5824, -13328, -9360, 0, + 1824, 336, 3792, -1568, 3520, -5712, 1008, -12096, -3744, -20720, -10736, -31584, -19968, -44688, -31440, + 0, 3040, 464, 6288, -2848, 5760, -9936, 1456, -20800, -6624, -35440, -18480, -53856, -34112, -76048, + -53520, 0, 4256, 592, 8784, -4128, 8000, -14160, 1904, -29504, -9504, -50160, -26224, -76128, -48256, + -107408, -75600, 0, 5472, 720, 11280, -5408, 10240, -18384, 2352, -38208, -12384, -64880, -33968, -98400, + -62400, -138768, -97680, 0, 6688, 848, 13776, -6688, 12480, -22608, 2800, -46912, -15264, -79600, -41712, + -120672, -76544, -170128, -119760, 0, 7904, 976, 16272, -7968, 14720, -26832, 3248, -55616, -18144, + -94320, -49456, -142944, -90688, -201488, -141840, 0, 9120, 1104, 18768, -9248, 16960, -31056, 3696, + -64320, -21024, -109040, -57200, -165216, -104832, -232848, -163920, 0, 10336, 1232, 21264, -10528, 19200, + -35280, 4144, -73024, -23904, -123760, -64944, -187488, -118976, -264208, -186000, 0, 11552, 1360, 23760, + -11808, 21440, -39504, 4592, -81728, -26784, -138480, -72688, -209760, -133120, -295568, -208080, 0, + 12768, 1488, 26256, -13088, 23680, -43728, 5040, -90432, -29664, -153200, -80432, -232032, -147264, + -326928, -230160, 0, 13984, 1616, 28752, -14368, 25920, -47952, 5488, -99136, -32544, -167920, -88176, + -254304, -161408, -358288, -252240, 0, 15200, 1744, 31248, -15648, 28160, -52176, 5936, -107840, -35424, + -182640, -95920, -276576, -175552, -389648, -274320, 0, 16416, 1872, 33744, -16928, 30400, -56400, 6384, + -116544, -38304, -197360, -103664, -298848, -189696, -421008, -296400, 0, 17632, 2000, 36240, -18208, + 32640, -60624, 6832, -125248, -41184, -212080, -111408, -321120, -203840, -452368, -318480, 0, 18848, + 2128, 38736, -19488, 34880, -64848, 7280, -133952, -44064, -226800, -119152, -343392, -217984, -483728, + -340560 ] } ] @@ -1580,5 +2222,265 @@ ] } ] + }, + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4, batchDim = [1]", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4, batchDim = [1]; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ], + "dims": [1, 8, 16], + "type": "float32" + }, + { + "dims": [8, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ] + }, + { + "dims": [1, 8], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + } + ], + "outputs": [ + { + "dims": [1, 8, 8], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0, + -1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, + 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232, + -11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032, + -16405, -48288, -16247 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4, batchDim = [1, 2]", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4; symmetric, batchDim = [1, 2]", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [1, 2, 8, 16], + "type": "float32" + }, + { + "dims": [8, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ] + }, + { + "dims": [1, 8], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + } + ], + "outputs": [ + { + "dims": [1, 2, 8, 8], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0, + -1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, + 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232, + -11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032, + -16405, -48288, -16247, 0, -5889, -22624, -14403, -40896, -18565, -54816, -18375, 0, -6577, -25312, + -16083, -45760, -20725, -61344, -20503, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, 0, + -7953, -30688, -19443, -55488, -25045, -74400, -24759, 0, -8641, -33376, -21123, -60352, -27205, -80928, + -26887, 0, -9329, -36064, -22803, -65216, -29365, -87456, -29015, 0, -10017, -38752, -24483, -70080, + -31525, -93984, -31143, 0, -10705, -41440, -26163, -74944, -33685, -100512, -33271 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; output shape = 8 X 16; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ], + "dims": [8, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [8, 16], + "type": "float32", + "data": [ + 0, 728, 688, 2376, 1632, 4280, 2832, 6440, 4288, 8856, 6000, 11528, 7968, 14456, 10192, 17640, 0, 2200, + 1840, 7176, 4448, 12920, 7824, 19432, 11968, 26712, 16880, 34760, 22560, 43576, 29008, 53160, 0, 3672, + 2992, 11976, 7264, 21560, 12816, 32424, 19648, 44568, 27760, 57992, 37152, 72696, 47824, 88680, 0, 5144, + 4144, 16776, 10080, 30200, 17808, 45416, 27328, 62424, 38640, 81224, 51744, 101816, 66640, 124200, 0, + 6616, 5296, 21576, 12896, 38840, 22800, 58408, 35008, 80280, 49520, 104456, 66336, 130936, 85456, 159720, + 0, 8088, 6448, 26376, 15712, 47480, 27792, 71400, 42688, 98136, 60400, 127688, 80928, 160056, 104272, + 195240, 0, 9560, 7600, 31176, 18528, 56120, 32784, 84392, 50368, 115992, 71280, 150920, 95520, 189176, + 123088, 230760, 0, 11032, 8752, 35976, 21344, 64760, 37776, 97384, 58048, 133848, 82160, 174152, 110112, + 218296, 141904, 266280 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; output shape = 16 X 8; K=16, N=8, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 8, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=8, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [8, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ] + }, + { + "dims": [8], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7] + } + ], + "outputs": [ + { + "dims": [16, 8], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0, + -1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, + 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232, + -11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032, + -16405, -48288, -16247, 0, -5889, -22624, -14403, -40896, -18565, -54816, -18375, 0, -6577, -25312, + -16083, -45760, -20725, -61344, -20503, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, 0, + -7953, -30688, -19443, -55488, -25045, -74400, -24759, 0, -8641, -33376, -21123, -60352, -27205, -80928, + -26887, 0, -9329, -36064, -22803, -65216, -29365, -87456, -29015, 0, -10017, -38752, -24483, -70080, + -31525, -93984, -31143, 0, -10705, -41440, -26163, -74944, -33685, -100512, -33271 + ] + } + ] + } + ] } ]