Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch-frontend] fx export FakeTensor weight as SplatElementsAttr. #299

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 54 additions & 13 deletions frontends/torch-frontend/third_party/patches/fx_importer.patch
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py
index 381f8f9ad..99e75e8bd 100644
index 381f8f9a..a5e2d32b 100644
--- a/python/torch_mlir/extras/fx_importer.py
+++ b/python/torch_mlir/extras/fx_importer.py
@@ -52,6 +52,10 @@ from torch._subclasses import (
Expand Down Expand Up @@ -95,23 +95,64 @@ index 381f8f9ad..99e75e8bd 100644
# check support for bfloat16
assert not (
tensor.dtype == torch.bfloat16 and ml_dtypes is None
@@ -1732,11 +1776,17 @@ def _make_vtensor_literal_op(
@@ -1732,29 +1776,42 @@ def _make_vtensor_literal_op(
# detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
- np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
- # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not
- # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling
- # 0d tensors.
- if np_tensor.size == 1:
+
+ # NOTE: if we torch.export a torch.nn.Module under fake mode, the parameters in the fx.GraphModule will be FakeTensor.
+ # So we specifically handle FakeTensor here by randomly generating a tensor of the same shape and dtype.
+ # So we specifically handle FakeTensor here by creating a splat DenseElementsAttr with value 0.
+ if isinstance(tensor, TorchFakeTensor):
+ np_tensor = np.random.rand(*list(tensor.shape)).astype(npy_dtype)
+ else:
+ np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
# One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not
# support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling
# 0d tensors.
- if np_tensor.size == 1:
+ if True:
+ array = np.array([0]).astype(npy_dtype)
try:
dtype = tensor.dtype
element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
- dtype = tensor.dtype
- element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
+ element_type = TORCH_DTYPE_TO_MLIR_TYPE[tensor.dtype]()
except KeyError:
- raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")
+ raise TypeError(f"Could not map Torch dtype {tensor.dtype} to an MLIR type")
elements_attr = DenseElementsAttr.get(
- type=element_type, array=np_tensor, shape=np_tensor.shape
+ array=array, type=element_type, shape=list(tensor.shape)
)
else:
- bytes_view = np_tensor.view(npy_dtype)
- tensor_type = create_mlir_tensor_type(tensor)
- shape_desc = "_".join([str(d) for d in tensor.shape])
- blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
- elements_attr = DenseResourceElementsAttr.get_from_buffer(
- bytes_view,
- blob_name,
- tensor_type,
- )
+ np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
+ # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not
+ # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling
+ # 0d tensors.
+ if np_tensor.size == 1:
+ try:
+ dtype = tensor.dtype
+ element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
+ except KeyError:
+ raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")
+ elements_attr = DenseElementsAttr.get(
+ type=element_type, array=np_tensor, shape=np_tensor.shape
+ )
+ else:
+ bytes_view = np_tensor.view(npy_dtype)
+ tensor_type = create_mlir_tensor_type(tensor)
+ shape_desc = "_".join([str(d) for d in tensor.shape])
+ blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
+ elements_attr = DenseResourceElementsAttr.get_from_buffer(
+ bytes_view,
+ blob_name,
+ tensor_type,
+ )
mapping.value = elements_attr
else:
elements_attr = mapping.value
Loading