-
Notifications
You must be signed in to change notification settings - Fork 0
/
Subgroups.cpp
72 lines (54 loc) · 1.93 KB
/
Subgroups.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include "Subgroups.h"
#include "ComputeUtil.h"
struct UniformData {
uint32_t count;
uint32_t step;
uint32_t padding0;
uint32_t padding1;
};
void SubgroupSort::Init(const wgpu::Device& device, const wgpu::Buffer& inputBuffer, uint32_t inputSize) {
auto bgl = utils::MakeBindGroupLayout(
device, "SubgroupsSort", {
{ 0, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Storage },
{ 1, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Uniform },
});
pipeline = ComputeUtil::CreatePipeline(device, bgl,
#include "subgroups/sort.wgsl"
, "Sort::Subgroups"
);
uniformBuffer = utils::CreateBuffer(device, sizeof(UniformData), wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "SubgroupUniforms");
bindGroup = utils::MakeBindGroup(
device, bgl,
{
{ 0, inputBuffer },
{ 1, uniformBuffer }
});
}
void SubgroupSort::Upload(const wgpu::Device& device, uint32_t count) {
uint32_t numWgs = ComputeUtil::div_up(count, wgSize);
UniformData data;
data.count = count;
if (numWgs >= 0xffffu) {
uint32_t x = static_cast<uint32_t>(std::sqrt(static_cast<double>(numWgs)));
data.step = x;
} else {
data.step = 1u;
}
device.GetQueue().WriteBuffer(uniformBuffer, 0, &data, sizeof(UniformData));
}
void SubgroupSort::Sort(const wgpu::CommandEncoder& encoder, const wgpu::QuerySet& querySet, uint32_t count) {
auto sortPass = ComputeUtil::CreateTimestampedComputePass(encoder, querySet, 0);
uint32_t numWgs = ComputeUtil::div_up(count, wgSize);
sortPass.SetPipeline(pipeline);
sortPass.SetBindGroup(0, bindGroup);
if (numWgs >= 0xffffu) {
uint32_t x = static_cast<uint32_t>(std::sqrt(static_cast<double>(numWgs)));
uint32_t y = x + 1;
sortPass.DispatchWorkgroups(x, y);
} else {
sortPass.DispatchWorkgroups(numWgs);
}
sortPass.End();
}
void SubgroupSort::Dispose() {
}