Skip to content

Commit

Permalink
[executorch][serialization] Serialize PTD files.
Browse files Browse the repository at this point in the history
Introduce top-level serialization file that calls:
- serialize_pte_binary for PTE file
- FlatTensor.serialize_tensors for PTD files.

Differential Revision: [D66523267](https://our.internmc.facebook.com/intern/diff/D66523267/)

[ghstack-poisoned]
  • Loading branch information
lucylq committed Dec 10, 2024
1 parent 4ec0731 commit aa13c87
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 15 deletions.
1 change: 1 addition & 0 deletions exir/_serialize/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ runtime.python_library(
"_dataclass.py",
"_flatbuffer.py",
"_program.py",
"_serialize.py",
"utils.py",
"data_serializer.py",
],
Expand Down
77 changes: 77 additions & 0 deletions exir/_serialize/_serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


from typing import Dict, Tuple

from executorch.exir._serialize import _serialize_pte_binary

from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize.data_serializer import (
DataSerializer,
SerializationInfo,
TensorLayout,
)

from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.emit import EmitterOutput
from executorch.exir.schema import Tensor, TensorDataLocation


def serialize(
emitter_output: EmitterOutput,
config: ExecutorchBackendConfig,
data_serializer: DataSerializer,
) -> Tuple[Cord, Dict[str, Cord]]:
"""Serialize the output from Emitter into ExecuTorch artifacts; PTE and PTD files."""
# Serialize PTE file.
pte: Cord = _serialize_pte_binary(
program=emitter_output.program,
mutable_data=emitter_output.mutable_data,
extract_delegate_segments=config.extract_delegate_segments,
segment_alignment=config.segment_alignment,
constant_tensor_alignment=config.constant_tensor_alignment,
delegate_alignment=config.delegate_alignment,
)

# Serialize PTD files.
ptd_files: Dict[str, Cord] = {}

# Find all external tensors and organize into {fqn: Tensor}.
fqn_to_tensor_layout: Dict[str, TensorLayout] = {}
for plan in emitter_output.program.execution_plan:
for evalue in plan.values:
if isinstance(evalue.val, Tensor):
tensor = evalue.val
if (
tensor.extra_tensor_info is not None
and tensor.extra_tensor_info.fully_qualified_name is not None
and tensor.extra_tensor_info.location is TensorDataLocation.EXTERNAL
):
# pyre-ignore Undefined attribute [16]: Optional type has no attribute `fully_qualified_name`.
fqn_to_tensor_layout[
tensor.extra_tensor_info.fully_qualified_name
] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order)
if len(fqn_to_tensor_layout) > 0:
assert emitter_output.external_constant_map is not None
for (
file,
fqn_map,
) in (
# pyre-ignore Undefined attribute [16]: Optional type has no attribute `items`.
emitter_output.external_constant_map.items()
):
ptd_files[file] = data_serializer.serialize_tensors(
SerializationInfo(
emitter_output.external_constant_buffer,
fqn_map,
fqn_to_tensor_layout,
)
)

return pte, ptd_files
1 change: 1 addition & 0 deletions exir/program/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ python_library(
"//executorch/exir/passes:spec_prop_pass",
"//executorch/exir/passes:weights_to_outputs_pass",
"//executorch/exir/verification:verifier",
"//executorch/extension/flat_tensor/serialize:serialize",
] + (["//executorch/exir/program/fb:logger"] if not runtime.is_oss else [])
)

Expand Down
32 changes: 18 additions & 14 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

import torch
import torch._export
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._serialize import serialize
from executorch.exir._serialize.data_serializer import DataSerializer
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.partitioner import Partitioner
Expand Down Expand Up @@ -56,6 +57,7 @@
EXIREdgeDialectVerifier,
get_aten_verifier,
)
from executorch.extension.flat_tensor.serialize.serialize import FlatTensorSerializer
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
from torch.export import ExportedProgram
from torch.export._remove_auto_functionalized_pass import (
Expand Down Expand Up @@ -494,23 +496,23 @@ def __init__(
)
self.exported_program = exir_exported_program.exported_program
self._pte_data: Optional[Cord] = None
self._data_files: Optional[Dict[str, Cord]] = None
self._buffer: Optional[bytes] = None
self._emitter_output: Optional[EmitterOutput] = None
self._emit_stacktrace: bool = emit_stacktrace
self._extract_delegate_segments: bool = extract_delegate_segments
self._segment_alignment: int = segment_alignment
self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment
self._delegate_alignment: Optional[int] = delegate_alignment
self._data_serializer: DataSerializer = FlatTensorSerializer()

def _get_pte_data(self) -> Cord:
if self._pte_data is None:
self._pte_data = _serialize_pte_binary(
program=self.program,
extract_delegate_segments=self._extract_delegate_segments,
segment_alignment=self._segment_alignment,
constant_tensor_alignment=self._constant_tensor_alignment,
delegate_alignment=self._delegate_alignment,
assert self._emitter_output is not None
self._pte_data, self._data_files = serialize(
self._emitter_output, ExecutorchBackendConfig(), self._data_serializer
)
assert self._pte_data is not None
return self._pte_data

@property
Expand Down Expand Up @@ -1443,14 +1445,11 @@ def __init__(
self._config_methods,
)

self._data_serializer = FlatTensorSerializer()

# Serialize emitter output, ready to be written to a file.
self._pte_data: Cord = _serialize_pte_binary(
program=self._emitter_output.program,
mutable_data=self._emitter_output.mutable_data,
extract_delegate_segments=backend_config.extract_delegate_segments,
segment_alignment=backend_config.segment_alignment,
constant_tensor_alignment=backend_config.constant_tensor_alignment,
delegate_alignment=backend_config.delegate_alignment,
self._pte_data, self._data_files = serialize(
self._emitter_output, ExecutorchBackendConfig(), self._data_serializer
)
self._buffer: Optional[bytes] = None

Expand Down Expand Up @@ -1532,3 +1531,8 @@ def write_to_file(self, open_file: io.BufferedIOBase) -> None:
reducing the peak memory usage.
"""
self._pte_data.write_to_file(open_file)

for filename, cord in self._data_files.items():
filename = filename + ".ptd"
with open(filename, "wb") as file:
cord.write_to_file(file)
2 changes: 1 addition & 1 deletion extension/flat_tensor/test/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

# Test artifacts
TEST_TENSOR_BUFFER = [b"\x11"*4, b"\x22"*32]
TEST_TENSOR_BUFFER = [b"\x11" * 4, b"\x22" * 32]
TEST_TENSOR_MAP = {
"fqn1": 0,
"fqn2": 0,
Expand Down

0 comments on commit aa13c87

Please sign in to comment.