-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR] Test with onnx backend test files #1452
Changes from all commits
e616600
e39d7a2
90c078b
4efc5a5
db08661
68be035
aa58fec
607828c
869e8c4
e4ebc5b
14ac43d
42130e2
44875e7
3f735e0
1f880f4
34c82ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -466,6 +466,7 @@ | |
inputs, | ||
outputs, | ||
nodes=nodes, | ||
# TODO(justinchuby): Attach the values associated with the initializers | ||
initializers=initializers, | ||
doc_string=_get_field(proto, "doc_string"), | ||
name=_get_field(proto, "name"), | ||
|
@@ -633,6 +634,17 @@ | |
doc_string=proto.doc_string, | ||
metadata_props=deserialize_metadata_props(proto.metadata_props), | ||
) | ||
if proto.data_type == _enums.DataType.STRING: | ||
name = _get_field(proto, "name") | ||
doc_string = _get_field(proto, "doc_string") | ||
metadata_props = deserialize_metadata_props(proto.metadata_props) | ||
return _core.StringTensor( | ||
proto.string_data, | ||
shape=_core.Shape(proto.dims), | ||
name=name, | ||
doc_string=doc_string, | ||
metadata_props=metadata_props, | ||
) | ||
return TensorProtoTensor(proto) | ||
|
||
|
||
|
@@ -691,7 +703,25 @@ | |
[_deserialize_graph(g, scoped_values) for g in proto.graphs], | ||
doc_string=doc_string, | ||
) | ||
# TODO: Handle type protos etc. | ||
if type_ == _enums.AttributeType.SPARSE_TENSOR: | ||
raise NotImplementedError("Sparse tensors are not supported yet") | ||
if type_ == _enums.AttributeType.SPARSE_TENSORS: | ||
raise NotImplementedError("Sparse tensors are not supported yet") | ||
if type_ == _enums.AttributeType.TYPE_PROTO: | ||
ir_type = deserialize_type_proto_for_type(proto.tp) | ||
shape = deserialize_type_proto_for_shape(proto.tp) | ||
return _core.AttrTypeProto( | ||
name, _core.TypeAndShape(ir_type, shape), doc_string=doc_string | ||
) | ||
if type_ == _enums.AttributeType.TYPE_PROTOS: | ||
type_and_shapes = [] | ||
for type_proto in proto.type_protos: | ||
ir_type = deserialize_type_proto_for_type(type_proto) | ||
shape = deserialize_type_proto_for_shape(type_proto) | ||
type_and_shapes.append(_core.TypeAndShape(ir_type, shape)) | ||
return _core.AttrTypeProtos(name, type_and_shapes, doc_string=doc_string) | ||
if type_ == _enums.AttributeType.UNDEFINED: | ||
return _core.Attr(name, type_, None, doc_string=doc_string) | ||
raise ValueError(f"Unsupported attribute type: '{type_}'") | ||
|
||
|
||
|
@@ -1078,6 +1108,8 @@ | |
entry = tensor_proto.external_data.add() | ||
entry.key = k | ||
entry.value = str(v) | ||
elif isinstance(from_, _core.StringTensor): | ||
tensor_proto.string_data.extend(from_.string_data()) | ||
else: | ||
tensor_proto.raw_data = from_.tobytes() | ||
_serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props) | ||
|
@@ -1102,37 +1134,69 @@ | |
attribute_proto: onnx.AttributeProto, type_: _enums.AttributeType, value: Any | ||
) -> None: | ||
if type_ == _enums.AttributeType.INT: | ||
# value: int | ||
attribute_proto.i = value | ||
attribute_proto.type = onnx.AttributeProto.INT | ||
elif type_ == _enums.AttributeType.FLOAT: | ||
# value: float | ||
attribute_proto.f = value | ||
attribute_proto.type = onnx.AttributeProto.FLOAT | ||
elif type_ == _enums.AttributeType.STRING: | ||
# value: str | ||
attribute_proto.s = value.encode("utf-8") | ||
attribute_proto.type = onnx.AttributeProto.STRING | ||
elif type_ == _enums.AttributeType.INTS: | ||
# value: Sequence[int] | ||
attribute_proto.ints.extend(value) | ||
attribute_proto.type = onnx.AttributeProto.INTS | ||
elif type_ == _enums.AttributeType.FLOATS: | ||
# value: Sequence[float] | ||
attribute_proto.floats.extend(value) | ||
attribute_proto.type = onnx.AttributeProto.FLOATS | ||
elif type_ == _enums.AttributeType.STRINGS: | ||
# value: Sequence[str] | ||
attribute_proto.strings.extend([s.encode("utf-8") for s in value]) | ||
attribute_proto.type = onnx.AttributeProto.STRINGS | ||
elif type_ == _enums.AttributeType.TENSOR: | ||
# value: _protocols.TensorProtocol | ||
serialize_tensor_into(attribute_proto.t, value) | ||
attribute_proto.type = onnx.AttributeProto.TENSOR | ||
elif type_ == _enums.AttributeType.GRAPH: | ||
# value: _protocols.GraphProtocol | ||
serialize_graph_into(attribute_proto.g, value) | ||
attribute_proto.type = onnx.AttributeProto.GRAPH | ||
elif type_ == _enums.AttributeType.TENSORS: | ||
# value: Sequence[_protocols.TensorProtocol] | ||
for tensor in value: | ||
serialize_tensor_into(attribute_proto.tensors.add(), tensor) | ||
attribute_proto.type = onnx.AttributeProto.TENSORS | ||
elif type_ == _enums.AttributeType.GRAPHS: | ||
# value: Sequence[_protocols.GraphProtocol] | ||
for graph in value: | ||
serialize_graph_into(attribute_proto.graphs.add(), graph) | ||
attribute_proto.type = onnx.AttributeProto.GRAPHS | ||
elif type_ == _enums.AttributeType.SPARSE_TENSOR: | ||
raise NotImplementedError("Sparse tensors are not supported yet") | ||
elif type_ == _enums.AttributeType.SPARSE_TENSORS: | ||
raise NotImplementedError("Sparse tensors are not supported yet") | ||
elif type_ == _enums.AttributeType.TYPE_PROTO: | ||
# value: _core.TypeAndShape | ||
if value.type is not None: | ||
serialize_type_into(attribute_proto.tp, value.type) | ||
# Need to create the type _before_ writing the shape | ||
if value.shape is not None: | ||
serialize_shape_into(attribute_proto.tp, value.shape) | ||
attribute_proto.type = onnx.AttributeProto.TYPE_PROTO | ||
elif type_ == _enums.AttributeType.TYPE_PROTOS: | ||
for ir_type in value: | ||
# ir_type: _core.TypeAndShape | ||
type_proto = attribute_proto.type_protos.add() | ||
if ir_type.type is not None: | ||
serialize_type_into(type_proto, ir_type.type) | ||
# Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto | ||
if ir_type.shape is not None: | ||
serialize_shape_into(type_proto, ir_type.shape) | ||
attribute_proto.type = onnx.AttributeProto.TYPE_PROTOS | ||
else: | ||
raise TypeError(f"Unsupported attribute type: {type_}") | ||
|
||
|
@@ -1179,10 +1243,11 @@ | |
value_info_proto.name = from_.name | ||
if from_.metadata_props: | ||
_serialize_metadata_props_into(value_info_proto.metadata_props, from_.metadata_props) | ||
if from_.shape is not None: | ||
serialize_shape_into(value_info_proto.type, from_.shape) | ||
if from_.type is not None: | ||
serialize_type_into(value_info_proto.type, from_.type) | ||
# Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto | ||
if from_.shape is not None: | ||
serialize_shape_into(value_info_proto.type, from_.shape) | ||
|
||
|
||
def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtocol) -> None: | ||
|
@@ -1205,12 +1270,18 @@ | |
|
||
|
||
def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None: | ||
tensor_type_proto = type_proto.tensor_type | ||
value_field = type_proto.WhichOneof("value") | ||
tensor_type = getattr(type_proto, value_field) | ||
while not isinstance(tensor_type.elem_type, int): | ||
# Find the leaf type that has the shape field | ||
type_proto = tensor_type.elem_type | ||
value_field = type_proto.WhichOneof("value") | ||
tensor_type = getattr(type_proto, value_field) | ||
# When from is empty, we still need to set the shape field to an empty list by touching it | ||
tensor_type_proto.shape.ClearField("dim") | ||
tensor_type.shape.ClearField("dim") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check naming |
||
for i, dim in enumerate(from_): | ||
denotation = from_.get_denotation(i) | ||
serialize_dimension_into(tensor_type_proto.shape.dim.add(), dim, denotation) | ||
serialize_dimension_into(tensor_type.shape.dim.add(), dim, denotation) | ||
|
||
|
||
def serialize_dimension_into( | ||
|
@@ -1223,4 +1294,6 @@ | |
if isinstance(dim, int): | ||
dim_proto.dim_value = dim | ||
elif isinstance(dim, (_core.SymbolicDim, _protocols.SymbolicDimProtocol)): | ||
dim_proto.dim_param = str(dim.value) | ||
if dim.value is not None: | ||
# TODO(justinchuby): None is probably not a valid value for dim_param | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we do something when it's None? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like we are ok doing nothing. The None was from a test case. We should add a pass to check that no dims are None in rewriter produced values. etc. |
||
dim_proto.dim_param = str(dim.value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we show this dependency in code? Like putting from_.shape under from_.type is not None branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean line 1238-1239?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you are saying. Currently we just assume the type to be a type that is an unknown tensor dtype with a known shape, if type is None. I think that's ok?