Skip to content

Commit

Permalink
feat(mgeconvert/tflite): support get qmin qmax np_dtype dtype_name fr…
Browse files Browse the repository at this point in the history
…om qparams and fix some bugs
  • Loading branch information
dingshaohua960303 committed Jun 6, 2022
1 parent 739c1b9 commit 7554304
Show file tree
Hide file tree
Showing 19 changed files with 338 additions and 182 deletions.
7 changes: 6 additions & 1 deletion mgeconvert/backend/ir_to_tflite/tflite_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
get_shape_param,
mge2tflite_dtype_mapping,
set_quantization,
set_tensor_format,
)


Expand All @@ -53,6 +54,10 @@ def __init__(self, net, graph_name="graph", quantizer=None):

def convert(self, disable_nhwc=False):
# Note the 0th entry of this array must be an empty buffer (sentinel)
if disable_nhwc:
set_tensor_format("nchw")
else:
set_tensor_format("nhwc")
Buffer.BufferStart(self._builder)
buffer = Buffer.BufferEnd(self._builder)
self._buffer_list.append(buffer)
Expand Down Expand Up @@ -106,7 +111,7 @@ def need_convert(mge_opr):
)

if isinstance(dtype, QuantDtypeMeta):
dtype = dtype.np_dtype_str
dtype = dtype.name
else:
dtype = tensor.dtype

Expand Down
82 changes: 70 additions & 12 deletions mgeconvert/backend/ir_to_tflite/tflite_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import List

import numpy as np
from megengine import get_logger
from numpy import dtype

from ...converter_ir.ir_op import (
Expand Down Expand Up @@ -85,10 +86,13 @@
from .tflite.Padding import Padding
from .tflite.TensorType import TensorType

logger = get_logger(__name__)


class Config:
platform = "official"
require_quantize = True
tensor_format = "nhwc"


def set_platform(platform):
Expand All @@ -100,23 +104,26 @@ def set_quantization(require_quantize):
Config.require_quantize = require_quantize


def set_tensor_format(tensor_format):
assert tensor_format in ["nchw", "nhwc"]
Config.tensor_format = tensor_format


def get_platform():
return Config.platform


def get_format():
return Config.tensor_format


def get_quantization():
return Config.require_quantize


def get_shape_param(
tensor: IRTensor, mge_opr: OpBase, quantizer: IRQuantizer, disable_nhwc=False
):
"""
Return a tuple of shape and bytes(1dim) object for tflite operator, which will
restore its inp/out at runtime by the shape and bytes.
"""
def _get_tensor_shape(tensor, mge_opr, disable_nhwc):
if isinstance(mge_opr, ReshapeOpr):
return tensor.shape, None
return tensor.shape

shape = list(tensor.shape)
if tensor.axis_order and tensor.ndim == 4:
Expand All @@ -135,9 +142,25 @@ def get_shape_param(
shape = tensor.axis_order.shape_to_NHWC(shape)
elif isinstance(tensor.axis_order, IOHWFormat):
shape = tensor.axis_order.shape_to_OHWI(shape)
elif tensor.axis_order and mge_opr.name == "Squeeze":
if not disable_nhwc:
nhwc_aixs_order = [0, 3, 1, 2]
inp_shape = list(mge_opr.inp_tensors[0].shape)
assert len(inp_shape) == 4
out_shape = mge_opr.inp_tensors[0].axis_order.shape_to_NHWC(inp_shape)
squeeze_dims = [nhwc_aixs_order[i] for i in mge_opr.squeeze_dims[::-1]]
for i in squeeze_dims:
out_shape.pop(i)
shape = out_shape

elif tensor.ndim > 4:
assert False, "ERROR: output ndim {0} is not supported now".format(tensor.ndim)
return shape


def _get_tensor_value(tensor, mge_opr, quantizer, disable_nhwc):
if isinstance(mge_opr, ReshapeOpr):
return None
number_list: List[np.ndarray] = []
if (
quantizer.require_quantize
Expand All @@ -160,15 +183,34 @@ def get_shape_param(
value = tensor.axis_order.data_to_NHWC(value)
elif isinstance(tensor.axis_order, IOHWFormat):
value = tensor.axis_order.data_to_OHWI(value)

if not disable_nhwc and mge_opr.name == "GetSubTensor" and value is not None:
assert value.shape == (
4,
), "can't support Slice input ndim !=4 in nhwc mode "
value = np.array([value[0], value[2], value[3], value[1]])
number_list = value.reshape(-1)

if len(number_list) > 0:
byte_list: List[bytes] = []
for i in number_list:
byte_list.extend(i.tobytes())
return shape, byte_list
return byte_list
else:
return shape, None
return None


def get_shape_param(
tensor: IRTensor, mge_opr: OpBase, quantizer: IRQuantizer, disable_nhwc=False
):
"""
Return a tuple of shape and bytes(1dim) object for tflite operator, which will
restore its inp/out at runtime by the shape and bytes.
"""
return (
_get_tensor_shape(tensor, mge_opr, disable_nhwc),
_get_tensor_value(tensor, mge_opr, quantizer, disable_nhwc),
)


mge2tflite_dtype_mapping = {
Expand All @@ -184,11 +226,14 @@ def get_shape_param(
dtype("uint8"): TensorType.UINT8,
dtype("int8"): TensorType.INT8,
"quint8": TensorType.UINT8,
"qint8": TensorType.INT8,
"qint32": TensorType.INT32,
"qint16": TensorType.INT16,
"uint8": TensorType.UINT8,
"int8": TensorType.INT8,
"int16": TensorType.INT16,
"int32": TensorType.INT32,
"qint8_narrow": TensorType.INT8,
}


Expand Down Expand Up @@ -381,6 +426,11 @@ def _deconv(mge_opr, builder):

@_register_op(ConcatOpr)
def _concat(mge_opr, builder):
if len(set([t.scale for t in mge_opr.inp_tensors + mge_opr.out_tensors])) != 1:
logger.warning(
"tflite concat doesn't support inputs outputs with different scale!"
)

ConcatenationOptions.ConcatenationOptionsStart(builder)
ConcatenationOptions.ConcatenationOptionsAddFusedActivationFunction(
builder, mge2tflite_activation_type[mge_opr.activation]
Expand Down Expand Up @@ -528,9 +578,17 @@ def _squeeze(mge_opr, builder):
SqueezeOptions.SqueezeOptionsStartSqueezeDimsVector(
builder, len(mge_opr.squeeze_dims)
)
for i in mge_opr.squeeze_dims:
if get_format() == "nhwc":
assert (
mge_opr.inp_tensors[0].ndim == 4
), "can't support Squeeze input ndim !=4 in nhwc mode"
nhwc_aixs_order = [0, 3, 1, 2]
squeeze_dims = [nhwc_aixs_order[i] for i in mge_opr.squeeze_dims]
else:
squeeze_dims = mge_opr.squeeze_dims
for i in squeeze_dims:
builder.PrependInt32(i)
squeeze_dims = builder.EndVector(len(mge_opr.squeeze_dims))
squeeze_dims = builder.EndVector(len(squeeze_dims))
SqueezeOptions.SqueezeOptionsStart(builder)
SqueezeOptions.SqueezeOptionsAddSqueezeDims(builder, squeeze_dims)
options = SqueezeOptions.SqueezeOptionsEnd(builder)
Expand Down
40 changes: 25 additions & 15 deletions mgeconvert/converter_ir/ir_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ def quantize(self, tensor: IRTensor):
value = np.round(value)
if tensor.zero_point:
value += tensor.zero_point
dt = (
np.dtype(tensor.q_dtype)
if isinstance(tensor.q_dtype, str)
else tensor.q_dtype
)
if np.issubdtype(dt, np.integer):
np_dtype = tensor.np_dtype
dt = np.dtype(np_dtype)
if tensor.qmin is not None and tensor.qmax is not None:
v_min = tensor.qmin
v_max = tensor.qmax
elif np.issubdtype(dt, np.integer):
v_min = np.iinfo(dt).min
v_max = np.iinfo(dt).max
value = np.clip(value, v_min, v_max)
value = value.astype(tensor.q_dtype)
value = np.clip(value, v_min, v_max)
value = value.astype(np_dtype)
return value

def save_quantize_params(self, irgraph):
Expand All @@ -56,10 +56,20 @@ def save_quantize_params(self, irgraph):
self.parse_quant_info(t)

def parse_quant_info(self, t: IRTensor):
dt = np.dtype(t.q_dtype)
if t.q_dtype is None:
return
np_dtype = t.np_dtype
try:
dt = np.dtype(np_dtype)
except TypeError:
dt = None

v_max, v_min = None, None
is_weight = bool(t.np_data is not None)
if np.issubdtype(dt, np.integer):
if t.qmin is not None and t.qmax is not None:
v_min = t.qmin
v_max = t.qmax
elif dt is not None and np.issubdtype(dt, np.integer):
v_min = np.iinfo(dt).min
v_max = np.iinfo(dt).max
if self.param_fake_quant and is_weight:
Expand All @@ -78,11 +88,11 @@ def parse_quant_info(self, t: IRTensor):
)[0].numpy()
else:
param = {
"dtype": str(dt),
"qmin": str(v_min),
"qmax": str(v_max),
"scale": str(t.scale),
"zero_point": str(t.zero_point),
"dtype": np_dtype,
"qmin": v_min,
"qmax": v_max,
"scale": t.scale,
"zero_point": t.zero_point,
"is_weight": is_weight,
}
self.quant_params[t.name] = param
Expand Down
57 changes: 49 additions & 8 deletions mgeconvert/converter_ir/ir_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import List, Sequence, Union
from typing import List, Sequence


class DataFormat:
Expand Down Expand Up @@ -94,7 +94,10 @@ def __init__(
dtype,
scale=None,
zero_point=None,
qmin=None,
qmax=None,
q_type=None,
np_dtype=None,
np_data=None,
owner_opr=None,
axis=AxisOrder.NCHW,
Expand All @@ -110,7 +113,11 @@ def __init__(

self.scale = scale
self.zero_point = zero_point
self.qmin = qmin
self.qmax = qmax
assert isinstance(q_type, str) or q_type is None
self.q_dtype = q_type
self.np_dtype = np_dtype

@property
def ndim(self):
Expand All @@ -123,9 +130,22 @@ def set_dtype(self, target_type):
self.np_data = self.np_data.astype(target_type)
self.dtype = target_type

def set_qparams(
self, scale: Union[float, List[float]], zero_point=None, q_dtype=None
):
def set_qparams_from_other_tensor(self, other):
self.q_dtype = other.q_dtype
self.np_dtype = other.np_dtype
self.qmin = other.qmin
self.qmax = other.qmax
self.scale = other.scale
self.zero_point = other.zero_point

def set_qparams_from_mge_qparams(self, qparams):
dtype_meta = qparams.dtype_meta
self.q_dtype = dtype_meta.name
self.np_dtype = dtype_meta.np_dtype_str
self.qmin = dtype_meta.qmin
self.qmax = dtype_meta.qmax
scale = qparams.scale
zero_point = qparams.zero_point
if not isinstance(scale, Sequence): # per tensor
self.scale = float(scale)
else: # per channel
Expand All @@ -137,8 +157,29 @@ def set_qparams(
else:
self.zero_point = [int(zp) for zp in zero_point]

if self.q_dtype is not None:
self.q_dtype = q_dtype
def set_qparams(
self,
*,
scale: float,
q_dtype: str,
qmin: int = None,
qmax: int = None,
zero_point=None,
np_dtype=None,
):
if qmin is None or qmax is None:
assert np_dtype is not None, "must provide np_dtype or qmin and qmax"
if not isinstance(scale, Sequence): # per tensor
self.scale = float(scale)
else: # per channel
self.scale = [float(s) for s in scale]
if zero_point is not None:
if not isinstance(zero_point, Sequence):
self.zero_point = int(zero_point)
else:
self.zero_point = [int(zp) for zp in zero_point]

def __repr__(self):
return self.name
self.q_dtype = q_dtype
self.np_dtype = np_dtype
self.qmin = qmin
self.qmax = qmax
Loading

0 comments on commit 7554304

Please sign in to comment.