From 8f39da5fc6fe2400bb1f4fdf7522c40539acafd4 Mon Sep 17 00:00:00 2001 From: "wujiawei.aml" Date: Fri, 31 May 2024 00:05:27 +0800 Subject: [PATCH 1/2] [torch-frontend] fx export FakeTensor weight as SplatElementsAttr From 1e74803aa48c06d11a10e9a7b7a1e11fc965a53b Mon Sep 17 00:00:00 2001 From: "wujiawei.aml" Date: Fri, 31 May 2024 00:11:10 +0800 Subject: [PATCH 2/2] [torch-frontend] fx export FakeTensor weight as SplatElementsAttr --- .../third_party/patches/fx_importer.patch | 67 +++++++++++++++---- 1 file changed, 54 insertions(+), 13 deletions(-) diff --git a/frontends/torch-frontend/third_party/patches/fx_importer.patch b/frontends/torch-frontend/third_party/patches/fx_importer.patch index 3fefbc804..5b72dc6b0 100644 --- a/frontends/torch-frontend/third_party/patches/fx_importer.patch +++ b/frontends/torch-frontend/third_party/patches/fx_importer.patch @@ -1,5 +1,5 @@ diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py -index 381f8f9ad..99e75e8bd 100644 +index 381f8f9a..a5e2d32b 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -52,6 +52,10 @@ from torch._subclasses import ( @@ -95,23 +95,64 @@ index 381f8f9ad..99e75e8bd 100644 # check support for bfloat16 assert not ( tensor.dtype == torch.bfloat16 and ml_dtypes is None -@@ -1732,11 +1776,17 @@ def _make_vtensor_literal_op( +@@ -1732,29 +1776,42 @@ def _make_vtensor_literal_op( # detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw # buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as # desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above) - np_tensor = np.array(tensor.tolist()).astype(npy_dtype) +- # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not +- # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling +- # 0d tensors. +- if np_tensor.size == 1: + + # NOTE: if we torch.export a torch.nn.Module under fake mode, the parameters in the fx.GraphModule will be FakeTensor. -+ # So we specifically handle FakeTensor here by randomly generating a tensor of the same shape and dtype. ++ # So we specifically handle FakeTensor here by creating a splat DenseElementsAttr with value 0. + if isinstance(tensor, TorchFakeTensor): -+ np_tensor = np.random.rand(*list(tensor.shape)).astype(npy_dtype) -+ else: -+ np_tensor = np.array(tensor.tolist()).astype(npy_dtype) - # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not - # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling - # 0d tensors. -- if np_tensor.size == 1: -+ if True: ++ array = np.array([0]).astype(npy_dtype) try: - dtype = tensor.dtype - element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() +- dtype = tensor.dtype +- element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() ++ element_type = TORCH_DTYPE_TO_MLIR_TYPE[tensor.dtype]() + except KeyError: +- raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") ++ raise TypeError(f"Could not map Torch dtype {tensor.dtype} to an MLIR type") + elements_attr = DenseElementsAttr.get( +- type=element_type, array=np_tensor, shape=np_tensor.shape ++ array=array, type=element_type, shape=list(tensor.shape) + ) + else: +- bytes_view = np_tensor.view(npy_dtype) +- tensor_type = create_mlir_tensor_type(tensor) +- shape_desc = "_".join([str(d) for d in tensor.shape]) +- blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}" +- elements_attr = DenseResourceElementsAttr.get_from_buffer( +- bytes_view, +- blob_name, +- tensor_type, +- ) ++ np_tensor = np.array(tensor.tolist()).astype(npy_dtype) ++ # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not ++ # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling ++ # 0d tensors. ++ if np_tensor.size == 1: ++ try: ++ dtype = tensor.dtype ++ element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() ++ except KeyError: ++ raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") ++ elements_attr = DenseElementsAttr.get( ++ type=element_type, array=np_tensor, shape=np_tensor.shape ++ ) ++ else: ++ bytes_view = np_tensor.view(npy_dtype) ++ tensor_type = create_mlir_tensor_type(tensor) ++ shape_desc = "_".join([str(d) for d in tensor.shape]) ++ blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}" ++ elements_attr = DenseResourceElementsAttr.get_from_buffer( ++ bytes_view, ++ blob_name, ++ tensor_type, ++ ) + mapping.value = elements_attr + else: + elements_attr = mapping.value