Skip to content

Commit

Permalink
Merge pull request fastmachinelearning#78 from iksnagreb/fix/transpos…
Browse files Browse the repository at this point in the history
…e_into_quant

Fix FoldTransposeIntoQuantInit Transformation
  • Loading branch information
maltanar authored Oct 23, 2023
2 parents e62517a + 0351d9e commit c966b46
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 41 deletions.
116 changes: 75 additions & 41 deletions src/qonnx/transformation/quant_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,57 +26,91 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import warnings
# Protobuf onnx graph node type
from onnx import NodeProto

# QONNX wrapper of ONNX model graphs
from qonnx.core.modelwrapper import ModelWrapper

# QONNX graph transformations base class
from qonnx.transformation.base import Transformation

# Gets items from protobuf by name
from qonnx.util.basic import get_by_name


# Tests whether a node is a quant-init, i.e., a quantizer with only initializer
# inputs
def is_quant_init(node: NodeProto, model: ModelWrapper):
# Only handle existing Quant or BipolarQuant type nodes
if node is not None and node.op_type in {"Quant", "BipolarQuant"}:
# All inputs must have initializers, otherwise this is just a normal
# quant, but not a quant-init
return all(model.get_initializer(i) is not None for i in node.input)
# Did not match the operator type
return False


# Transpose nodes can be folded into quantized initializers, i.e., Quant nodes
# where *all* inputs are initializers. Initializers are constants and part of
# the model graph and thus can be transposed offline.
class FoldTransposeIntoQuantInit(Transformation):
"""
Fueses a Transpose node into the initalizer of a Quant node.
Fuses a Transpose node into the initializers of a Quant node.
"""

def apply(self, model):
# Applies the transform to a whole model graph
def apply(self, model: ModelWrapper):
# Get the model graph out of the model wrapper object
graph = model.graph
node_ind = 0
# Keep track of whether the graph has been modified
graph_modified = False
# Find transpose nodes, which have Quant node with initilizer upstream.
for n in graph.node:
node_ind += 1
if n.op_type == "Transpose":
predecessors = model.find_direct_predecessors(n)
# Check if we reached the top of the graph
if predecessors is None:
# Iterate all nodes in the graph keeping track of the index
for index, node in enumerate(graph.node):
# This transformation is triggered by finding a Transpose node
if node.op_type == "Transpose":
# Get the predecessors feeding into the transpose node
predecessors = model.find_direct_predecessors(node)
# The transform applies only to transpose with exactly one input
if predecessors is None or len(predecessors) != 1:
# Note: Softly skip this node, maybe consider a hard failure
# at least in case there are multiple inputs?
continue
predecessor = predecessors[0]
if predecessor.op_type == "Quant" or predecessor.op_type == "BipolarQuant":
for inp in predecessor.input:
if not isinstance(model.get_initializer(inp), type(None)):
# Explicitly apply the transpose to the initializers
# of the previous node
target_tensor = model.get_initializer(inp)
if target_tensor is None:
warnings.warn(
f"Cannot fold transpose {n} into Quant/BipolarQuant node {predecessor}, "
f"due to not initialized tensor: {inp}. "
f"Exiting FoldTransposeIntoQuantInit transformation."
)
return model, False
# Make sure the tensor has the correct shape
perm = get_by_name(n.attribute, "perm")
if perm is None:
target_tensor = target_tensor.transpose()
model.set_initializer(inp, target_tensor)
graph_modified = True
elif len(perm.ints) == len(target_tensor.shape):
target_tensor = target_tensor.transpose(perm.ints)
model.set_initializer(inp, target_tensor)
graph_modified = True
# Reconnect predecessor and delete transpose node
predecessor.output[0] = n.output[0]
graph.node.remove(n)

return model, graph_modified

# Check whether the predecessor is a quantizer with only
# initializer inputs
if is_quant_init(predecessors[0], model):
# Alias to the single predecessor node
quant_init = predecessors[0]
# Get the (optional) permutation indices of the transpose in
# case it is a multi-axis transpose
perm = get_by_name(node.attribute, "perm")
# Convert permutation indices to list of integers if it is
# given
perm = perm.ints if perm is not None else None
# Transpose all(!) initializer inputs of the quant node
for i in quant_init.input:
# Get the initializer tensor
# Note: No need to validate the presence of the
# initializer here, as we already tested this as the
# applicability condition above
tensor = model.get_initializer(i)
# Skip transposing the initializer if the number of
# dimensions do not match
if perm is not None and len(perm) != tensor.ndim:
# Note: Soft skip ok or is this an error?
continue
# Transpose the tensor, optionally according to the
# permutation indices (perm might be None)
tensor = tensor.transpose(perm)
# Reassign the transposed initializer tensor
model.set_initializer(i, tensor)
# The graph has been modified, this needs to be reported
# back to the caller
graph_modified = True
# Rewire the graph to skip the transpose node
quant_init.output[0] = node.output[0]
# Remove the now absorbed transpose node
graph.node.remove(node)
# Return the transformed model and indicate whether the graph actually
# has been transformed
return model, graph_modified
142 changes: 142 additions & 0 deletions tests/transformation/test_quant_constant_folding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Set pytest parameters
import pytest

# Numpy for handling simulation of tensor operations
import numpy as np

# Helper for creating ONNX nodes
from onnx import NodeProto, TensorProto # noqa
from onnx import helper as oh # noqa

# QONNX wrapper of ONNX model graphs
from qonnx.core.modelwrapper import ModelWrapper # noqa

# Execute QONNX model graphs
from qonnx.core.onnx_exec import execute_onnx # noqa

# QONNX quantizer function modeling the behavior of the Quant operator
from qonnx.custom_op.general.quant import quant as quant_fn # noqa

# QONNX graph transformations for inferring datatypes and shapes required by
# test setup
from qonnx.transformation.infer_datatypes import InferDataTypes # noqa
from qonnx.transformation.infer_shapes import InferShapes # noqa

# Graph transformation to be tested: Transposes the initializers to Quantizer if
# ALL inputs are initializers
from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit # noqa

# QONNX utility for creating models from ONNX graphs
from qonnx.util.basic import qonnx_make_model # noqa


@pytest.mark.parametrize("quant_init", [True, False])
@pytest.mark.parametrize("signed", [0, 1])
@pytest.mark.parametrize("narrow", [0, 1])
@pytest.mark.parametrize("rounding_mode", ["ROUND"])
@pytest.mark.parametrize("shape", [(16, 8, 12)])
@pytest.mark.parametrize(
"perm",
[
# All axis permutations
(0, 1, 2),
(0, 2, 1),
(1, 0, 2),
(1, 2, 0),
(2, 0, 1),
(2, 1, 0),
],
)
@pytest.mark.parametrize("scale", [0.01])
@pytest.mark.parametrize("zeropoint", [0])
@pytest.mark.parametrize("bitwidth", [8])
# Tests the FoldTransposeIntoQuantInit transformation
def test_fold_transpose_into_quant_init(quant_init, signed, narrow, rounding_mode, shape, perm, scale, zeropoint, bitwidth):
# Prepare the quantizer node attributes and input/output lists
node_attrs = {
# Type of the operation
"op_type": "Quant",
# This operator type is defined within QONNX
"domain": "qonnx.custom_op.general",
# List the inputs to the operator in order
# Note: The proper input followed by initializers configuring the
# quantizer
"inputs": ["input", "scale", "zeropoint", "bitwidth"],
# List the outputs of the operator in order
# Note: Intermediate feeds to the next operator input
"outputs": ["intermediate"],
# Whether the quantization interval should be signed or not
# (e.g. at 8b unsigned=[0, 255] vs signed=[-128, 127])
"signed": signed,
# When signed=1, whether to use narrow range or not
# (e.g. at 8b regular=[-128, 127] vs narrow=[-127, 127])
"narrow": narrow,
# The rounding mode, which is used for the quant function
"rounding_mode": rounding_mode,
}
# Create a dummy quantizer node
quant = oh.make_node(**node_attrs, name="Quant")
# Attach a Transpose operation to the quantizer
transpose = oh.make_node("Transpose", ["intermediate"], ["output"], name="Transpose", perm=perm)
# Derive the transposed shape
transposed_shape = np.transpose(np.zeros(shape), perm).shape
# Create tensor information for the input, intermediate and output tensor
x = oh.make_tensor_value_info("input", TensorProto.FLOAT, shape) # noqa
y = oh.make_tensor_value_info("output", TensorProto.FLOAT, transposed_shape)
# Create the initializer tensors for quantizer parameters
s = oh.make_tensor_value_info("scale", TensorProto.FLOAT, (1,))
z = oh.make_tensor_value_info("zeropoint", TensorProto.FLOAT, (1,))
b = oh.make_tensor_value_info("bitwidth", TensorProto.FLOAT, (1,))
# Create the graph connecting the nodes and tensors
graph = oh.make_graph(
[quant, transpose],
"quant-transpose",
[x, s, z, b],
[y],
)
# Wrap the graph in an QONNX model wrapper
model = ModelWrapper(qonnx_make_model(graph, producer_name="qonnx-tests"))
# Add the actual initializers to the initializer tensors
model.set_initializer("scale", np.array(scale))
model.set_initializer("zeropoint", np.array(zeropoint))
model.set_initializer("bitwidth", np.array(bitwidth))
# Prepare the model graph by inferring all missing shape and datatype
# information
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())

# Get a random dummy input for testing
x = np.random.rand(*shape) # noqa
# Fill the execution context with dummy input data
context = {"input": x}

# Some test cases even turn the input into an initializer
if quant_init:
# Turn the model input into an initializer
model.set_initializer("input", x)
# Clear the execution context removing the input as it is now baked into
# the model graph
context = {}

# Run the transformation to be tested
model = model.transform(FoldTransposeIntoQuantInit())
# Verify that shape and datatype inference still works
# Note: This has been an issue, please see
# https://github.com/fastmachinelearning/qonnx/issues/77
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())

# For the case of quant-initializers there must not be a Transpose left
# after transforming and contrariwise, the Transpose must remain in place if
# there is non-initializer input.
assert quant_init != ("Transpose" in [n.op_type for n in model.graph.node])

# Execute the ONNX model
o_produced = execute_onnx(model, context)["output"]
# Use numpy and QONNX quantizer to generate expectation
o_expected = np.transpose(
quant_fn(x, np.array(scale), np.array(zeropoint), np.array(bitwidth), signed, narrow, rounding_mode), perm
)

# The output must match the "manual" execution using numpy
assert np.allclose(o_produced, o_expected)

0 comments on commit c966b46

Please sign in to comment.