diff --git a/src/qonnx/transformation/lower_convs_to_matmul.py b/src/qonnx/transformation/lower_convs_to_matmul.py index 49700cd7..81f0b713 100644 --- a/src/qonnx/transformation/lower_convs_to_matmul.py +++ b/src/qonnx/transformation/lower_convs_to_matmul.py @@ -63,6 +63,8 @@ def apply(self, model): stride_w, group, weight_name, + conv_weight_inp_name, + conv_weight_q_scale_name, W_conv, ifm_ch, ofm_ch, @@ -74,12 +76,18 @@ def apply(self, model): pad, ) = self.extract_conv_params(model, node) + if W_conv is None: + warnings.warn("Found Conv node with non-initialized weight, skipping") + continue + # if depthwise conv create sparse matrix and variable "dw" # to store as attribute in Im2Col that indicates that the created # Im2Col node belongs to a depthwise convolution dw = False if group == ifm_ch and ofm_ch == ifm_ch: W_sparse = np.zeros((ofm_ch, ifm_ch, k_h, k_w)) # (OFM, IFM, k_H, k_W) + # TODO: if the convolution is quantized with a non-zero zeropoint we + # should be using the zeropoint value here instead of np.zeros for ch in range(ifm_ch): W_sparse[ch][ch] = W_conv[ch][0] # W_conv = [OFM, IFM, k_H, k_W] W_conv = W_sparse.astype(np.float32) @@ -104,6 +112,21 @@ def apply(self, model): # transpose to get ONNX-compatible [k_h*k_w*IFM][OFM] matrix W_matmul = W_matmul.T model.set_initializer(weight_name, W_matmul) + if weight_name != conv_weight_inp_name: + # required for convs with quantized weights + model.set_tensor_shape(conv_weight_inp_name, W_matmul.shape) + if conv_weight_q_scale_name is not None: + # required for convs with quantized weights + scale_weight_q = model.get_initializer(conv_weight_q_scale_name) + if scale_weight_q.ndim > 0: + # scale shape is originally [OFM, IFM, k_H, k_W] + # transpose into [OFM, k_H, k_W, IFM] + scale_weight_q = scale_weight_q.transpose(0, 2, 3, 1) + # reshape into [OFM][k_h*k_w*IFM] matrix + scale_weight_q = scale_weight_q.reshape(ofm_ch, -1) + # transpose to be shape-compatible with weight matrix + scale_weight_q = scale_weight_q.T + model.set_initializer(conv_weight_q_scale_name, scale_weight_q) # create new intermediate values inp_trans_out = helper.make_tensor_value_info( @@ -154,7 +177,7 @@ def apply(self, model): matmul_input = im2col_out if need_im2col else inp_trans_out # do matmul - matmul_node = helper.make_node("MatMul", [matmul_input, weight_name], [matmul_out]) + matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out]) # NHWC -> NCHW out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2]) @@ -178,7 +201,16 @@ def extract_conv_params(self, model, node): stride_w = get_by_name(node.attribute, "strides").ints[1] group = get_by_name(node.attribute, "group").i weight_name = node.input[1] + conv_weight_inp_name = node.input[1] + conv_weight_q_scale_name = None W_conv = model.get_initializer(weight_name) + if W_conv is None: + # check to see if there is an immediate quantizer node feeding the weight input + w_producer = model.find_producer(weight_name) + if not (w_producer is None) and w_producer.op_type == "Quant": + W_conv = model.get_initializer(w_producer.input[0]) + weight_name = w_producer.input[0] + conv_weight_q_scale_name = w_producer.input[1] ifm_ch = model.get_tensor_shape(cnv_input)[1] # assume NCHW ofm_ch = model.get_tensor_shape(cnv_output)[1] # assume NCHW ifm_dim_h = model.get_tensor_shape(cnv_input)[2] # assume NCHW @@ -213,6 +245,8 @@ def extract_conv_params(self, model, node): stride_w, group, weight_name, + conv_weight_inp_name, + conv_weight_q_scale_name, W_conv, ifm_ch, ofm_ch, diff --git a/src/qonnx/util/test.py b/src/qonnx/util/test.py index f18e437e..ff0fcb15 100644 --- a/src/qonnx/util/test.py +++ b/src/qonnx/util/test.py @@ -145,15 +145,20 @@ def qonnx_download_model(): clize.run(download_model) -def get_golden_in_and_output(test_model): - model = download_model(test_model, do_cleanup=True, return_modelwrapper=True) - rng = np.random.RandomState(42) +def get_random_input(test_model, seed=42): + rng = np.random.RandomState(seed) input_shape = test_model_details[test_model]["input_shape"] (low, high) = test_model_details[test_model]["input_range"] size = np.prod(np.asarray(input_shape)) input_tensor = rng.uniform(low=low, high=high, size=size) input_tensor = input_tensor.astype(np.float32) input_tensor = input_tensor.reshape(input_shape) + return input_tensor + + +def get_golden_in_and_output(test_model, seed=42): + model = download_model(test_model, do_cleanup=True, return_modelwrapper=True) + input_tensor = get_random_input(test_model, seed=seed) input_dict = {model.graph.input[0].name: input_tensor} golden_output_dict = oxe.execute_onnx(model, input_dict) golden_result = golden_output_dict[model.graph.output[0].name] diff --git a/tests/transformation/test_conv_lowering.py b/tests/transformation/test_conv_lowering.py index 788d6993..0da57ea3 100644 --- a/tests/transformation/test_conv_lowering.py +++ b/tests/transformation/test_conv_lowering.py @@ -43,6 +43,19 @@ from qonnx.transformation.infer_shapes import InferShapes from qonnx.transformation.lower_convs_to_matmul import LowerConvsToMatMul from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model +from qonnx.util.test import download_model, get_golden_in_and_output + + +@pytest.mark.parametrize("model_name", ["FINN-CNV_W2A2", "MobileNetv1-w4a4"]) +def test_conv_lowering_quant_weights(model_name): + model = download_model(model_name, return_modelwrapper=True, do_cleanup=True) + input_t, golden_t = get_golden_in_and_output(model_name, seed=0) + input_dict = {model.graph.input[0].name: input_t} + model = model.transform(LowerConvsToMatMul()) + assert model.get_nodes_by_op_type("Conv") == [] + prod_dict = oxe.execute_onnx(model, input_dict) + prod_t = prod_dict[model.graph.output[0].name] + assert np.isclose(golden_t, prod_t, atol=1e-04).all() def test_conv_lowering_convmnist():