Skip to content

Commit

Permalink
consistent wrapper for ET & AOTI
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gschwind committed Apr 7, 2024
1 parent f5ef619 commit 6aeeb4b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 21 deletions.
21 changes: 20 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ def device_sync(device):
else:
print(f"device={device} is not yet suppported")


class model_wrapper(nn.Module):
def __init__(self, model, device):
super().__init__()

max_seq_length = 350
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

self.model = model
# init model here if necessary

def forward(self, idx, input_pos):
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = self.model(idx, input_pos)
return logits # sample(logits, **sampling_kwargs)


def main(checkpoint_path, device, quantize = "{ }", args = None):
assert checkpoint_path.is_file(), checkpoint_path
Expand All @@ -53,7 +71,8 @@ def main(checkpoint_path, device, quantize = "{ }", args = None):
print(f"Time to load model: {time.time() - t0:.02f} seconds")

quantize_model(model, args.quantize)

model = model_wrapper(model, device=device)

output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path

Expand Down
4 changes: 2 additions & 2 deletions export_aoti.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def device_sync(device):

def export_model(model: nn.Module, device, output_path, args=None):
max_seq_length = 350
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
# with torch.device(device):
# model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

input = (
torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device),
Expand Down
18 changes: 0 additions & 18 deletions export_et.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,24 +70,6 @@ def materialze_broadcast_of_rope_freq_cis(
return module


class model_wrapper(nn.Module):
def __init__(self, model, device):
super().__init__()

max_seq_length = 350
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

self.model = model
# init model here if necessary

def forward(self, x, input_pos):
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = self.model(x, input_pos)
return logits # sample(logits, **sampling_kwargs)


def canonical_path(path):
return path

Expand Down

0 comments on commit 6aeeb4b

Please sign in to comment.