From 7554304527c8f30412a4c939ffaa05bbafea61c3 Mon Sep 17 00:00:00 2001 From: dingshaohua Date: Tue, 19 Apr 2022 19:33:42 +0800 Subject: [PATCH] feat(mgeconvert/tflite): support get qmin qmax np_dtype dtype_name from qparams and fix some bugs --- .../backend/ir_to_tflite/tflite_converter.py | 7 +- mgeconvert/backend/ir_to_tflite/tflite_op.py | 82 +++++++++++-- mgeconvert/converter_ir/ir_quantizer.py | 40 +++--- mgeconvert/converter_ir/ir_tensor.py | 57 +++++++-- mgeconvert/converter_ir/ir_transform.py | 115 +++++++++++------- mgeconvert/converters/tm_to_tflite.py | 5 +- .../frontend/tm_to_ir/op_generators/base.py | 4 +- .../frontend/tm_to_ir/op_generators/concat.py | 10 +- .../frontend/tm_to_ir/op_generators/conv2d.py | 26 ++-- .../tm_to_ir/op_generators/conv_bn2d.py | 27 ++-- .../tm_to_ir/op_generators/elemwise.py | 9 +- .../frontend/tm_to_ir/op_generators/matmul.py | 27 ++-- .../tm_to_ir/op_generators/subtensor.py | 6 + mgeconvert/frontend/tm_to_ir/pattern_utils.py | 6 +- mgeconvert/frontend/tm_to_ir/qat_pattern.py | 48 +++++--- mgeconvert/frontend/tm_to_ir/tm_frontend.py | 15 +-- test/mge/test_tflite.py | 8 +- test/traced_module/test_tflite.py | 14 ++- test/utils.py | 14 +++ 19 files changed, 338 insertions(+), 182 deletions(-) diff --git a/mgeconvert/backend/ir_to_tflite/tflite_converter.py b/mgeconvert/backend/ir_to_tflite/tflite_converter.py index c51e553..54559de 100644 --- a/mgeconvert/backend/ir_to_tflite/tflite_converter.py +++ b/mgeconvert/backend/ir_to_tflite/tflite_converter.py @@ -32,6 +32,7 @@ get_shape_param, mge2tflite_dtype_mapping, set_quantization, + set_tensor_format, ) @@ -53,6 +54,10 @@ def __init__(self, net, graph_name="graph", quantizer=None): def convert(self, disable_nhwc=False): # Note the 0th entry of this array must be an empty buffer (sentinel) + if disable_nhwc: + set_tensor_format("nchw") + else: + set_tensor_format("nhwc") Buffer.BufferStart(self._builder) buffer = Buffer.BufferEnd(self._builder) self._buffer_list.append(buffer) @@ -106,7 +111,7 @@ def need_convert(mge_opr): ) if isinstance(dtype, QuantDtypeMeta): - dtype = dtype.np_dtype_str + dtype = dtype.name else: dtype = tensor.dtype diff --git a/mgeconvert/backend/ir_to_tflite/tflite_op.py b/mgeconvert/backend/ir_to_tflite/tflite_op.py index f62ff91..d0dfe77 100644 --- a/mgeconvert/backend/ir_to_tflite/tflite_op.py +++ b/mgeconvert/backend/ir_to_tflite/tflite_op.py @@ -11,6 +11,7 @@ from typing import List import numpy as np +from megengine import get_logger from numpy import dtype from ...converter_ir.ir_op import ( @@ -85,10 +86,13 @@ from .tflite.Padding import Padding from .tflite.TensorType import TensorType +logger = get_logger(__name__) + class Config: platform = "official" require_quantize = True + tensor_format = "nhwc" def set_platform(platform): @@ -100,23 +104,26 @@ def set_quantization(require_quantize): Config.require_quantize = require_quantize +def set_tensor_format(tensor_format): + assert tensor_format in ["nchw", "nhwc"] + Config.tensor_format = tensor_format + + def get_platform(): return Config.platform +def get_format(): + return Config.tensor_format + + def get_quantization(): return Config.require_quantize -def get_shape_param( - tensor: IRTensor, mge_opr: OpBase, quantizer: IRQuantizer, disable_nhwc=False -): - """ - Return a tuple of shape and bytes(1dim) object for tflite operator, which will - restore its inp/out at runtime by the shape and bytes. - """ +def _get_tensor_shape(tensor, mge_opr, disable_nhwc): if isinstance(mge_opr, ReshapeOpr): - return tensor.shape, None + return tensor.shape shape = list(tensor.shape) if tensor.axis_order and tensor.ndim == 4: @@ -135,9 +142,25 @@ def get_shape_param( shape = tensor.axis_order.shape_to_NHWC(shape) elif isinstance(tensor.axis_order, IOHWFormat): shape = tensor.axis_order.shape_to_OHWI(shape) + elif tensor.axis_order and mge_opr.name == "Squeeze": + if not disable_nhwc: + nhwc_aixs_order = [0, 3, 1, 2] + inp_shape = list(mge_opr.inp_tensors[0].shape) + assert len(inp_shape) == 4 + out_shape = mge_opr.inp_tensors[0].axis_order.shape_to_NHWC(inp_shape) + squeeze_dims = [nhwc_aixs_order[i] for i in mge_opr.squeeze_dims[::-1]] + for i in squeeze_dims: + out_shape.pop(i) + shape = out_shape + elif tensor.ndim > 4: assert False, "ERROR: output ndim {0} is not supported now".format(tensor.ndim) + return shape + +def _get_tensor_value(tensor, mge_opr, quantizer, disable_nhwc): + if isinstance(mge_opr, ReshapeOpr): + return None number_list: List[np.ndarray] = [] if ( quantizer.require_quantize @@ -160,15 +183,34 @@ def get_shape_param( value = tensor.axis_order.data_to_NHWC(value) elif isinstance(tensor.axis_order, IOHWFormat): value = tensor.axis_order.data_to_OHWI(value) + + if not disable_nhwc and mge_opr.name == "GetSubTensor" and value is not None: + assert value.shape == ( + 4, + ), "can't support Slice input ndim !=4 in nhwc mode " + value = np.array([value[0], value[2], value[3], value[1]]) number_list = value.reshape(-1) if len(number_list) > 0: byte_list: List[bytes] = [] for i in number_list: byte_list.extend(i.tobytes()) - return shape, byte_list + return byte_list else: - return shape, None + return None + + +def get_shape_param( + tensor: IRTensor, mge_opr: OpBase, quantizer: IRQuantizer, disable_nhwc=False +): + """ + Return a tuple of shape and bytes(1dim) object for tflite operator, which will + restore its inp/out at runtime by the shape and bytes. + """ + return ( + _get_tensor_shape(tensor, mge_opr, disable_nhwc), + _get_tensor_value(tensor, mge_opr, quantizer, disable_nhwc), + ) mge2tflite_dtype_mapping = { @@ -184,11 +226,14 @@ def get_shape_param( dtype("uint8"): TensorType.UINT8, dtype("int8"): TensorType.INT8, "quint8": TensorType.UINT8, + "qint8": TensorType.INT8, "qint32": TensorType.INT32, + "qint16": TensorType.INT16, "uint8": TensorType.UINT8, "int8": TensorType.INT8, "int16": TensorType.INT16, "int32": TensorType.INT32, + "qint8_narrow": TensorType.INT8, } @@ -381,6 +426,11 @@ def _deconv(mge_opr, builder): @_register_op(ConcatOpr) def _concat(mge_opr, builder): + if len(set([t.scale for t in mge_opr.inp_tensors + mge_opr.out_tensors])) != 1: + logger.warning( + "tflite concat doesn't support inputs outputs with different scale!" + ) + ConcatenationOptions.ConcatenationOptionsStart(builder) ConcatenationOptions.ConcatenationOptionsAddFusedActivationFunction( builder, mge2tflite_activation_type[mge_opr.activation] @@ -528,9 +578,17 @@ def _squeeze(mge_opr, builder): SqueezeOptions.SqueezeOptionsStartSqueezeDimsVector( builder, len(mge_opr.squeeze_dims) ) - for i in mge_opr.squeeze_dims: + if get_format() == "nhwc": + assert ( + mge_opr.inp_tensors[0].ndim == 4 + ), "can't support Squeeze input ndim !=4 in nhwc mode" + nhwc_aixs_order = [0, 3, 1, 2] + squeeze_dims = [nhwc_aixs_order[i] for i in mge_opr.squeeze_dims] + else: + squeeze_dims = mge_opr.squeeze_dims + for i in squeeze_dims: builder.PrependInt32(i) - squeeze_dims = builder.EndVector(len(mge_opr.squeeze_dims)) + squeeze_dims = builder.EndVector(len(squeeze_dims)) SqueezeOptions.SqueezeOptionsStart(builder) SqueezeOptions.SqueezeOptionsAddSqueezeDims(builder, squeeze_dims) options = SqueezeOptions.SqueezeOptionsEnd(builder) diff --git a/mgeconvert/converter_ir/ir_quantizer.py b/mgeconvert/converter_ir/ir_quantizer.py index 2c72872..0da6af2 100644 --- a/mgeconvert/converter_ir/ir_quantizer.py +++ b/mgeconvert/converter_ir/ir_quantizer.py @@ -35,16 +35,16 @@ def quantize(self, tensor: IRTensor): value = np.round(value) if tensor.zero_point: value += tensor.zero_point - dt = ( - np.dtype(tensor.q_dtype) - if isinstance(tensor.q_dtype, str) - else tensor.q_dtype - ) - if np.issubdtype(dt, np.integer): + np_dtype = tensor.np_dtype + dt = np.dtype(np_dtype) + if tensor.qmin is not None and tensor.qmax is not None: + v_min = tensor.qmin + v_max = tensor.qmax + elif np.issubdtype(dt, np.integer): v_min = np.iinfo(dt).min v_max = np.iinfo(dt).max - value = np.clip(value, v_min, v_max) - value = value.astype(tensor.q_dtype) + value = np.clip(value, v_min, v_max) + value = value.astype(np_dtype) return value def save_quantize_params(self, irgraph): @@ -56,10 +56,20 @@ def save_quantize_params(self, irgraph): self.parse_quant_info(t) def parse_quant_info(self, t: IRTensor): - dt = np.dtype(t.q_dtype) + if t.q_dtype is None: + return + np_dtype = t.np_dtype + try: + dt = np.dtype(np_dtype) + except TypeError: + dt = None + v_max, v_min = None, None is_weight = bool(t.np_data is not None) - if np.issubdtype(dt, np.integer): + if t.qmin is not None and t.qmax is not None: + v_min = t.qmin + v_max = t.qmax + elif dt is not None and np.issubdtype(dt, np.integer): v_min = np.iinfo(dt).min v_max = np.iinfo(dt).max if self.param_fake_quant and is_weight: @@ -78,11 +88,11 @@ def parse_quant_info(self, t: IRTensor): )[0].numpy() else: param = { - "dtype": str(dt), - "qmin": str(v_min), - "qmax": str(v_max), - "scale": str(t.scale), - "zero_point": str(t.zero_point), + "dtype": np_dtype, + "qmin": v_min, + "qmax": v_max, + "scale": t.scale, + "zero_point": t.zero_point, "is_weight": is_weight, } self.quant_params[t.name] = param diff --git a/mgeconvert/converter_ir/ir_tensor.py b/mgeconvert/converter_ir/ir_tensor.py index 6024089..d2ddef7 100644 --- a/mgeconvert/converter_ir/ir_tensor.py +++ b/mgeconvert/converter_ir/ir_tensor.py @@ -5,7 +5,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from typing import List, Sequence, Union +from typing import List, Sequence class DataFormat: @@ -94,7 +94,10 @@ def __init__( dtype, scale=None, zero_point=None, + qmin=None, + qmax=None, q_type=None, + np_dtype=None, np_data=None, owner_opr=None, axis=AxisOrder.NCHW, @@ -110,7 +113,11 @@ def __init__( self.scale = scale self.zero_point = zero_point + self.qmin = qmin + self.qmax = qmax + assert isinstance(q_type, str) or q_type is None self.q_dtype = q_type + self.np_dtype = np_dtype @property def ndim(self): @@ -123,9 +130,22 @@ def set_dtype(self, target_type): self.np_data = self.np_data.astype(target_type) self.dtype = target_type - def set_qparams( - self, scale: Union[float, List[float]], zero_point=None, q_dtype=None - ): + def set_qparams_from_other_tensor(self, other): + self.q_dtype = other.q_dtype + self.np_dtype = other.np_dtype + self.qmin = other.qmin + self.qmax = other.qmax + self.scale = other.scale + self.zero_point = other.zero_point + + def set_qparams_from_mge_qparams(self, qparams): + dtype_meta = qparams.dtype_meta + self.q_dtype = dtype_meta.name + self.np_dtype = dtype_meta.np_dtype_str + self.qmin = dtype_meta.qmin + self.qmax = dtype_meta.qmax + scale = qparams.scale + zero_point = qparams.zero_point if not isinstance(scale, Sequence): # per tensor self.scale = float(scale) else: # per channel @@ -137,8 +157,29 @@ def set_qparams( else: self.zero_point = [int(zp) for zp in zero_point] - if self.q_dtype is not None: - self.q_dtype = q_dtype + def set_qparams( + self, + *, + scale: float, + q_dtype: str, + qmin: int = None, + qmax: int = None, + zero_point=None, + np_dtype=None, + ): + if qmin is None or qmax is None: + assert np_dtype is not None, "must provide np_dtype or qmin and qmax" + if not isinstance(scale, Sequence): # per tensor + self.scale = float(scale) + else: # per channel + self.scale = [float(s) for s in scale] + if zero_point is not None: + if not isinstance(zero_point, Sequence): + self.zero_point = int(zero_point) + else: + self.zero_point = [int(zp) for zp in zero_point] - def __repr__(self): - return self.name + self.q_dtype = q_dtype + self.np_dtype = np_dtype + self.qmin = qmin + self.qmax = qmax diff --git a/mgeconvert/converter_ir/ir_transform.py b/mgeconvert/converter_ir/ir_transform.py index 066159f..f4b0696 100644 --- a/mgeconvert/converter_ir/ir_transform.py +++ b/mgeconvert/converter_ir/ir_transform.py @@ -239,7 +239,8 @@ def _transpose_pattern_as_input(net): dtype=np.int32, np_data=np.array(op.pattern, dtype=np.int32), owner_opr=op, - q_type=np.int32, + q_type="int32", + np_dtype="int32", axis=None, ) op.add_inp_tensors(perm_tensor) @@ -267,7 +268,8 @@ def _pad_width_as_input(net): dtype=np.int32, np_data=padddings, owner_opr=op, - q_type=np.int32, + q_type="int32", + np_dtype="int32", axis=None, ) op.add_inp_tensors(pad_tensor) @@ -285,7 +287,8 @@ def _reduce_axis_as_input(net): dtype=np.int32, np_data=np.array(op.axis, dtype=np.int32), owner_opr=op, - q_type=np.int32, + q_type="int32", + np_dtype="int32", axis=None, ) op.add_inp_tensors(axis_tensor) @@ -329,7 +332,8 @@ def have_padding(opr): dtype=np.int32, owner_opr=None, np_data=np_data, - q_type=np.int32, + q_type="int32", + np_dtype="int32", axis=None, ) net.add_tensor(new_tensor_id, pad_in_tensor) @@ -350,8 +354,7 @@ def have_padding(opr): hasattr(op.inp_tensors[0], "scale") and op.inp_tensors[0].scale is not None ): - pad_out_tensor.scale = op.inp_tensors[0].scale - pad_out_tensor.q_dtype = op.inp_tensors[0].q_dtype + pad_out_tensor.set_qparams_from_other_tensor(op.inp_tensors[0]) if hasattr(op.inp_tensors[0], "zero_point"): pad_out_tensor.zero_point = op.inp_tensors[0].zero_point net.add_tensor(new_tensor_id, pad_out_tensor) @@ -364,7 +367,6 @@ def have_padding(opr): pad_out_tensor.owner_opr = pad_opr op.inp_tensors = [pad_out_tensor] + op.inp_tensors[1:] pad_out_tensor.user_opr.append(op) - index = net._opr_ids.index(id(op)) insert_intended[index] = (id(pad_opr), pad_opr) @@ -391,7 +393,8 @@ def _deconv_shape_as_input(net: IRGraph): dtype=np.int32, owner_opr=op, np_data=np_data, - q_type=np.int32, + q_type="int32", + np_dtype="int32", axis=None, ) shape_tensor = net.get_tensor(new_tensor_id, shape_symvar) @@ -424,7 +427,8 @@ def _resize_params_as_input(net): shape=(2,), dtype=np.int32, np_data=np.array(op.out_size, dtype=np.int32), - q_type=np.int32, + q_type="int32", + np_dtype="int32", axis=None, ) op.add_inp_tensors(out_size_tensor) @@ -455,9 +459,11 @@ def _add_bias_for_conv(net: IRGraph): ) if op.inp_tensors[0].scale and op.inp_tensors[1].scale: bias_tensor.set_qparams( - op.inp_tensors[0].scale * op.inp_tensors[1].scale, 0 + scale=op.inp_tensors[0].scale * op.inp_tensors[1].scale, + zero_point=0, + q_dtype="int32", + np_dtype="int32", ) - bias_tensor.q_dtype = "int32" op.inp_tensors.append(bias_tensor) @@ -486,9 +492,11 @@ def _add_bias_for_deconv(net: IRGraph): ) if op.inp_tensors[0].scale and op.inp_tensors[1].scale: bias_tensor.set_qparams( - op.inp_tensors[0].scale * op.inp_tensors[1].scale, 0 + scale=op.inp_tensors[0].scale * op.inp_tensors[1].scale, + zero_point=0, + q_dtype="int32", + np_dtype="int32", ) - bias_tensor.q_dtype = "int32" op.inp_tensors.append(bias_tensor) @@ -536,12 +544,15 @@ def _fuse_activation(net): continue if prev_op.activation != "IDENTITY" or prev_op.name == "Deconv2d": continue - + prev_output = prev_op.out_tensors activation = op.name.upper() prev_op.activation = activation prev_op.out_tensors = op.out_tensors for t in prev_op.out_tensors: t.owner_opr = prev_op + if prev_output[0] in net.graph_outputs: + out_idx = net.graph_outputs.index(prev_output[0]) + net.graph_outputs[out_idx] = prev_op.out_tensors[0] delete_intended.append(net._opr_ids.index(op_id)) for delete_idx in delete_intended[::-1]: @@ -567,7 +578,8 @@ def make_input(axis, param, init_value): dtype=np.int32, np_data=np.array(ret, dtype=np.int32), owner_opr=op, # pylint:disable=cell-var-from-loop - q_type=np.int32, + q_type="int32", + np_dtype="int32", ) return ret @@ -576,33 +588,34 @@ def make_input(axis, param, init_value): steps_tensor = make_input(op.axis, op.step_params, 1) op.inp_tensors = [op.inp_tensors[0], begins_tensor, ends_tensor, steps_tensor] + if len(op.squeeze_axis) > 0: + # TFLite slice do not support squeeze axis, so insert a squeeze opr here. + # infer actual output shape of tflite slice + desired_out_shape = op.out_tensors[0].shape + actual_out_shape = [1] * ndim + idx = 0 + for i in range(ndim): + if i in op.squeeze_axis: + continue + actual_out_shape[i] = desired_out_shape[idx] + idx += 1 + slice_out_tensor = IRTensor( + name=op.name + "fake_output", + shape=actual_out_shape, + dtype=op.out_tensors[0].dtype, + q_type=op.out_tensors[0].q_dtype, + np_dtype=op.out_tensors[0].np_dtype, + owner_opr=op, + ) + old_out = op.out_tensors + op.out_tensors = [slice_out_tensor] - # TFLite slice do not support squeeze axis, so insert a squeeze opr here. - # infer actual output shape of tflite slice - desired_out_shape = op.out_tensors[0].shape - actual_out_shape = [1] * ndim - idx = 0 - for i in range(ndim): - if i in op.squeeze_axis: - continue - actual_out_shape[i] = desired_out_shape[idx] - idx += 1 - slice_out_tensor = IRTensor( - name=op.name + "fake_output", - shape=actual_out_shape, - dtype=op.out_tensors[0].dtype, - q_type=op.out_tensors[0].q_dtype, - owner_opr=op, - ) - old_out = op.out_tensors - op.out_tensors = [slice_out_tensor] - - squeeze = SqueezeOpr(op.squeeze_axis) - squeeze.inp_tensors = [slice_out_tensor] - squeeze.out_tensors = old_out + squeeze = SqueezeOpr(op.squeeze_axis) + squeeze.inp_tensors = [slice_out_tensor] + squeeze.out_tensors = old_out - idx = net._opr_ids.index(id(op)) + 1 - net.add_op(squeeze, idx) + idx = net._opr_ids.index(id(op)) + 1 + net.add_op(squeeze, idx) # caffe transormer rules @@ -930,6 +943,7 @@ def _expand_mul_add3(net: IRGraph): scale=op.out_tensors[0].scale, zero_point=op.out_tensors[0].zero_point, q_type=op.out_tensors[0].q_dtype, + np_dtype=op.out_tensors[0].np_dtype, ) new_tensor_id = max(net._tensor_ids) + 1 net.add_tensor(new_tensor_id, mul_out_tensor) @@ -967,6 +981,7 @@ def _expand_add_relu(net: IRGraph): scale=op.out_tensors[0].scale, zero_point=op.out_tensors[0].zero_point, q_type=op.out_tensors[0].q_dtype, + np_dtype=op.out_tensors[0].np_dtype, ) new_tensor_id = max(net._tensor_ids) + 1 net.add_tensor(new_tensor_id, add_out_tensor) @@ -1005,6 +1020,7 @@ def _expand_add_sigmoid(net: IRGraph): scale=op.out_tensors[0].scale, zero_point=op.out_tensors[0].zero_point, q_type=op.out_tensors[0].q_dtype, + np_dtype=op.out_tensors[0].np_dtype, ) new_tensor_id = max(net._tensor_ids) + 1 net.add_tensor(new_tensor_id, add_out_tensor) @@ -1140,6 +1156,7 @@ def _add_fake_hsigmoid_tensor(net: IRGraph): opr.inp_tensors[0].shape, opr.inp_tensors[0].dtype, q_type=opr.inp_tensors[0].q_dtype, + np_dtype=opr.inp_tensors[0].np_dtype, scale=opr.inp_tensors[0].scale, zero_point=opr.inp_tensors[0].zero_point, ) @@ -1149,6 +1166,7 @@ def _add_fake_hsigmoid_tensor(net: IRGraph): opr.inp_tensors[0].shape, opr.inp_tensors[0].dtype, q_type=opr.inp_tensors[0].q_dtype, + np_dtype=opr.inp_tensors[0].np_dtype, scale=opr.inp_tensors[0].scale, zero_point=opr.inp_tensors[0].zero_point, ) @@ -1159,6 +1177,7 @@ def _add_fake_hsigmoid_tensor(net: IRGraph): opr.inp_tensors[0].shape, opr.inp_tensors[0].dtype, q_type=opr.inp_tensors[0].q_dtype, + np_dtype=opr.inp_tensors[0].np_dtype, scale=opr.inp_tensors[0].scale, zero_point=opr.inp_tensors[0].zero_point, ) @@ -1258,9 +1277,12 @@ def _fuse_conv_bn(net: IRGraph): ) if conv_op.inp_tensors[0].scale and conv_op.inp_tensors[1].scale: conv_bias.set_qparams( - conv_op.inp_tensors[0].scale * conv_op.inp_tensors[1].scale, 0 + scale=conv_op.inp_tensors[0].scale + * conv_op.inp_tensors[1].scale, + zero_point=0, + q_dtype="int32", + np_dtype="int32", ) - conv_bias.q_dtype = "int32" conv_op.inp_tensors.append(conv_bias) conv_bias = conv_op.inp_tensors[2].np_data.reshape(1, -1, 1, 1) @@ -1359,10 +1381,12 @@ def _fuse_linear_bn(net: IRGraph): ) if linear_op.inp_tensors[0].scale and linear_op.inp_tensors[1].scale: linear_bias.set_qparams( - linear_op.inp_tensors[0].scale * linear_op.inp_tensors[1].scale, - 0, + scale=linear_op.inp_tensors[0].scale + * linear_op.inp_tensors[1].scale, + zero_point=0, + q_dtype="int32", + np_dtype="int32", ) - linear_bias.q_dtype = "int32" linear_op.inp_tensors.append(linear_bias) linear_bias = linear_op.inp_tensors[2].np_data.reshape(1, -1) @@ -1430,6 +1454,7 @@ def _expand_conv_relu(net: IRGraph): scale=opr.out_tensors[0].scale, zero_point=opr.out_tensors[0].zero_point, q_type=opr.out_tensors[0].q_dtype, + np_dtype=opr.out_tensors[0].np_dtype, owner_opr=conv_op, ) conv_op.out_tensors = [conv_out_tensor] diff --git a/mgeconvert/converters/tm_to_tflite.py b/mgeconvert/converters/tm_to_tflite.py index b4dfbad..23d039e 100644 --- a/mgeconvert/converters/tm_to_tflite.py +++ b/mgeconvert/converters/tm_to_tflite.py @@ -38,6 +38,7 @@ def tracedmodule_to_tflite( outspec=None, remove_relu=False, prefer_same_pad_mode=False, + disable_nhwc=False, ): """ Convert traced model to TFLite, @@ -78,6 +79,7 @@ def tracedmodule_to_tflite( TransformerRule.REMOVE_IDENTITY, TransformerRule.REPLACE_FLATTEN_TO_RESHAPE, TransformerRule.PAD_WIDTH_AS_INPUT, + TransformerRule.EXPAND_ADD_RELU, ] if mtk: # MTK devices only support batch_size 1 @@ -88,7 +90,6 @@ def tracedmodule_to_tflite( transformer = IRTransform(transformer_options) transformed_irgraph = transformer.transform(irgraph) - quantizer = IRQuantizer( require_quantize=require_quantize, param_fake_quant=param_fake_quant ) @@ -98,7 +99,7 @@ def tracedmodule_to_tflite( quantizer.dump_quant_param(path=quantize_file_path) converter = TFLiteConverter(transformed_irgraph, graph_name, quantizer=quantizer) - model = converter.convert() + model = converter.convert(disable_nhwc) assert isinstance(output, str), "tflite_fpath must be string" with open(output, "wb") as fout: diff --git a/mgeconvert/frontend/tm_to_ir/op_generators/base.py b/mgeconvert/frontend/tm_to_ir/op_generators/base.py index 52b5ad7..2dde8f3 100644 --- a/mgeconvert/frontend/tm_to_ir/op_generators/base.py +++ b/mgeconvert/frontend/tm_to_ir/op_generators/base.py @@ -67,7 +67,5 @@ def add_opr_out_tensors(self): for o in self.expr.outputs: t = self.resolver.get_ir_tensor(o, owner_opr=self.op) if is_qat: - t.scale = self.op.inp_tensors[0].scale - t.zero_point = self.op.inp_tensors[0].zero_point - t.q_dtype = self.op.inp_tensors[0].q_dtype + t.set_qparams_from_other_tensor(self.op.inp_tensors[0]) self.op.add_out_tensors(t) diff --git a/mgeconvert/frontend/tm_to_ir/op_generators/concat.py b/mgeconvert/frontend/tm_to_ir/op_generators/concat.py index a67c3fd..5ce662f 100644 --- a/mgeconvert/frontend/tm_to_ir/op_generators/concat.py +++ b/mgeconvert/frontend/tm_to_ir/op_generators/concat.py @@ -52,10 +52,10 @@ def __init__(self, expr, irgraph) -> None: self.module = expr.inputs[0].owner if hasattr(self.module.act_fake_quant, "get_qparams"): self.act_qparams = self.module.act_fake_quant.get_qparams() - self.act_dtype = self.act_qparams.dtype_meta.np_dtype_str + self.act_dtype = self.act_qparams.dtype_meta.name elif hasattr(self.module.act_observer, "get_qparams"): self.act_qparams = self.module.act_observer.get_qparams() - self.act_dtype = self.act_qparams.dtype_meta.np_dtype_str + self.act_dtype = self.act_qparams.dtype_meta.name else: logger.error("Observer and FakeQuantize do not have get_qparams().") super().__init__(expr, irgraph=irgraph) @@ -63,10 +63,6 @@ def __init__(self, expr, irgraph) -> None: def add_opr_out_tensors(self): for o in self.expr.outputs: t = self.resolver.get_ir_tensor(o, owner_opr=self.op) - t.set_qparams( - *self.resolver.resolve_qparams( - self.act_qparams.scale, self.act_qparams.zero_point - ) - ) + t.set_qparams_from_mge_qparams(self.act_qparams) t.q_dtype = self.act_dtype self.op.add_out_tensors(t) diff --git a/mgeconvert/frontend/tm_to_ir/op_generators/conv2d.py b/mgeconvert/frontend/tm_to_ir/op_generators/conv2d.py index 24e5e22..0e9555b 100644 --- a/mgeconvert/frontend/tm_to_ir/op_generators/conv2d.py +++ b/mgeconvert/frontend/tm_to_ir/op_generators/conv2d.py @@ -106,14 +106,14 @@ def __init__(self, expr, irgraph, op_cls): conv_module = expr.inputs[0].owner if hasattr(conv_module.weight_fake_quant, "get_qparams"): self.weight_qparams = conv_module.weight_fake_quant.get_qparams() - self.weight_dtype = self.weight_qparams.dtype_meta.np_dtype_str + self.weight_dtype = self.weight_qparams.dtype_meta.name self.act_qparams = conv_module.act_fake_quant.get_qparams() - self.act_dtype = self.act_qparams.dtype_meta.np_dtype_str + self.act_dtype = self.act_qparams.dtype_meta.name elif hasattr(conv_module.weight_observer, "get_qparams"): self.weight_qparams = conv_module.weight_observer.get_qparams() - self.weight_dtype = self.weight_qparams.dtype_meta.np_dtype_str + self.weight_dtype = self.weight_qparams.dtype_meta.name self.act_qparams = conv_module.act_observer.get_qparams() - self.act_dtype = self.act_qparams.dtype_meta.np_dtype_str + self.act_dtype = self.act_qparams.dtype_meta.name else: logger.error("Observer and FakeQuantize do not have get_qparams().") super().__init__(expr, irgraph, op_cls) @@ -121,11 +121,7 @@ def __init__(self, expr, irgraph, op_cls): def add_opr_out_tensors(self): for o in self.expr.outputs: out_tensor = self.resolver.get_ir_tensor(o, self.op) - out_tensor.set_qparams( - *self.resolver.resolve_qparams( - self.act_qparams.scale, self.act_qparams.zero_point - ) - ) + out_tensor.set_qparams_from_mge_qparams(self.act_qparams) out_tensor.q_dtype = self.act_dtype self.op.add_out_tensors(out_tensor) @@ -137,11 +133,7 @@ def add_const_inputs(self, weight_format): name=self.expr.inputs[0]._name + "_weight", axis_order=weight_format, ) - weight_tensor.set_qparams( - *self.resolver.resolve_qparams( - self.weight_qparams.scale, self.weight_qparams.zero_point - ) - ) + weight_tensor.set_qparams_from_mge_qparams(self.weight_qparams) weight_tensor.q_dtype = self.weight_dtype self.op.add_inp_tensors(weight_tensor) @@ -150,9 +142,11 @@ def add_const_inputs(self, weight_format): self.bias, user_opr=self.op, name=self.expr.inputs[0]._name + "_bias" ) bias_tensor.set_qparams( - self.op.inp_tensors[0].scale * weight_tensor.scale, 0 + scale=self.op.inp_tensors[0].scale * weight_tensor.scale, + zero_point=0, + q_dtype="int32", + np_dtype="int32", ) - bias_tensor.q_dtype = "int32" self.op.add_inp_tensors(bias_tensor) diff --git a/mgeconvert/frontend/tm_to_ir/op_generators/conv_bn2d.py b/mgeconvert/frontend/tm_to_ir/op_generators/conv_bn2d.py index 58b58e0..0459fba 100644 --- a/mgeconvert/frontend/tm_to_ir/op_generators/conv_bn2d.py +++ b/mgeconvert/frontend/tm_to_ir/op_generators/conv_bn2d.py @@ -83,14 +83,14 @@ def __init__(self, expr, irgraph, op_cls): conv_module = expr.inputs[0].owner if hasattr(conv_module.weight_fake_quant, "get_qparams"): self.weight_qparams = conv_module.weight_fake_quant.get_qparams() - self.weight_dtype = self.weight_qparams.dtype_meta.np_dtype_str + self.weight_dtype = self.weight_qparams.dtype_meta.name self.act_qparams = conv_module.act_fake_quant.get_qparams() - self.act_dtype = self.act_qparams.dtype_meta.np_dtype_str + self.act_dtype = self.act_qparams.dtype_meta.name elif hasattr(conv_module.weight_observer, "get_qparams"): self.weight_qparams = conv_module.weight_observer.get_qparams() - self.weight_dtype = self.weight_qparams.dtype_meta.np_dtype_str + self.weight_dtype = self.weight_qparams.dtype_meta.name self.act_qparams = conv_module.act_observer.get_qparams() - self.act_dtype = self.act_qparams.dtype_meta.np_dtype_str + self.act_dtype = self.act_qparams.dtype_meta.name else: logger.error("Observer and FakeQuantize do not have get_qparams().") super().__init__(expr, irgraph, op_cls) @@ -98,11 +98,7 @@ def __init__(self, expr, irgraph, op_cls): def add_opr_out_tensors(self): for o in self.expr.outputs: out_tensor = self.resolver.get_ir_tensor(o, self.op) - out_tensor.set_qparams( - *self.resolver.resolve_qparams( - self.act_qparams.scale, self.act_qparams.zero_point - ) - ) + out_tensor.set_qparams_from_mge_qparams(self.act_qparams) out_tensor.q_dtype = self.act_dtype self.op.add_out_tensors(out_tensor) @@ -114,12 +110,7 @@ def add_const_inputs(self, weight_format): name=self.expr.inputs[0]._name + "_weight", axis_order=weight_format, ) - weight_tensor.set_qparams( - *self.resolver.resolve_qparams( - self.weight_qparams.scale, self.weight_qparams.zero_point - ) - ) - weight_tensor.q_dtype = self.weight_dtype + weight_tensor.set_qparams_from_mge_qparams(self.weight_qparams) self.op.add_inp_tensors(weight_tensor) if self.bias is not None: @@ -127,9 +118,11 @@ def add_const_inputs(self, weight_format): self.bias, user_opr=self.op, name=self.expr.inputs[0]._name + "_bias" ) bias_tensor.set_qparams( - self.op.inp_tensors[0].scale * weight_tensor.scale, 0 + scale=self.op.inp_tensors[0].scale * weight_tensor.scale, + zero_point=0, + q_dtype="int32", + np_dtype="int32", ) - bias_tensor.q_dtype = "int32" self.op.add_inp_tensors(bias_tensor) diff --git a/mgeconvert/frontend/tm_to_ir/op_generators/elemwise.py b/mgeconvert/frontend/tm_to_ir/op_generators/elemwise.py index e59b225..03f76f0 100644 --- a/mgeconvert/frontend/tm_to_ir/op_generators/elemwise.py +++ b/mgeconvert/frontend/tm_to_ir/op_generators/elemwise.py @@ -94,9 +94,7 @@ def add_opr_vars(self): and self.op.inp_tensors[0].scale is not None ): for o in self.op.out_tensors: - o.scale = self.op.inp_tensors[0].scale - o.zero_point = self.op.inp_tensors[0].zero_point - o.dtype = self.op.inp_tensors[0].dtype + o.set_qparams_from_other_tensor(self.op.inp_tensors[0]) # set dtype for const value @@ -283,9 +281,6 @@ def get_elemwise_op(expr, net): else: qparams = module.act_observer.get_qparams() for o in op_gen.get_opr().out_tensors: + o.set_qparams_from_mge_qparams(qparams) o.scale = float(qparams.scale) if method != "sigmoid" else 1 / 256.0 - o.zero_point = ( - int(qparams.zero_point) if qparams.zero_point is not None else None - ) - o.q_dtype = qparams.dtype_meta.np_dtype_str return op_gen diff --git a/mgeconvert/frontend/tm_to_ir/op_generators/matmul.py b/mgeconvert/frontend/tm_to_ir/op_generators/matmul.py index ab7c132..3df70fb 100644 --- a/mgeconvert/frontend/tm_to_ir/op_generators/matmul.py +++ b/mgeconvert/frontend/tm_to_ir/op_generators/matmul.py @@ -87,14 +87,14 @@ def __init__(self, expr, irgraph): self.module = expr.inputs[0].owner if hasattr(self.module.weight_fake_quant, "get_qparams"): self.weight_qparams = self.module.weight_fake_quant.get_qparams() - self.weight_dtype = self.weight_qparams.dtype_meta.np_dtype_str + self.weight_dtype = self.weight_qparams.dtype_meta.name self.act_qparams = self.module.act_fake_quant.get_qparams() - self.act_dtype = self.act_qparams.dtype_meta.np_dtype_str + self.act_dtype = self.act_qparams.dtype_meta.name elif hasattr(self.module.weight_observer, "get_qparams"): self.weight_qparams = self.module.weight_observer.get_qparams() - self.weight_dtype = self.weight_qparams.dtype_meta.np_dtype_str + self.weight_dtype = self.weight_qparams.dtype_meta.name self.act_qparams = self.module.act_observer.get_qparams() - self.act_dtype = self.act_qparams.dtype_meta.np_dtype_str + self.act_dtype = self.act_qparams.dtype_meta.name else: logger.error("Observer and FakeQuantize do not have get_qparams().") super().__init__(expr, irgraph) @@ -103,11 +103,7 @@ def add_const_inputs(self): weight_tensor = self.resolver.get_ir_tensor( self.weight, user_opr=self.op, name=self.expr.inputs[0]._name + "_weight", ) - weight_tensor.set_qparams( - *self.resolver.resolve_qparams( - self.weight_qparams.scale, self.weight_qparams.zero_point - ) - ) + weight_tensor.set_qparams_from_mge_qparams(self.weight_qparams) weight_tensor.q_dtype = self.weight_dtype self.op.add_inp_tensors(weight_tensor) if self.has_bias: @@ -117,18 +113,15 @@ def add_const_inputs(self): name=self.expr.inputs[0]._name + "_bias", ) bias_tensor.set_qparams( - self.op.inp_tensors[0].scale * weight_tensor.scale, 0 + scale=self.op.inp_tensors[0].scale * weight_tensor.scale, + zero_point=0, + q_dtype="int32", + np_dtype="int32", ) - bias_tensor.q_dtype = "int32" self.op.add_inp_tensors(bias_tensor) def add_opr_out_tensors(self): for o in self.expr.outputs: out_tensor = self.resolver.get_ir_tensor(o, owner_opr=self.op) - out_tensor.set_qparams( - *self.resolver.resolve_qparams( - self.act_qparams.scale, self.act_qparams.zero_point - ) - ) - out_tensor.q_dtype = self.act_dtype + out_tensor.set_qparams_from_mge_qparams(self.act_qparams) self.op.add_out_tensors(out_tensor) diff --git a/mgeconvert/frontend/tm_to_ir/op_generators/subtensor.py b/mgeconvert/frontend/tm_to_ir/op_generators/subtensor.py index 9c08113..4e20c93 100644 --- a/mgeconvert/frontend/tm_to_ir/op_generators/subtensor.py +++ b/mgeconvert/frontend/tm_to_ir/op_generators/subtensor.py @@ -75,3 +75,9 @@ def add_opr_vars(self): inp_tensor = self.resolver.get_ir_tensor(inp, user_opr=self.op) self.op.add_inp_tensors(inp_tensor) self.add_opr_out_tensors() + if ( + hasattr(self.op.inp_tensors[0], "scale") + and self.op.inp_tensors[0].scale is not None + ): + for o in self.op.out_tensors: + o.set_qparams_from_other_tensor(self.op.inp_tensors[0]) diff --git a/mgeconvert/frontend/tm_to_ir/pattern_utils.py b/mgeconvert/frontend/tm_to_ir/pattern_utils.py index 65e2cf3..e9c4c0b 100644 --- a/mgeconvert/frontend/tm_to_ir/pattern_utils.py +++ b/mgeconvert/frontend/tm_to_ir/pattern_utils.py @@ -61,7 +61,11 @@ class InputNode: def register_pattern(pattern, default_dict: OrderedDict): def insert(func): - default_dict[pattern] = func + if isinstance(pattern, list): + for p in pattern: + default_dict[p] = func + else: + default_dict[pattern] = func return func return insert diff --git a/mgeconvert/frontend/tm_to_ir/qat_pattern.py b/mgeconvert/frontend/tm_to_ir/qat_pattern.py index 0a2e0a3..4f2cd51 100644 --- a/mgeconvert/frontend/tm_to_ir/qat_pattern.py +++ b/mgeconvert/frontend/tm_to_ir/qat_pattern.py @@ -36,8 +36,16 @@ def gen_qat_conv_opr(module, conv_function_expr, qat_expr, irgraph, is_deconv=Fa ) assert len(module.graph.inputs) == 2 - act_qparams = module.act_fake_quant.get_qparams() - weight_qparams = module.weight_fake_quant.get_qparams() + act_qparams = ( + module.act_fake_quant.get_qparams() + if module.act_observer is None + else module.act_observer.get_qparams() + ) + weight_qparams = ( + module.weight_fake_quant.get_qparams() + if module.weight_observer is None + else module.weight_observer.get_qparams() + ) module.stride = conv_function_expr.args[3] module.padding = conv_function_expr.args[4] @@ -54,17 +62,15 @@ def gen_qat_conv_opr(module, conv_function_expr, qat_expr, irgraph, is_deconv=Fa else GenDeconv2dOpr(qat_expr, irgraph).get_opr() ) - op.inp_tensors[1].scale = float(weight_qparams.scale) - op.inp_tensors[1].zero_point = int(weight_qparams.zero_point) - op.inp_tensors[1].q_dtype = weight_qparams.dtype_meta.np_dtype_str + op.inp_tensors[1].set_qparams_from_mge_qparams(weight_qparams) if len(op.inp_tensors) == 3: - op.inp_tensors[2].scale = op.inp_tensors[0].scale * op.inp_tensors[1].scale - op.inp_tensors[2].q_dtype = "int32" - op.inp_tensors[2].zero_point = 0 - - op.out_tensors[0].scale = act_qparams.scale.numpy()[0] - op.out_tensors[0].zero_point = act_qparams.zero_point.numpy()[0] - op.out_tensors[0].q_dtype = act_qparams.dtype_meta.np_dtype_str + op.inp_tensors[2].set_qparams( + scale=op.inp_tensors[0].scale * op.inp_tensors[1].scale, + zero_point=0, + q_dtype="int32", + np_dtype="int32", + ) + op.out_tensors[0].set_qparams_from_mge_qparams(act_qparams) return op @@ -80,6 +86,17 @@ def gen_qat_conv_opr(module, conv_function_expr, qat_expr, irgraph, is_deconv=Fa MatchAnyNode, ) +pat_conv_bias_relu_1 = ( + QATModule._apply_fakequant_with_observer, + MatchAnyNode, + ( + F.relu, + (F.conv2d, InputNode, QATModule._apply_fakequant_with_observer, MatchAnyNode), + ), + MatchAnyNode, + MatchAnyNode, +) + pat_conv_bias = ( QATModule._apply_fakequant_with_observer, MatchAnyNode, @@ -121,7 +138,7 @@ def gen_qat_conv_opr(module, conv_function_expr, qat_expr, irgraph, is_deconv=Fa ) -@register_fusion_pattern(pat_conv_bias_relu) +@register_fusion_pattern([pat_conv_bias_relu, pat_conv_bias_relu_1]) def qat_conv_bias_relu(module, expr, call_expr, irgraph, _): relu = expr.inputs[1].expr op = gen_qat_conv_opr(module, relu.inputs[0].expr, call_expr, irgraph) @@ -174,9 +191,7 @@ def qat_deconv_relu_bias( relu_op.out_tensors.append(resolver.resolve(call_expr.outputs[0], relu_op)[0]) relu_op.out_tensors[0].name += "_relu" - relu_op.out_tensors[0].q_dtype = relu_op.inp_tensors[0].q_dtype - relu_op.out_tensors[0].scale = relu_op.inp_tensors[0].scale - relu_op.out_tensors[0].zero_point = relu_op.inp_tensors[0].zero_point + relu_op.out_tensors[0].set_qparams_from_other_tensor(relu_op.inp_tensors[0]) irgraph.all_tensors[ irgraph._tensor_ids.index(call_expr.outputs[0]._id) ] = relu_op.out_tensors[0] @@ -186,6 +201,7 @@ def qat_deconv_relu_bias( MATCH_RULE[QATModule._apply_fakequant_with_observer] = [ pat_conv_bias_relu, + pat_conv_bias_relu_1, pat_conv_bias, pat_deconv_relu, pat_conv_relu, diff --git a/mgeconvert/frontend/tm_to_ir/tm_frontend.py b/mgeconvert/frontend/tm_to_ir/tm_frontend.py index 570b488..df94bf0 100644 --- a/mgeconvert/frontend/tm_to_ir/tm_frontend.py +++ b/mgeconvert/frontend/tm_to_ir/tm_frontend.py @@ -76,12 +76,8 @@ def add_net_inputs(self): for node in self.inputs: inp_tensor = self.tensor_resolver.get_ir_tensor(node, owner_opr=self) if node.qparams is not None: - inp_tensor.set_qparams( - *self.tensor_resolver.resolve_qparams( - node.qparams.scale, node.qparams.zero_point - ) - ) - inp_tensor.q_dtype = node.qparams.dtype_meta.np_dtype_str + inp_tensor.set_qparams_from_mge_qparams(node.qparams) + self.irgraph.add_net_inputs(inp_tensor) def get_all_oprs(self): @@ -129,17 +125,12 @@ def get_all_oprs(self): out_tensor = self.irgraph.get_tensor( expr.outputs[0]._id, None, origin_tensor=inp_tensor ) - qdtype = module.get_activation_dtype() qparams = ( module.act_fake_quant.get_qparams() if hasattr(module.act_fake_quant, "get_qparams") else module.act_observer.get_qparams() ) - scale = qparams.scale - zero_point = qparams.zero_point - out_tensor.q_dtype = qdtype - out_tensor.scale = float(scale) - out_tensor.zero_point = int(zero_point) if zero_point else None + out_tensor.set_qparams_from_mge_qparams(qparams) elif isinstance(m.owner, (FloatQuantStub, FloatDequantStub)): module = m.owner inp_tensor = self.tensor_resolver.get_ir_tensor(expr.inputs[1]) diff --git a/test/mge/test_tflite.py b/test/mge/test_tflite.py index 553cc5e..0997eaf 100644 --- a/test/mge/test_tflite.py +++ b/test/mge/test_tflite.py @@ -157,7 +157,13 @@ def test_slice(): net = SubtensorOpr() mge_result = dump_mge_model(net, net.data, tmp_file) _test_convert_result( - net.data, tmp_file, mge_result, max_error, nhwc=False, nhwc2=False + net.data, + tmp_file, + mge_result, + max_error, + nhwc=False, + nhwc2=False, + disable_nhwc=True, ) diff --git a/test/traced_module/test_tflite.py b/test/traced_module/test_tflite.py index 4c921bc..7b5875e 100644 --- a/test/traced_module/test_tflite.py +++ b/test/traced_module/test_tflite.py @@ -16,6 +16,7 @@ ElemwiseOpr, FConcatOpr, LinearOpr, + NCHW_SubtensorOpr, PadOpr, PoolOpr, ReduceOpr, @@ -67,7 +68,10 @@ def _test_convert_result( inputs[i] = inp.transpose((0, 2, 3, 1)) tracedmodule_to_tflite( - tm, output=tmp_file + ".tflite", require_quantize=require_quantize + tm, + output=tmp_file + ".tflite", + require_quantize=require_quantize, + disable_nhwc=not nhwc, ) tfl_model = interpreter.Interpreter(model_path=tmp_file + ".tflite") @@ -236,7 +240,9 @@ def test_squeeze(): net = SqueezeOpr() traced_module, tm_result = get_traced_module(net, mge.tensor(net.data)) print(traced_module.flatten().graph) - _test_convert_result(mge.tensor(net.data), traced_module, tm_result) + _test_convert_result( + mge.tensor(net.data), traced_module, tm_result, nhwc=False, nhwc2=False + ) def test_slice(): @@ -244,6 +250,10 @@ def test_slice(): tm, tm_result = get_traced_module(net, mge.tensor(net.data)) print(tm.flatten().graph) _test_convert_result(mge.tensor(net.data), tm, tm_result, nhwc=False, nhwc2=False) + net1 = NCHW_SubtensorOpr() + tm, tm_result = get_traced_module(net1, mge.tensor(net1.data)) + tm_result = mge.tensor(net1.data).transpose(0, 2, 3, 1)[1:3, 4:9, 2, 4:8] + _test_convert_result(mge.tensor(net1.data), tm, tm_result, nhwc=True, nhwc2=False) def test_typecvt(): diff --git a/test/utils.py b/test/utils.py index 2d7e76c..f5f2149 100644 --- a/test/utils.py +++ b/test/utils.py @@ -245,6 +245,20 @@ def forward(self, x): return x +class NCHW_SubtensorOpr(M.Module): + def __init__(self, fix_batch=False): + super().__init__() + self.fix_batch = fix_batch + self.data = np.random.random((5, 10, 20, 20)).astype(np.float32) + + def forward(self, x): + if self.fix_batch: + x = x[:, 4:8, 4:9, 2] + else: + x = x[1:3, 4:8, 4:9, 2] + return x + + class TransposeOpr(M.Module): def __init__(self): super().__init__()