From 6aeeb4be97fa1daad48cbc5e8d04a5d364f10a3b Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Sat, 6 Apr 2024 19:21:51 -0700 Subject: [PATCH] consistent wrapper for ET & AOTI --- export.py | 21 ++++++++++++++++++++- export_aoti.py | 4 ++-- export_et.py | 18 ------------------ 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/export.py b/export.py index b346e7a29..a495f6ead 100644 --- a/export.py +++ b/export.py @@ -37,6 +37,24 @@ def device_sync(device): else: print(f"device={device} is not yet suppported") + +class model_wrapper(nn.Module): + def __init__(self, model, device): + super().__init__() + + max_seq_length = 350 + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + self.model = model + # init model here if necessary + + def forward(self, idx, input_pos): + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = self.model(idx, input_pos) + return logits # sample(logits, **sampling_kwargs) + def main(checkpoint_path, device, quantize = "{ }", args = None): assert checkpoint_path.is_file(), checkpoint_path @@ -53,7 +71,8 @@ def main(checkpoint_path, device, quantize = "{ }", args = None): print(f"Time to load model: {time.time() - t0:.02f} seconds") quantize_model(model, args.quantize) - + model = model_wrapper(model, device=device) + output_pte_path = args.output_pte_path output_dso_path = args.output_dso_path diff --git a/export_aoti.py b/export_aoti.py index 23789b477..7a5306b5b 100644 --- a/export_aoti.py +++ b/export_aoti.py @@ -33,8 +33,8 @@ def device_sync(device): def export_model(model: nn.Module, device, output_path, args=None): max_seq_length = 350 - with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) +# with torch.device(device): +# model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) input = ( torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device), diff --git a/export_et.py b/export_et.py index d106787ab..813ea2fc9 100644 --- a/export_et.py +++ b/export_et.py @@ -70,24 +70,6 @@ def materialze_broadcast_of_rope_freq_cis( return module -class model_wrapper(nn.Module): - def __init__(self, model, device): - super().__init__() - - max_seq_length = 350 - with torch.device(device): - model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - - self.model = model - # init model here if necessary - - def forward(self, x, input_pos): - # input_pos: [B, 1] - assert input_pos.shape[-1] == 1 - logits = self.model(x, input_pos) - return logits # sample(logits, **sampling_kwargs) - - def canonical_path(path): return path