diff --git a/export_aoti.py b/export_aoti.py index 9e2e6ab22..1d6e4990c 100644 --- a/export_aoti.py +++ b/export_aoti.py @@ -46,7 +46,7 @@ def export_model( # with torch.device(device): # model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) - device= + device="cpu" input = ( torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device), torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),