Skip to content

Commit

Permalink
Merge pull request fastmachinelearning#132 from fastmachinelearning/f…
Browse files Browse the repository at this point in the history
…eature/convlower_qnt

Preserve weight quantizer while lowering convolutions
  • Loading branch information
maltanar authored Aug 22, 2024
2 parents 84ad7ae + 032681c commit bdf9405
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 4 deletions.
36 changes: 35 additions & 1 deletion src/qonnx/transformation/lower_convs_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def apply(self, model):
stride_w,
group,
weight_name,
conv_weight_inp_name,
conv_weight_q_scale_name,
W_conv,
ifm_ch,
ofm_ch,
Expand All @@ -74,12 +76,18 @@ def apply(self, model):
pad,
) = self.extract_conv_params(model, node)

if W_conv is None:
warnings.warn("Found Conv node with non-initialized weight, skipping")
continue

# if depthwise conv create sparse matrix and variable "dw"
# to store as attribute in Im2Col that indicates that the created
# Im2Col node belongs to a depthwise convolution
dw = False
if group == ifm_ch and ofm_ch == ifm_ch:
W_sparse = np.zeros((ofm_ch, ifm_ch, k_h, k_w)) # (OFM, IFM, k_H, k_W)
# TODO: if the convolution is quantized with a non-zero zeropoint we
# should be using the zeropoint value here instead of np.zeros
for ch in range(ifm_ch):
W_sparse[ch][ch] = W_conv[ch][0] # W_conv = [OFM, IFM, k_H, k_W]
W_conv = W_sparse.astype(np.float32)
Expand All @@ -104,6 +112,21 @@ def apply(self, model):
# transpose to get ONNX-compatible [k_h*k_w*IFM][OFM] matrix
W_matmul = W_matmul.T
model.set_initializer(weight_name, W_matmul)
if weight_name != conv_weight_inp_name:
# required for convs with quantized weights
model.set_tensor_shape(conv_weight_inp_name, W_matmul.shape)
if conv_weight_q_scale_name is not None:
# required for convs with quantized weights
scale_weight_q = model.get_initializer(conv_weight_q_scale_name)
if scale_weight_q.ndim > 0:
# scale shape is originally [OFM, IFM, k_H, k_W]
# transpose into [OFM, k_H, k_W, IFM]
scale_weight_q = scale_weight_q.transpose(0, 2, 3, 1)
# reshape into [OFM][k_h*k_w*IFM] matrix
scale_weight_q = scale_weight_q.reshape(ofm_ch, -1)
# transpose to be shape-compatible with weight matrix
scale_weight_q = scale_weight_q.T
model.set_initializer(conv_weight_q_scale_name, scale_weight_q)

# create new intermediate values
inp_trans_out = helper.make_tensor_value_info(
Expand Down Expand Up @@ -154,7 +177,7 @@ def apply(self, model):

matmul_input = im2col_out if need_im2col else inp_trans_out
# do matmul
matmul_node = helper.make_node("MatMul", [matmul_input, weight_name], [matmul_out])
matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out])
# NHWC -> NCHW
out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2])

Expand All @@ -178,7 +201,16 @@ def extract_conv_params(self, model, node):
stride_w = get_by_name(node.attribute, "strides").ints[1]
group = get_by_name(node.attribute, "group").i
weight_name = node.input[1]
conv_weight_inp_name = node.input[1]
conv_weight_q_scale_name = None
W_conv = model.get_initializer(weight_name)
if W_conv is None:
# check to see if there is an immediate quantizer node feeding the weight input
w_producer = model.find_producer(weight_name)
if not (w_producer is None) and w_producer.op_type == "Quant":
W_conv = model.get_initializer(w_producer.input[0])
weight_name = w_producer.input[0]
conv_weight_q_scale_name = w_producer.input[1]
ifm_ch = model.get_tensor_shape(cnv_input)[1] # assume NCHW
ofm_ch = model.get_tensor_shape(cnv_output)[1] # assume NCHW
ifm_dim_h = model.get_tensor_shape(cnv_input)[2] # assume NCHW
Expand Down Expand Up @@ -213,6 +245,8 @@ def extract_conv_params(self, model, node):
stride_w,
group,
weight_name,
conv_weight_inp_name,
conv_weight_q_scale_name,
W_conv,
ifm_ch,
ofm_ch,
Expand Down
11 changes: 8 additions & 3 deletions src/qonnx/util/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,20 @@ def qonnx_download_model():
clize.run(download_model)


def get_golden_in_and_output(test_model):
model = download_model(test_model, do_cleanup=True, return_modelwrapper=True)
rng = np.random.RandomState(42)
def get_random_input(test_model, seed=42):
rng = np.random.RandomState(seed)
input_shape = test_model_details[test_model]["input_shape"]
(low, high) = test_model_details[test_model]["input_range"]
size = np.prod(np.asarray(input_shape))
input_tensor = rng.uniform(low=low, high=high, size=size)
input_tensor = input_tensor.astype(np.float32)
input_tensor = input_tensor.reshape(input_shape)
return input_tensor


def get_golden_in_and_output(test_model, seed=42):
model = download_model(test_model, do_cleanup=True, return_modelwrapper=True)
input_tensor = get_random_input(test_model, seed=seed)
input_dict = {model.graph.input[0].name: input_tensor}
golden_output_dict = oxe.execute_onnx(model, input_dict)
golden_result = golden_output_dict[model.graph.output[0].name]
Expand Down
13 changes: 13 additions & 0 deletions tests/transformation/test_conv_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.lower_convs_to_matmul import LowerConvsToMatMul
from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model
from qonnx.util.test import download_model, get_golden_in_and_output


@pytest.mark.parametrize("model_name", ["FINN-CNV_W2A2", "MobileNetv1-w4a4"])
def test_conv_lowering_quant_weights(model_name):
model = download_model(model_name, return_modelwrapper=True, do_cleanup=True)
input_t, golden_t = get_golden_in_and_output(model_name, seed=0)
input_dict = {model.graph.input[0].name: input_t}
model = model.transform(LowerConvsToMatMul())
assert model.get_nodes_by_op_type("Conv") == []
prod_dict = oxe.execute_onnx(model, input_dict)
prod_t = prod_dict[model.graph.output[0].name]
assert np.isclose(golden_t, prod_t, atol=1e-04).all()


def test_conv_lowering_convmnist():
Expand Down

0 comments on commit bdf9405

Please sign in to comment.