forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpointwise_ops.py
103 lines (92 loc) · 4.41 KB
/
pointwise_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from functools import partial
import torch
from torch.sparse import SparseSemiStructuredTensor
from torchao.sparsity.training.autograd import semi_structured_sparsify_like
def _semi_sparse_pointwise_op(
func, types, args=(), kwargs=None, sparsify_like_args_list=()
):
"""
adds pointwise op support for semi-structured tensors.
Assumes that at least one of the arguments in arg is a SparseSemiStructuredTensor.
The last instance of a SparseSemiStructuredTensor is used as the reference mask to sparsify the others tensors passed in args.
sparsify_like_args_list is used to specify which arguments to sparsify like the reference tensor.
"""
reference_sparse_tensor = None
for tensor in args:
if isinstance(tensor, SparseSemiStructuredTensor):
reference_sparse_tensor = tensor
assert reference_sparse_tensor is not None
def handle_arg(i, tensor):
if isinstance(tensor, torch.Tensor):
# For pointwise ops, dense tensors will be sparsified to match the sparsity pattern of the reference tensor
# if they are specified in `sparsify_like_args_list`.
if not isinstance(tensor, SparseSemiStructuredTensor):
if i in sparsify_like_args_list:
tensor = semi_structured_sparsify_like(
tensor, reference_sparse_tensor
)
else:
raise ValueError(
f"Operation {func.__module__}.{func.__name__} on {type(reference_sparse_tensor)} requires all operands to "
f"be {type(reference_sparse_tensor)}, but operand {i} is a {type(tensor)}"
)
# If the tensor is a SparseSemiStructuredTensor, we make sure that the sparsity pattern is the same as the reference tensor.
# Pointwise ops on tensors containing two different sparsity patterns is not defined, as in the case of addition, where
# adding two semi-structured sparse tensors yields a result that is not semi-structured sparse.
else:
if (
tensor.compressed_swizzled_bitmask is None
or reference_sparse_tensor.compressed_swizzled_bitmask is None
or tensor.compressed_swizzled_bitmask.data_ptr()
!= reference_sparse_tensor.compressed_swizzled_bitmask.data_ptr()
or tensor.compressed_swizzled_bitmask.stride()
!= reference_sparse_tensor.compressed_swizzled_bitmask.stride()
):
raise ValueError(
f"Operation {func.__module__}.{func.__name__} on {type(reference_sparse_tensor)} requires all operands to be "
f"{type(reference_sparse_tensor)} with the same sparsity pattern"
)
return tensor
args_updated = [handle_arg(i, tensor) for i, tensor in enumerate(args)]
return reference_sparse_tensor.__class__(
reference_sparse_tensor.shape,
func(
*[
x.packed if isinstance(x, SparseSemiStructuredTensor) else x
for x in args_updated
]
),
reference_sparse_tensor.meta,
func(
*[
x.packed_t if isinstance(x, SparseSemiStructuredTensor) else x
for x in args_updated
]
),
reference_sparse_tensor.meta_t,
reference_sparse_tensor.compressed_swizzled_bitmask,
)
# Add pointwise ops to the dispatch table
CUTLASS_POINTWISE_OP_DISPATCH_TABLE = {
torch.ops.aten.relu: _semi_sparse_pointwise_op,
torch.ops.aten.gelu: _semi_sparse_pointwise_op,
torch.ops.aten.silu: _semi_sparse_pointwise_op,
torch.ops.aten.mul: partial(
# `mul` BW in swiglu
_semi_sparse_pointwise_op,
sparsify_like_args_list=(0, 1),
),
torch.ops.aten.add: _semi_sparse_pointwise_op,
# Note: for these ops, we allow the gradient to come in as a `torch.Tensor`
# and we will run the sparsification right before calling the BW aten func
torch.ops.aten.gelu_backward: partial(
_semi_sparse_pointwise_op, sparsify_like_args_list=(0,)
),
torch.ops.aten.silu_backward: partial(
_semi_sparse_pointwise_op, sparsify_like_args_list=(0, 1)
),
torch.ops.aten.threshold_backward: partial( # relu BW
_semi_sparse_pointwise_op,
sparsify_like_args_list=(0,),
),
}