From aa13c879e67edc6052ebfddb07c5c5a1b543a747 Mon Sep 17 00:00:00 2001 From: lucylq Date: Tue, 10 Dec 2024 09:42:51 -0800 Subject: [PATCH] [executorch][serialization] Serialize PTD files. Introduce top-level serialization file that calls: - serialize_pte_binary for PTE file - FlatTensor.serialize_tensors for PTD files. Differential Revision: [D66523267](https://our.internmc.facebook.com/intern/diff/D66523267/) [ghstack-poisoned] --- exir/_serialize/TARGETS | 1 + exir/_serialize/_serialize.py | 77 ++++++++++++++++++++ exir/program/TARGETS | 1 + exir/program/_program.py | 32 ++++---- extension/flat_tensor/test/test_serialize.py | 2 +- 5 files changed, 98 insertions(+), 15 deletions(-) create mode 100644 exir/_serialize/_serialize.py diff --git a/exir/_serialize/TARGETS b/exir/_serialize/TARGETS index 27b2a7d4c4..1e37c788fd 100644 --- a/exir/_serialize/TARGETS +++ b/exir/_serialize/TARGETS @@ -33,6 +33,7 @@ runtime.python_library( "_dataclass.py", "_flatbuffer.py", "_program.py", + "_serialize.py", "utils.py", "data_serializer.py", ], diff --git a/exir/_serialize/_serialize.py b/exir/_serialize/_serialize.py new file mode 100644 index 0000000000..e1d2ee52c6 --- /dev/null +++ b/exir/_serialize/_serialize.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +from typing import Dict, Tuple + +from executorch.exir._serialize import _serialize_pte_binary + +from executorch.exir._serialize._cord import Cord +from executorch.exir._serialize.data_serializer import ( + DataSerializer, + SerializationInfo, + TensorLayout, +) + +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.emit import EmitterOutput +from executorch.exir.schema import Tensor, TensorDataLocation + + +def serialize( + emitter_output: EmitterOutput, + config: ExecutorchBackendConfig, + data_serializer: DataSerializer, +) -> Tuple[Cord, Dict[str, Cord]]: + """Serialize the output from Emitter into ExecuTorch artifacts; PTE and PTD files.""" + # Serialize PTE file. + pte: Cord = _serialize_pte_binary( + program=emitter_output.program, + mutable_data=emitter_output.mutable_data, + extract_delegate_segments=config.extract_delegate_segments, + segment_alignment=config.segment_alignment, + constant_tensor_alignment=config.constant_tensor_alignment, + delegate_alignment=config.delegate_alignment, + ) + + # Serialize PTD files. + ptd_files: Dict[str, Cord] = {} + + # Find all external tensors and organize into {fqn: Tensor}. + fqn_to_tensor_layout: Dict[str, TensorLayout] = {} + for plan in emitter_output.program.execution_plan: + for evalue in plan.values: + if isinstance(evalue.val, Tensor): + tensor = evalue.val + if ( + tensor.extra_tensor_info is not None + and tensor.extra_tensor_info.fully_qualified_name is not None + and tensor.extra_tensor_info.location is TensorDataLocation.EXTERNAL + ): + # pyre-ignore Undefined attribute [16]: Optional type has no attribute `fully_qualified_name`. + fqn_to_tensor_layout[ + tensor.extra_tensor_info.fully_qualified_name + ] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order) + if len(fqn_to_tensor_layout) > 0: + assert emitter_output.external_constant_map is not None + for ( + file, + fqn_map, + ) in ( + # pyre-ignore Undefined attribute [16]: Optional type has no attribute `items`. + emitter_output.external_constant_map.items() + ): + ptd_files[file] = data_serializer.serialize_tensors( + SerializationInfo( + emitter_output.external_constant_buffer, + fqn_map, + fqn_to_tensor_layout, + ) + ) + + return pte, ptd_files diff --git a/exir/program/TARGETS b/exir/program/TARGETS index 674d7baa35..33e417e732 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -44,6 +44,7 @@ python_library( "//executorch/exir/passes:spec_prop_pass", "//executorch/exir/passes:weights_to_outputs_pass", "//executorch/exir/verification:verifier", + "//executorch/extension/flat_tensor/serialize:serialize", ] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else []) ) diff --git a/exir/program/_program.py b/exir/program/_program.py index fd1d0aca3d..f899e986d3 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -13,8 +13,9 @@ import torch import torch._export -from executorch.exir._serialize import _serialize_pte_binary from executorch.exir._serialize._cord import Cord +from executorch.exir._serialize._serialize import serialize +from executorch.exir._serialize.data_serializer import DataSerializer from executorch.exir._warnings import experimental from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.partitioner import Partitioner @@ -56,6 +57,7 @@ EXIREdgeDialectVerifier, get_aten_verifier, ) +from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass from torch.export import ExportedProgram from torch.export._remove_auto_functionalized_pass import ( @@ -494,6 +496,7 @@ def __init__( ) self.exported_program = exir_exported_program.exported_program self._pte_data: Optional[Cord] = None + self._data_files: Optional[Dict[str, Cord]] = None self._buffer: Optional[bytes] = None self._emitter_output: Optional[EmitterOutput] = None self._emit_stacktrace: bool = emit_stacktrace @@ -501,16 +504,15 @@ def __init__( self._segment_alignment: int = segment_alignment self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment self._delegate_alignment: Optional[int] = delegate_alignment + self._data_serializer: DataSerializer = FlatTensorSerializer() def _get_pte_data(self) -> Cord: if self._pte_data is None: - self._pte_data = _serialize_pte_binary( - program=self.program, - extract_delegate_segments=self._extract_delegate_segments, - segment_alignment=self._segment_alignment, - constant_tensor_alignment=self._constant_tensor_alignment, - delegate_alignment=self._delegate_alignment, + assert self._emitter_output is not None + self._pte_data, self._data_files = serialize( + self._emitter_output, ExecutorchBackendConfig(), self._data_serializer ) + assert self._pte_data is not None return self._pte_data @property @@ -1443,14 +1445,11 @@ def __init__( self._config_methods, ) + self._data_serializer = FlatTensorSerializer() + # Serialize emitter output, ready to be written to a file. - self._pte_data: Cord = _serialize_pte_binary( - program=self._emitter_output.program, - mutable_data=self._emitter_output.mutable_data, - extract_delegate_segments=backend_config.extract_delegate_segments, - segment_alignment=backend_config.segment_alignment, - constant_tensor_alignment=backend_config.constant_tensor_alignment, - delegate_alignment=backend_config.delegate_alignment, + self._pte_data, self._data_files = serialize( + self._emitter_output, ExecutorchBackendConfig(), self._data_serializer ) self._buffer: Optional[bytes] = None @@ -1532,3 +1531,8 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None: reducing the peak memory usage. """ self._pte_data.write_to_file(open_file) + + for filename, cord in self._data_files.items(): + filename = filename + ".ptd" + with open(filename, "wb") as file: + cord.write_to_file(file) diff --git a/extension/flat_tensor/test/test_serialize.py b/extension/flat_tensor/test/test_serialize.py index e5e339e0f3..73ad35b2fb 100644 --- a/extension/flat_tensor/test/test_serialize.py +++ b/extension/flat_tensor/test/test_serialize.py @@ -25,7 +25,7 @@ ) # Test artifacts -TEST_TENSOR_BUFFER = [b"\x11"*4, b"\x22"*32] +TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32] TEST_TENSOR_MAP = { "fqn1": 0, "fqn2": 0,