forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 0
/
scatter_epilogue.hpp
222 lines (189 loc) · 8.34 KB
/
scatter_epilogue.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing elementwise operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cute/tensor.hpp"
#include "cute/numeric/numeric_types.hpp"
#include "gather_tensor.hpp"
namespace cutlass::epilogue::collective {
/// Applies an element wise operation to all elements within the fragment
/// and scatter-writes them out to destination storage.
/// GatherC and ScatterD are types of user-defined functions that apply the
/// transoformation of the strided coordinate (e.g. through an index array).
template <
class StrideC_,
class StrideD_,
class ThreadEpilogueOp_,
class EpilogueSchedule_,
class GatherC_,
class ScatterD_
>
class EpilogueGatherScatter {
public:
//
// Type Aliases
//
using EpilogueSchedule = EpilogueSchedule_;
// derived types of output thread level operator
using ThreadEpilogueOp = ThreadEpilogueOp_;
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
using ElementScalar = ElementCompute;
using ElementC = typename ThreadEpilogueOp::ElementC;
using StrideC = StrideC_;
using ElementD = typename ThreadEpilogueOp::ElementD;
using StrideD = StrideD_;
// Every epilogue needs these two GmemTiledCopy{C,D} aliases.
// If you don't know what they should be, just use void.
using GmemTiledCopyC = void;
using GmemTiledCopyD = void;
using GatherC = GatherC_;
using ScatterD = ScatterD_;
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
struct SharedStorage { };
// Host side epilogue arguments
struct Arguments {
typename ThreadEpilogueOp::Params thread_params{};
ElementC const* ptr_C = nullptr;
StrideC dC{};
ElementD* ptr_D = nullptr;
StrideD dD{};
GatherC gather_C{};
ScatterD scatter_D{};
};
// Device side epilogue params
using Params = Arguments;
//
// Methods
//
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(
[[maybe_unused]] ProblemShape const& _,
Arguments const& args,
[[maybe_unused]] void* workspace) {
return args;
}
template<class ProblemShape>
CUTLASS_HOST_DEVICE static bool
can_implement(
[[maybe_unused]] ProblemShape const& problem_shape,
[[maybe_unused]] Arguments const& args) {
return true;
}
CUTLASS_HOST_DEVICE
EpilogueGatherScatter(Params const& params_) : params(params_) { }
template<
class ProblemShapeMNKL,
class BlockShapeMNK,
class BlockCoordMNKL,
class FrgEngine, class FrgLayout,
class TiledMma,
class ResidueMNK
>
CUTLASS_DEVICE void
operator()(
ProblemShapeMNKL problem_shape_mnkl,
BlockShapeMNK blk_shape_MNK,
BlockCoordMNKL blk_coord_mnkl,
cute::Tensor<FrgEngine, FrgLayout> const& accumulators,
TiledMma tiled_mma,
ResidueMNK residue_mnk,
int thread_idx,
char* smem_buf)
{
using namespace cute;
using X = Underscore;
static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
(void) smem_buf;
ThreadEpilogueOp epilogue_op{params.thread_params};
// Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl);
auto N = get<1>(problem_shape_mnkl);
auto L = get<3>(problem_shape_mnkl);
auto stride_c = detail::get_epilogue_stride<EpilogueSchedule>(params.dC);
auto stride_d = detail::get_epilogue_stride<EpilogueSchedule>(params.dD);
// Represent the full output tensor
Tensor mC_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c, params.gather_C); // (m,n,l)
Tensor mD_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d, params.scatter_D); // (m,n,l)
Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
// Slice to get the tile this CTA is responsible for
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
// Partition source and destination tiles to match the accumulator partitioning
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N)
Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N)
static_assert(is_static<FrgLayout>::value, "Accumulator layout must be static");
CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD),
"Source and destination must have the same number of elements.");
CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators),
"Accumulator count must have the same destination element count.");
// Make an identity coordinate tensor for predicating our output MN tile
auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
Tensor tCcD = thr_mma.partition_C(cD);
// source is needed
if (epilogue_op.is_source_needed()) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators); ++i) {
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
tCgD(i) = epilogue_op(accumulators(i), tCgC(i));
}
}
}
// source is not needed, avoid load
else {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators); ++i) {
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
tCgD(i) = epilogue_op(accumulators(i));
}
}
}
}
private:
Params params;
};
} // namespace cutlass::epilogue::collective