Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[executorch][emitter] Emit FQNs #7192

Merged
merged 4 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions exir/emit/_emit_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ class EmitterOutput:

mutable_data: Optional[List[Buffer]]

# Constants are optionally stored in external files.
# Aggregate unique external constants into one buffer.
external_constant_buffer: List[bytes]
# Each constant_tag groups a set of constants together.
# {constant_tag: {fqn: index into external_constant_buffer}}
external_constant_map: Optional[Dict[str, Dict[str, int]]]


def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule:
gm = exported_program.graph_module
Expand Down Expand Up @@ -199,4 +206,6 @@ def emit_program(
if len(program_state.mutable_buffer) > 1
else None
),
external_constant_buffer=program_state.external_constant_buffer,
external_constant_map=program_state.external_constant_map,
)
73 changes: 66 additions & 7 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
DoubleList,
EValue,
ExecutionPlan,
ExtraTensorInfo,
FreeCall,
Instruction,
Int,
Expand All @@ -76,6 +77,7 @@
ScalarType,
String,
Tensor,
TensorDataLocation,
TensorList,
TensorShapeDynamism,
)
Expand Down Expand Up @@ -121,6 +123,14 @@ class _ProgramState:
# and should be copied to Program.backend_delegate_data.
backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list)

# Constants are optionally stored in external files.
# Aggregate unique external constants into one buffer.
external_constant_buffer: List[bytes] = field(default_factory=list)
external_constant_hash: Dict[str, int] = field(default_factory=dict)
# Each constant_tag groups a set of constants together.
# {constant_tag: {fqn: index into external_constant_buffer}}
external_constant_map: Dict[str, Dict[str, int]] = field(default_factory=dict)


@dataclass
class _EmitterState:
Expand Down Expand Up @@ -363,7 +373,8 @@ def _save_new_const_tensor(
spec: TensorSpec,
buffer_data: bytes,
hashed: str,
allocation_info: Optional[AllocationDetails],
allocation_info: Optional[AllocationDetails] = None,
constant_tag: Optional[str] = None,
) -> int:
"""Saves a new constant tensor to the constant buffer and returns the buffer idx"""

Expand All @@ -372,17 +383,45 @@ def _save_new_const_tensor(

# Update buffer_idx to point to the end of the list where we are adding the new buffer.
buffer = Buffer(storage=buffer_data)

# Tensor is mutable with initial state.
if allocation_info:
buffer_idx = len(self.program_state.mutable_buffer)
self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx
self.program_state.mutable_buffer.append(buffer)

# Tensor is constant.
else:
buffer_idx = len(self.program_state.constant_buffer)
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
self.program_state.constant_buffer.append(buffer)
# Tensor is stored outside of the PTE file.
if (
spec.extra_tensor_info is not None
and spec.extra_tensor_info.fully_qualified_name is not None
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
):
assert (
constant_tag is not None
), "Constant tag is not set for external tensor"

buffer_idx = len(self.program_state.external_constant_buffer)
self.program_state.external_constant_hash[hashed] = buffer_idx
self.program_state.external_constant_buffer.append(buffer_data)
if constant_tag not in self.program_state.external_constant_map:
self.program_state.external_constant_map[constant_tag] = {}
self.program_state.external_constant_map[constant_tag][
spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`.
] = buffer_idx

# Tensor is stored in the PTE file.
else:
buffer_idx = len(self.program_state.constant_buffer)
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
self.program_state.constant_buffer.append(buffer)

return buffer_idx

def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
def _tensor_spec_to_evalue(
self, spec: TensorSpec, constant_tag: Optional[str] = None
) -> EValue:
"""Constructs an EValue from the given TensorSpec."""

allocation_info = None
Expand Down Expand Up @@ -420,13 +459,18 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
buffer_idx = self.program_state.cached_spec_mutable_hash_values.get(
hashed, -1
)
elif (
spec.extra_tensor_info is not None
and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL
):
buffer_idx = self.program_state.external_constant_hash.get(hashed, -1)
else:
buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1)

# Haven't seen this constant before.
if buffer_idx == -1:
buffer_idx = self._save_new_const_tensor(
spec, buffer_data, hashed, allocation_info
spec, buffer_data, hashed, allocation_info, constant_tag
)

if spec.const and spec.nbytes() != len(buffer_data):
Expand Down Expand Up @@ -1557,11 +1601,26 @@ def placeholder(
https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder
"""
spec = self.node.meta["spec"]
constant_tag = self.node.meta.get("constant_tag", None)
is_user_input = True

if isinstance(target, str) and isinstance(spec, TensorSpec):
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)

# If the placeholder has a constant_tag, it is external to the PTE file
# and requires a fqn and location=TensorDataLocation.EXTERNAL
if constant_tag is not None:
assert (
fqn is not None
), "constant tagged tensors require a fully qualified name"
if spec.extra_tensor_info is None:
spec.extra_tensor_info = ExtraTensorInfo(
fully_qualified_name=fqn, location=TensorDataLocation.EXTERNAL
)
else:
spec.extra_tensor_info.fully_qualified_name = fqn
spec.extra_tensor_info.location = TensorDataLocation.EXTERNAL

# From the fqn find the corresponding tensor
real_tensor = None
if fqn in self.exported_program.state_dict:
Expand Down Expand Up @@ -1599,7 +1658,7 @@ def placeholder(
spec.const = not (is_user_input or is_mutable_buffer)

evalue = (
self._tensor_spec_to_evalue(spec)
self._tensor_spec_to_evalue(spec, constant_tag)
if isinstance(spec, TensorSpec)
else self._constant_to_evalue(spec, None)
)
Expand Down