diff --git a/exir/_serialize/TARGETS b/exir/_serialize/TARGETS index 49419a4159..259ea615c5 100644 --- a/exir/_serialize/TARGETS +++ b/exir/_serialize/TARGETS @@ -33,6 +33,7 @@ runtime.python_library( "_dataclass.py", "_flatbuffer.py", "_program.py", + "utils.py", ], resources = { "//executorch/schema:program.fbs": "program.fbs", diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index 00a3d4700f..4001cc487f 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -11,7 +11,7 @@ import re from dataclasses import dataclass -from typing import ClassVar, List, Literal, Optional, Tuple +from typing import ClassVar, List, Optional, Tuple from executorch.exir._serialize._cord import Cord from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass @@ -21,6 +21,13 @@ _program_json_to_flatbuffer, ) +from executorch.exir._serialize.utils import ( + _aligned_size, + _HEADER_BYTEORDER, + _pad_to, + _padding_required, +) + from executorch.exir.schema import ( BackendDelegateDataReference, BackendDelegateInlineData, @@ -33,12 +40,6 @@ from executorch.exir.tensor import ALIGNMENT -# Byte order of numbers written to program headers. Always little-endian -# regardless of the host system, since all commonly-used modern CPUs are little -# endian. -_HEADER_BYTEORDER: Literal["little"] = "little" - - def _program_to_json(program: Program) -> str: """Returns the JSON representation of the given Program.""" return json.dumps(program, cls=_DataclassEncoder) @@ -50,19 +51,6 @@ def _json_to_program(program_json: bytes) -> Program: return _json_to_dataclass(json.loads(program_json), cls=Program) -def _padding_required(offset: int, alignment: int) -> int: - """Returns the padding required to align `offset` to `alignment`.""" - remainder: int = offset % alignment - if remainder != 0: - return alignment - remainder - return 0 - - -def _aligned_size(input_size: int, alignment: int) -> int: - """Returns input_size padded up to the next whole multiple of alignment.""" - return input_size + _padding_required(input_size, alignment) - - def _insert_flatbuffer_header( flatbuffer_data: bytes, magic_regex: str, header_data: bytes ) -> bytes: @@ -211,25 +199,6 @@ def to_bytes(self) -> bytes: return data -def _pad_to(data: bytes, length: int) -> bytes: - """Returns the input followed by enough zero bytes to become the requested length. - - Args: - data: The data to pad. - length: The length of the returned data. - Returns: - The padded data. - Raises: - ValueError: If the requested length is less than the input length. - """ - if length < len(data): - raise ValueError(f"Data length {len(data)} > padded length {length}") - if length > len(data): - data = data + b"\x00" * (length - len(data)) - assert len(data) == length - return data - - def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]: """Returns the extended header of the program data, if present and valid.""" try: diff --git a/exir/_serialize/utils.py b/exir/_serialize/utils.py new file mode 100644 index 0000000000..fe6a97de2c --- /dev/null +++ b/exir/_serialize/utils.py @@ -0,0 +1,42 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Literal + +# Byte order of numbers written to program headers. Always little-endian +# regardless of the host system, since all commonly-used modern CPUs are little +# endian. +_HEADER_BYTEORDER: Literal["little"] = "little" + + +def _pad_to(data: bytes, length: int) -> bytes: + """Returns the input followed by enough zero bytes to become the requested length. + + Args: + data: The data to pad. + length: The length of the returned data. + Returns: + The padded data. + Raises: + ValueError: If the requested length is less than the input length. + """ + if length < len(data): + raise ValueError(f"Data length {len(data)} > padded length {length}") + if length > len(data): + data = data + b"\x00" * (length - len(data)) + assert len(data) == length + return data + + +def _padding_required(offset: int, alignment: int) -> int: + """Returns the padding required to align `offset` to `alignment`.""" + remainder: int = offset % alignment + if remainder != 0: + return alignment - remainder + return 0 + + +def _aligned_size(input_size: int, alignment: int) -> int: + """Returns input_size padded up to the next whole multiple of alignment.""" + return input_size + _padding_required(input_size, alignment)