Skip to content

Commit

Permalink
Handling bfloat16 constant propagation (#1484)
Browse files Browse the repository at this point in the history
Just a temporary workaround for
#1471 for
experimentation/discussion.

It looks like the issue is that bfloat16 tensor constants are
represented as float32 numpy arrays (in ONNX itself), when converted to
numpy array. In the context of constant-propagation, this means that we
cannot rely solely on the numpy value's dtype to figure out the ONNX
type.

The hack below suppresses constant-propagation for bfloat16 constants:
partially because of the above reason, and partially since I am yet
unclear if this convention is supported by the onnx reference
implementation (or ORT), etc. Assuming the backend supports it, we can
try other alternative solutions too. One possibility is to simply
suppress constant-propagation if the output-types are unknown (in the
onnx model).
  • Loading branch information
gramalingam authored May 14, 2024
1 parent ac7ce49 commit fe9f29a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
11 changes: 11 additions & 0 deletions onnxscript/_legacy_ir/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,17 @@ def get_constant_value(i: int) -> onnx.TensorProto | None:
for output in node.output:
info = self.lookup_or_create(output)
if output in output_types:
if info.type is not None:
if (
info.type.tensor_type.elem_type
!= output_types[output].tensor_type.elem_type
):
logger.warning(
"Overriding existing type %s with inferred type %s for %s",
info.type,
output_types[output],
output,
)
# TODO: merge types
info.type = output_types[output]

Expand Down
18 changes: 18 additions & 0 deletions onnxscript/optimizer/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,24 @@ def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None:
if any(x is ir.NotConstant for x in input_values):
return None

input_types = [x.type for x in inputs if x is not None]

def is_excluded_type(type_proto: onnx.TypeProto | None) -> bool:
if type_proto is None:
return True
if type_proto.HasField("tensor_type"):
return type_proto.tensor_type.elem_type in {
onnx.TensorProto.BFLOAT16,
onnx.TensorProto.FLOAT8E4M3FN,
onnx.TensorProto.FLOAT8E4M3FNUZ,
onnx.TensorProto.FLOAT8E5M2,
onnx.TensorProto.FLOAT8E5M2FNUZ,
}
return False

if any(is_excluded_type(x) for x in input_types):
return None

outputs = self.evaluate(domain, op, version, *input_values, **attrs)
# TODO: what if evaluated value is None?
if outputs is None:
Expand Down

0 comments on commit fe9f29a

Please sign in to comment.