diff --git a/exir/emit/_emit_program.py b/exir/emit/_emit_program.py index 9c8c9dfd06..f9571143a1 100644 --- a/exir/emit/_emit_program.py +++ b/exir/emit/_emit_program.py @@ -47,6 +47,13 @@ class EmitterOutput: mutable_data: Optional[List[Buffer]] + # Constants are optionally stored in external files. + # Aggregate unique external constants into one buffer. + external_constant_buffer: List[bytes] + # Each constant_tag groups a set of constants together. + # {constant_tag: {fqn: index into external_constant_buffer}} + external_constant_map: Optional[Dict[str, Dict[str, int]]] + def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule: gm = exported_program.graph_module @@ -199,4 +206,6 @@ def emit_program( if len(program_state.mutable_buffer) > 1 else None ), + external_constant_buffer=program_state.external_constant_buffer, + external_constant_map=program_state.external_constant_map, ) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 2d6c066cce..88247d2a27 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -63,6 +63,7 @@ DoubleList, EValue, ExecutionPlan, + ExtraTensorInfo, FreeCall, Instruction, Int, @@ -76,6 +77,7 @@ ScalarType, String, Tensor, + TensorDataLocation, TensorList, TensorShapeDynamism, ) @@ -121,6 +123,14 @@ class _ProgramState: # and should be copied to Program.backend_delegate_data. backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list) + # Constants are optionally stored in external files. + # Aggregate unique external constants into one buffer. + external_constant_buffer: List[bytes] = field(default_factory=list) + external_constant_hash: Dict[str, int] = field(default_factory=dict) + # Each constant_tag groups a set of constants together. + # {constant_tag: {fqn: index into external_constant_buffer}} + external_constant_map: Dict[str, Dict[str, int]] = field(default_factory=dict) + @dataclass class _EmitterState: @@ -363,7 +373,8 @@ def _save_new_const_tensor( spec: TensorSpec, buffer_data: bytes, hashed: str, - allocation_info: Optional[AllocationDetails], + allocation_info: Optional[AllocationDetails] = None, + constant_tag: Optional[str] = None, ) -> int: """Saves a new constant tensor to the constant buffer and returns the buffer idx""" @@ -372,17 +383,45 @@ def _save_new_const_tensor( # Update buffer_idx to point to the end of the list where we are adding the new buffer. buffer = Buffer(storage=buffer_data) + + # Tensor is mutable with initial state. if allocation_info: buffer_idx = len(self.program_state.mutable_buffer) self.program_state.cached_spec_mutable_hash_values[hashed] = buffer_idx self.program_state.mutable_buffer.append(buffer) + + # Tensor is constant. else: - buffer_idx = len(self.program_state.constant_buffer) - self.program_state.cached_spec_hash_values[hashed] = buffer_idx - self.program_state.constant_buffer.append(buffer) + # Tensor is stored outside of the PTE file. + if ( + spec.extra_tensor_info is not None + and spec.extra_tensor_info.fully_qualified_name is not None + and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL + ): + assert ( + constant_tag is not None + ), "Constant tag is not set for external tensor" + + buffer_idx = len(self.program_state.external_constant_buffer) + self.program_state.external_constant_hash[hashed] = buffer_idx + self.program_state.external_constant_buffer.append(buffer_data) + if constant_tag not in self.program_state.external_constant_map: + self.program_state.external_constant_map[constant_tag] = {} + self.program_state.external_constant_map[constant_tag][ + spec.extra_tensor_info.fully_qualified_name # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`. + ] = buffer_idx + + # Tensor is stored in the PTE file. + else: + buffer_idx = len(self.program_state.constant_buffer) + self.program_state.cached_spec_hash_values[hashed] = buffer_idx + self.program_state.constant_buffer.append(buffer) + return buffer_idx - def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue: + def _tensor_spec_to_evalue( + self, spec: TensorSpec, constant_tag: Optional[str] = None + ) -> EValue: """Constructs an EValue from the given TensorSpec.""" allocation_info = None @@ -420,13 +459,18 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue: buffer_idx = self.program_state.cached_spec_mutable_hash_values.get( hashed, -1 ) + elif ( + spec.extra_tensor_info is not None + and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL + ): + buffer_idx = self.program_state.external_constant_hash.get(hashed, -1) else: buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1) # Haven't seen this constant before. if buffer_idx == -1: buffer_idx = self._save_new_const_tensor( - spec, buffer_data, hashed, allocation_info + spec, buffer_data, hashed, allocation_info, constant_tag ) if spec.const and spec.nbytes() != len(buffer_data): @@ -1557,11 +1601,26 @@ def placeholder( https://pytorch.org/docs/stable/fx.html#torch.fx.Graph.placeholder """ spec = self.node.meta["spec"] + constant_tag = self.node.meta.get("constant_tag", None) is_user_input = True if isinstance(target, str) and isinstance(spec, TensorSpec): fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec) + # If the placeholder has a constant_tag, it is external to the PTE file + # and requires a fqn and location=TensorDataLocation.EXTERNAL + if constant_tag is not None: + assert ( + fqn is not None + ), "constant tagged tensors require a fully qualified name" + if spec.extra_tensor_info is None: + spec.extra_tensor_info = ExtraTensorInfo( + fully_qualified_name=fqn, location=TensorDataLocation.EXTERNAL + ) + else: + spec.extra_tensor_info.fully_qualified_name = fqn + spec.extra_tensor_info.location = TensorDataLocation.EXTERNAL + # From the fqn find the corresponding tensor real_tensor = None if fqn in self.exported_program.state_dict: @@ -1599,7 +1658,7 @@ def placeholder( spec.const = not (is_user_input or is_mutable_buffer) evalue = ( - self._tensor_spec_to_evalue(spec) + self._tensor_spec_to_evalue(spec, constant_tag) if isinstance(spec, TensorSpec) else self._constant_to_evalue(spec, None) )