-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
A few fixes relating to constant propagation (#1892)
Fixes a few different issues. Helps resolve an issue relating to ir-based optimization for the Blender model in the benchmark. * Move the utility for evaluating `Constant` op into the IR, and make `const_value` automatically perform the related computation. * Eliminate the dependence on the reference-implementation for evaluation of Constant op. * There are still a couple of issues relating to the use of reference-implementation (eg., when we have tensor-valued attributes in external-data format, and the use of float16) which will need to be addressed separately, but the above bypasses this issue for Constant op (and the Blender model). * Make the optimizer robust to external-data-tensors whose files are not available.
- Loading branch information
1 parent
12f9209
commit ed28222
Showing
10 changed files
with
100 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,13 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""This is a temporary utility to assist new IR while it's still under development.""" | ||
|
||
from __future__ import annotations | ||
|
||
import typing | ||
|
||
import numpy as np | ||
|
||
from onnxscript import ir | ||
|
||
GRAPH_OUTPUT_META_KEY = "pkg.onnxscript.rewriter.generic_pattern.graph_output" | ||
|
||
|
||
def propagate_const_value(ir_value: ir.Value) -> ir.Value: | ||
"""Temporary method to propagate a constant value to the IR value.""" | ||
node = ir_value.producer() | ||
if node is None: | ||
return ir_value | ||
if node.op_type != "Constant": | ||
return ir_value | ||
attr_name, attr_value = next(iter(node.attributes.items())) | ||
if attr_value is None or not isinstance(attr_value, ir.Attr): | ||
return ir_value | ||
import onnxscript.ir as ir | ||
from onnxscript.optimizer import basic_constant_propagation | ||
|
||
const_value: ir.TensorProtocol | ||
if attr_name in {"value_float", "value_floats"}: | ||
const_value = ir.Tensor( | ||
np.array(attr_value.value, dtype=np.float32), name=ir_value.name | ||
) | ||
elif attr_name in {"value_int", "value_ints"}: | ||
const_value = ir.Tensor(np.array(attr_value.value, dtype=np.int64), name=ir_value.name) | ||
elif attr_name in {"value_string", "value_strings"}: | ||
const_value = ir.StringTensor( | ||
np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name | ||
) | ||
elif attr_name == "value": | ||
const_value = typing.cast(ir.TensorProtocol, attr_value.value) | ||
else: | ||
return ir_value | ||
|
||
ir_value.const_value = const_value | ||
ir_value.shape = const_value.shape # type: ignore | ||
ir_value.dtype = const_value.dtype | ||
return ir_value | ||
def get_const_value(value: ir.Value) -> ir.TensorProtocol | None: | ||
node = value.producer() | ||
if node is not None: | ||
basic_constant_propagation([node]) | ||
return value.const_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters