From fd024c5ee13779274a42bb36074471dc694555cc Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 4 Oct 2023 16:35:20 +0000 Subject: [PATCH 1/6] Implement aten::all.dims --- .../function_libs/torch_lib/ops/core.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4317ee257..5341a40b8 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -347,6 +347,32 @@ def aten_all_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL: return result +@torch_op("aten::all.dims", private=True) +def _aten_all_keep_dims(self: TTensor, keepdims: bool) -> BOOL: + """Private implementation for when keepdims is True.""" + + self_rank = op.Size(op.Shape(self)) + if self_rank == 0: + result = op.Cast(self, to=BOOL.dtype) + else: + self_bool = op.Cast(self, to=BOOL.dtype) + self_int = op.Cast(self_bool, to=INT64.dtype) + all_true = op.ReduceMin(self_int, keepdims=keepdims) + result = op.Cast(all_true, to=BOOL.dtype) + return result + + +@torch_op("aten::all.dims", trace_only=True) +def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) -> BOOL: + """all.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" + + if not dim: + return _aten_all_keep_dims(self, keepdim) + for d in dim: + self = aten_all_dim(self, d, keepdim) + return self + + @torch_op("aten::allclose") def aten_allclose( self: TReal, From 259765c5c3e38656e3356326495631258b55539c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 4 Oct 2023 16:40:33 +0000 Subject: [PATCH 2/6] any --- .../function_libs/torch_lib/ops/core.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 5341a40b8..fed5b6bd1 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -471,6 +471,32 @@ def aten_any_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL: return result +@torch_op("aten::any.dims", private=True) +def _aten_any_keep_dims(self: TTensor, keepdims: bool) -> BOOL: + """Private implementation for when keepdims is True.""" + + self_rank = op.Size(op.Shape(self)) + if self_rank == 0: + result = op.Cast(self, to=BOOL.dtype) + else: + self_bool = op.Cast(self, to=BOOL.dtype) + self_int = op.Cast(self_bool, to=INT64.dtype) + any_true = op.ReduceMax(self_int, keepdims=keepdims) + result = op.Cast(any_true, to=BOOL.dtype) + return result + + +@torch_op("aten::any.dims", trace_only=True) +def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) -> BOOL: + """any.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" + + if not dim: + return _aten_any_keep_dims(self, keepdim) + for d in dim: + self = aten_any_dim(self, d, keepdim) + return self + + def _range_supported(dtype: int) -> bool: """Returns true if the dtype is supported by the ONNX Range op.""" return dtype in { From 663df8e00d67eeda9a5a4b9b822fcaa0f3999545 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 4 Oct 2023 16:58:16 +0000 Subject: [PATCH 3/6] update --- .../function_libs/torch_lib/ops/core.py | 60 ++++++++++--------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index fed5b6bd1..a254f875a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -347,9 +347,22 @@ def aten_all_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL: return result -@torch_op("aten::all.dims", private=True) -def _aten_all_keep_dims(self: TTensor, keepdims: bool) -> BOOL: - """Private implementation for when keepdims is True.""" +@torch_op("aten::all.dims", trace_only=True) +def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) -> BOOL: + """all.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" + + if not dim: + return _aten_all_keep_dims(self, keepdim) + for d in dim: + self = aten_all_dim(self, d, keepdim) + return self + + +@torch_op("aten::all.dims") +def aten_all_dims_empty_dim(self: TTensor, keepdims: bool) -> BOOL: + """all.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" + + # dim is None and thus not supplied self_rank = op.Size(op.Shape(self)) if self_rank == 0: @@ -362,17 +375,6 @@ def _aten_all_keep_dims(self: TTensor, keepdims: bool) -> BOOL: return result -@torch_op("aten::all.dims", trace_only=True) -def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) -> BOOL: - """all.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" - - if not dim: - return _aten_all_keep_dims(self, keepdim) - for d in dim: - self = aten_all_dim(self, d, keepdim) - return self - - @torch_op("aten::allclose") def aten_allclose( self: TReal, @@ -471,9 +473,22 @@ def aten_any_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL: return result -@torch_op("aten::any.dims", private=True) -def _aten_any_keep_dims(self: TTensor, keepdims: bool) -> BOOL: - """Private implementation for when keepdims is True.""" +@torch_op("aten::any.dims", trace_only=True) +def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) -> BOOL: + """any.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" + + if not dim: + return _aten_any_keep_dims(self, keepdim) + for d in dim: + self = aten_any_dim(self, d, keepdim) + return self + + +@torch_op("aten::any.dims") +def aten_any_dims_empty_dim(self: TTensor, keepdims: bool) -> BOOL: + """any.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" + + # dim is None and thus not supplied self_rank = op.Size(op.Shape(self)) if self_rank == 0: @@ -486,17 +501,6 @@ def _aten_any_keep_dims(self: TTensor, keepdims: bool) -> BOOL: return result -@torch_op("aten::any.dims", trace_only=True) -def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) -> BOOL: - """any.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" - - if not dim: - return _aten_any_keep_dims(self, keepdim) - for d in dim: - self = aten_any_dim(self, d, keepdim) - return self - - def _range_supported(dtype: int) -> bool: """Returns true if the dtype is supported by the ONNX Range op.""" return dtype in { From 9c47bfb80655885f4e797b50c134cb4054cc548c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 4 Oct 2023 17:54:26 +0000 Subject: [PATCH 4/6] aten_any_dims_empty_dim --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a254f875a..41a271ae0 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -352,7 +352,7 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) """all.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" if not dim: - return _aten_all_keep_dims(self, keepdim) + return aten_all_dims_empty_dim(self, keepdim) for d in dim: self = aten_all_dim(self, d, keepdim) return self @@ -478,7 +478,7 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) """any.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" if not dim: - return _aten_any_keep_dims(self, keepdim) + return aten_any_dims_empty_dim(self, keepdim) for d in dim: self = aten_any_dim(self, d, keepdim) return self From 2a688d13ca380dd27a1257a53d50f44f0c4ac480 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 4 Oct 2023 12:02:51 -0700 Subject: [PATCH 5/6] Update core.py --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 41a271ae0..843b5e636 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -349,7 +349,7 @@ def aten_all_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL: @torch_op("aten::all.dims", trace_only=True) def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) -> BOOL: - """all.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" + """all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor""" if not dim: return aten_all_dims_empty_dim(self, keepdim) @@ -360,7 +360,7 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) @torch_op("aten::all.dims") def aten_all_dims_empty_dim(self: TTensor, keepdims: bool) -> BOOL: - """all.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" + """all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor""" # dim is None and thus not supplied From 8eb01377542ff5365b16c0447e6a70168ab8c64d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 12 Oct 2023 17:53:29 -0700 Subject: [PATCH 6/6] Apply suggestions from code review --- onnxscript/function_libs/torch_lib/ops/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d132c135e..a643d48fd 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -352,14 +352,14 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) """all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor""" if not dim: - return aten_all_dims_empty_dim(self, keepdim) + return aten_all_dims_no_dim(self, keepdim) for d in dim: self = aten_all_dim(self, d, keepdim) return self @torch_op("aten::all.dims") -def aten_all_dims_empty_dim(self: TTensor, keepdims: bool) -> BOOL: +def aten_all_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: """all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor""" # dim is None and thus not supplied @@ -478,14 +478,14 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) """any.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" if not dim: - return aten_any_dims_empty_dim(self, keepdim) + return aten_any_dims_no_dim(self, keepdim) for d in dim: self = aten_any_dim(self, d, keepdim) return self @torch_op("aten::any.dims") -def aten_any_dims_empty_dim(self: TTensor, keepdims: bool) -> BOOL: +def aten_any_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: """any.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor""" # dim is None and thus not supplied