diff --git a/export.py b/export.py index dbcc6e87c..f8de9e48b 100644 --- a/export.py +++ b/export.py @@ -23,7 +23,7 @@ from model import Transformer from generate import _load_model, decode_one_token -from quantize import quantize_model +from quantize import quantize_model, name_to_dtype from torch._export import capture_pre_autograd_graph default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu' @@ -77,7 +77,7 @@ def main(checkpoint_path, device, quantize = "{ }", args = None): # dtype: if args.dtype: model.to(dtype=name_to_dtype(args.dtype)) - + model = model_wrapper(model, device=device) output_pte_path = args.output_pte_path