Skip to content

Commit

Permalink
add packaging to aoti
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Apr 12, 2024
1 parent 479b24b commit 7c34043
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__/
192 changes: 192 additions & 0 deletions _package_aoti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import glob
import os
import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch._inductor
import torch.utils._pytree as pytree
from torch.export._tree_utils import reorder_kwargs
from torch.export import ExportedProgram
from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact


from _pt2_archive_constants import (
AOTINDUCTOR_DIR,
ARCHIVE_ROOT_NAME,
CONSTANTS_DIR,
MODELS_FILENAME_FORMAT,
SAMPLE_INPUTS_DIR,
WEIGHTS_DIR,
)


ARCHIVE_VERSION = 0

class PT2ArchiveWriter:
def __init__(self, archive_path: str):
self.archive_file = torch._C.PyTorchFileWriter(archive_path)
self.archive_file.set_min_version(ARCHIVE_VERSION)
self.write_string("archive_format", "pt2")

def __enter__(self):
return self

def __exit__(self, *args):
self.close()

def write_bytes(self, name: str, data: bytes) -> None:
assert isinstance(data, bytes), f"Expected bytes but got {type(data)}"
self.archive_file.write_record(name, data, len(data))

def write_string(self, name: str, data: str) -> None:
assert isinstance(data, str), f"Expected string but got {type(data)}"
data_bytes = data.encode()
self.write_bytes(name, data_bytes)

def write_file(self, name: str, file_path: str) -> None:
"""
Copy a file into the archive.
name: The destination file inside the archive.
file_path: The source file on disk.
"""
assert os.path.isfile(file_path), f"{file_path} is not a valid file path"

with open(file_path, "rb") as f:
file_bytes = f.read()
self.write_bytes(name, file_bytes)

def close(self) -> None:
self.archive_file.write_end_of_file()


class PT2ArchiveReader:
def __init__(self, archive_path: str):
self.archive_file = torch._C.PyTorchFileReader(archive_path)
assert self.read_string("archive_format") == "pt2", "Invalid archive format"

def __enter__(self):
return self

def __exit__(self, *args):
# torch._C.PyTorchFileReader doesn't have a close method
pass

def read_bytes(self, name: str) -> bytes:
return self.archive_file.get_record(name)

def read_string(self, name: str) -> str:
data = self.read_bytes(name)
return data.decode()

def get_file_names(self) -> List[str]:
return self.archive_file.get_all_records()


def _package_exported_program(
archive_writer: PT2ArchiveWriter, exported_program: ExportedProgram
) -> None:
exported_artifact: SerializedArtifact = serialize(exported_program)
archive_writer.write_bytes(MODELS_FILENAME_FORMAT.format("model"), exported_artifact.exported_program)
archive_writer.write_bytes(os.path.join(WEIGHTS_DIR, "weights.pt"), exported_artifact.state_dict)
archive_writer.write_bytes(os.path.join(CONSTANTS_DIR, "constants.pt"), exported_artifact.constants)
archive_writer.write_bytes(os.path.join(SAMPLE_INPUTS_DIR, "example_inputs.pt"), exported_artifact.example_inputs)


def _package_aoti_files(archive_writer: PT2ArchiveWriter, so_path: str):
cpp_file_path = so_path[:-3] + ".cpp"
extern_nodes_file_path = so_path[:-3] + ".json"
work_dir = pathlib.Path(so_path).parent
cubin_file_paths = glob.glob(f"{work_dir}/*.cubin")

package_files = [so_path, cpp_file_path]
package_files.extend(cubin_file_paths)

if os.path.isfile(extern_nodes_file_path):
package_files.append(extern_nodes_file_path)

for path in package_files:
filename = os.path.basename(path)
archive_writer.write_file(f"{AOTINDUCTOR_DIR}{filename}", path)


def _extract_exported_program(archive_reader: PT2ArchiveReader) -> ExportedProgram:
exported_program_bytes = archive_reader.read_bytes(MODELS_FILENAME_FORMAT.format("model"))
state_dict_bytes = archive_reader.read_bytes(os.path.join(WEIGHTS_DIR, "weights.pt"))
constants_bytes = archive_reader.read_bytes(os.path.join(CONSTANTS_DIR, "constants.pt"))
example_inputs_bytes = archive_reader.read_bytes(os.path.join(SAMPLE_INPUTS_DIR, "example_inputs.pt"))

artifact: SerializedArtifact = SerializedArtifact(
exported_program_bytes,
state_dict_bytes,
constants_bytes,
example_inputs_bytes,
)

deserialized_exported_program = deserialize(artifact)
return deserialized_exported_program


def _extract_so(archive_reader: PT2ArchiveReader, device: str) -> Callable:
tmp_output_dir = pathlib.Path("/tmp/aotinductor_loaded_model")
tmp_output_dir.mkdir(exist_ok=True)

file_names = archive_reader.get_file_names()
aoti_files = [file for file in file_names if file.startswith(AOTINDUCTOR_DIR)]

so_path = None
for file in aoti_files:
filename = os.path.basename(file)
with open(tmp_output_dir / filename, 'wb') as f:
f.write(archive_reader.read_bytes(file))
if file.endswith('.so'):
assert so_path is None
so_path = tmp_output_dir / filename
assert so_path is not None
so_path = str(so_path)

if device == "cpu":
runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
elif device == "cuda" or device.startswith("cuda:"):
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
else:
raise RuntimeError("Unsupported device " + device)

def optimized(*args, **kwargs):
call_spec = runner.get_call_spec() # type: ignore[attr-defined]
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
return pytree.tree_unflatten(flat_outputs, out_spec)

return optimized


def aoti_compile(
exported_program: ExportedProgram,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
*,
options: Optional[Dict[str, Any]] = None,
):
archive_path = options["aot_inductor.output_path"]
options["aot_inductor.output_path"] = ""

so_path = torch._inductor.aot_compile(
exported_program.module(), args, kwargs, options=options
)

with PT2ArchiveWriter(archive_path) as archive_writer:
# _package_exported_program(archive_writer, exported_program)
_package_aoti_files(archive_writer, so_path)

return archive_path


def aoti_load(path: str, device: str):
with PT2ArchiveReader(path) as archive_reader:
# exported_program = _extract_exported_program(archive_reader)
optimized = _extract_so(archive_reader, device)

return optimized
36 changes: 36 additions & 0 deletions _pt2_archive_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# This file codify PT2 Inference Archive Spec
# https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit?usp=sharing

# Naming convention
# *_DIR: path to a folder, e.g. "data/aotinductor/"
# *_PATH: absolute path to a file, e.g. "models/merge.json"
# *_FORMAT: naming format of a file, e.g. "models/{}.json"

ARCHIVE_ROOT_NAME: str = "package"

# Archive format
ARCHIVE_FORMAT_PATH: str = "archive_format"

# Model definitions
MODELS_DIR: str = "models/"
MODELS_FILENAME_FORMAT: str = "models/{}.json"; # {model_name}

# AOTInductor artifacts
AOTINDUCTOR_DIR: str = "data/aotinductor/"

# weights, including parameters and buffers
WEIGHTS_DIR: str = "data/weights/"
WEIGHT_FILENAME_PREFIX: str = "weight_"

# constants, including tensor_constants, non-persistent buffers and script objects
CONSTANTS_DIR: str = "data/constants/"
TENSOR_CONSTANT_FILENAME_PREFIX: str = "tensor_"
CUSTOM_OBJ_FILENAME_PREFIX: str = "custom_obj_"

# sample inputs
SAMPLE_INPUTS_DIR: str = "data/sample_inputs/"
SAMPLE_INPUTS_FILENAME_FORMAT: str = "data/sample_inputs/{}.pt"; # {model_name}

# extra folder
EXTRA_DIR: str = "extra/"
MODULE_INFO_PATH: str = "extra/module_info.json"
16 changes: 9 additions & 7 deletions export_aoti.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from model import Transformer

from _package_aoti import aoti_compile

default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'


Expand Down Expand Up @@ -47,11 +49,11 @@ def export_model(model: nn.Module, device, output_path, args=None):
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}}

so = torch._export.aot_compile(
model,
args=input,
options={"aot_inductor.output_path": output_path},
dynamic_shapes=dynamic_shapes,
ep = torch.export.export(
model, args=input, dynamic_shapes=dynamic_shapes,
)
package_path = aoti_compile(
ep, input, options={"aot_inductor.output_path": output_path}
)
print(f"The generated DSO model can be found at: {so}")
return so
print(f"The generated PT2 model can be found at: {package_path}")
return package_path
5 changes: 3 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ def main(
# attributes will NOT be seen on by AOTI-compiled forward
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
model.forward = torch._export.aot_load(str(dso_path.absolute()), device)
from _package_aoti import aoti_load
model.forward = aoti_load(str(dso_path.absolute()), device)
except:
raise RuntimeError(f"Failed to load AOTI compiled {dso_path}")
elif pte_path:
Expand All @@ -387,7 +388,7 @@ def main(
# dtype:
if model_dtype:
model.to(dtype=model_dtype)

if is_speculative:
draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
else:
Expand Down

0 comments on commit 7c34043

Please sign in to comment.