Skip to content

Commit

Permalink
Rapid descent
Browse files Browse the repository at this point in the history
- Implement single (but actually 2) pass downsampling
  • Loading branch information
Jozufozu committed Sep 14, 2024
1 parent a527af5 commit ddb0450
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ public class IndirectPrograms extends AtomicReferenceCounted {
private static final ResourceLocation CULL_SHADER_MAIN = Flywheel.rl("internal/indirect/cull.glsl");
private static final ResourceLocation APPLY_SHADER_MAIN = Flywheel.rl("internal/indirect/apply.glsl");
private static final ResourceLocation SCATTER_SHADER_MAIN = Flywheel.rl("internal/indirect/scatter.glsl");
private static final ResourceLocation DEPTH_REDUCE_SHADER_MAIN = Flywheel.rl("internal/indirect/depth_reduce.glsl");
public static final List<ResourceLocation> UTIL_SHADERS = List.of(APPLY_SHADER_MAIN, SCATTER_SHADER_MAIN, DEPTH_REDUCE_SHADER_MAIN);
private static final ResourceLocation DOWNSAMPLE_FIRST = Flywheel.rl("internal/indirect/downsample_first.glsl");
private static final ResourceLocation DOWNSAMPLE_SECOND = Flywheel.rl("internal/indirect/downsample_second.glsl");
public static final List<ResourceLocation> UTIL_SHADERS = List.of(APPLY_SHADER_MAIN, SCATTER_SHADER_MAIN, DOWNSAMPLE_FIRST, DOWNSAMPLE_SECOND);

private static final Compile<InstanceType<?>> CULL = new Compile<>();
private static final Compile<ResourceLocation> UTIL = new Compile<>();
Expand Down Expand Up @@ -184,9 +185,14 @@ public GlProgram getScatterProgram() {
return utils.get(SCATTER_SHADER_MAIN);
}

public GlProgram getDepthReduceProgram() {
return utils.get(DEPTH_REDUCE_SHADER_MAIN);
public GlProgram getDownsampleFirstProgram() {
return utils.get(DOWNSAMPLE_FIRST);
}

public GlProgram getDownsampleSecondProgram() {
return utils.get(DOWNSAMPLE_SECOND);
}

@Override
protected void _delete() {
pipeline.values()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
import net.minecraft.client.Minecraft;

public class DepthPyramid {
private final GlProgram depthReduceProgram;
private final GlProgram downsampleFirstProgram;
private final GlProgram downsampleSecondProgram;

public int pyramidTextureId = -1;

private int lastWidth = -1;
private int lastHeight = -1;

public DepthPyramid(GlProgram depthReduceProgram) {
this.depthReduceProgram = depthReduceProgram;
public DepthPyramid(GlProgram downsampleFirstProgram, GlProgram downsampleSecondProgram) {
this.downsampleFirstProgram = downsampleFirstProgram;
this.downsampleSecondProgram = downsampleSecondProgram;
}

public void generate() {
Expand All @@ -37,26 +39,43 @@ public void generate() {

GL46.glMemoryBarrier(GL46.GL_FRAMEBUFFER_BARRIER_BIT);

GlTextureUnit.T1.makeActive();
GlTextureUnit.T0.makeActive();
GlStateManager._bindTexture(depthBufferId);

depthReduceProgram.bind();
downsampleFirstProgram.bind();
downsampleFirstProgram.setUInt("max_mip_level", mipLevels);

for (int i = 0; i < mipLevels; i++) {
int mipWidth = mipSize(width, i);
int mipHeight = mipSize(height, i);
for (int i = 0; i < Math.min(6, mipLevels); i++) {
GL46.glBindImageTexture(i + 1, pyramidTextureId, i, false, 0, GL32.GL_WRITE_ONLY, GL32.GL_R32F);
}

int srcTexture = (i == 0) ? depthBufferId : pyramidTextureId;
GlStateManager._bindTexture(srcTexture);
GL46.glDispatchCompute(MoreMath.ceilingDiv(width << 1, 64), MoreMath.ceilingDiv(height << 1, 64), 1);

GL46.glBindImageTexture(0, pyramidTextureId, i, false, 0, GL32.GL_WRITE_ONLY, GL32.GL_R32F);
if (mipLevels < 7) {
GL46.glMemoryBarrier(GL46.GL_TEXTURE_FETCH_BARRIER_BIT);

depthReduceProgram.setVec2("imageSize", mipWidth, mipHeight);
depthReduceProgram.setInt("lod", Math.max(0, i - 1));
return;
}

GL46.glDispatchCompute(MoreMath.ceilingDiv(mipWidth, 8), MoreMath.ceilingDiv(mipHeight, 8), 1);
GL46.glMemoryBarrier(GL46.GL_SHADER_IMAGE_ACCESS_BARRIER_BIT);

GL46.glMemoryBarrier(GL46.GL_TEXTURE_FETCH_BARRIER_BIT);
downsampleSecondProgram.bind();
downsampleSecondProgram.setUInt("max_mip_level", mipLevels);

// Note: mip_6 in the shader is actually mip level 5 in the texture
GL46.glBindImageTexture(0, pyramidTextureId, 5, false, 0, GL32.GL_READ_ONLY, GL32.GL_R32F);
for (int i = 6; i < Math.min(12, mipLevels); i++) {
GL46.glBindImageTexture(i - 5, pyramidTextureId, i, false, 0, GL32.GL_WRITE_ONLY, GL32.GL_R32F);
}

GL46.glDispatchCompute(1, 1, 1);

GL46.glMemoryBarrier(GL46.GL_TEXTURE_FETCH_BARRIER_BIT);
}

public void bindForCull() {
GlTextureUnit.T0.makeActive();
GlStateManager._bindTexture(pyramidTextureId);
}

public void delete() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import java.util.List;
import java.util.Map;

import org.lwjgl.opengl.GL46;

import dev.engine_room.flywheel.api.backend.Engine;
import dev.engine_room.flywheel.api.instance.Instance;
import dev.engine_room.flywheel.api.instance.InstanceType;
Expand Down Expand Up @@ -63,7 +61,7 @@ public IndirectDrawManager(IndirectPrograms programs) {
lightBuffers = new LightBuffers();
matrixBuffer = new MatrixBuffer();

depthPyramid = new DepthPyramid(programs.getDepthReduceProgram());
depthPyramid = new DepthPyramid(programs.getDownsampleFirstProgram(), programs.getDownsampleSecondProgram());
}

@Override
Expand Down Expand Up @@ -151,8 +149,7 @@ public void flush(LightStorage lightStorage, EnvironmentStorage environmentStora

matrixBuffer.bind();

GL46.glActiveTexture(GL46.GL_TEXTURE0);
GL46.glBindTexture(GL46.GL_TEXTURE_2D, depthPyramid.pyramidTextureId);
depthPyramid.bindForCull();

for (var group : cullingGroups.values()) {
group.dispatchCull();
Expand Down Expand Up @@ -185,6 +182,8 @@ public void delete() {
crumblingDrawBuffer.delete();

programs.release();

depthPyramid.delete();
}

public void renderCrumbling(List<Engine.CrumblingBlock> crumblingBlocks) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
layout(local_size_x = 256) in;

uniform uint max_mip_level;

/// Generates a hierarchical depth buffer.
/// Based on FidelityFX SPD v2.1 https://github.com/GPUOpen-LibrariesAndSDKs/FidelityFX-SDK/blob/d7531ae47d8b36a5d4025663e731a47a38be882f/sdk/include/FidelityFX/gpu/spd/ffx_spd.h#L528
/// Based on Bevy's more readable implementation https://github.com/JMS55/bevy/blob/ca2c8e63b9562f88c8cd7e1d88a17a4eea20aaf4/crates/bevy_pbr/src/meshlet/downsample_depth.wgsl

shared float[16][16] intermediate_memory;

// These are builtins in wgsl but we can trivially emulate them.
uint extractBits(uint e, uint offset, uint count) {
return (e >> offset) & ((1u << count) - 1u);
}

uint insertBits(uint e, uint newbits, uint offset, uint count) {
uint countMask = ((1u << count) - 1u);
// zero out the bits we're going to replace first
return (e & ~(countMask << offset)) | ((newbits & countMask) << offset);
}

// I do not understand how this works but it seems cool.
uvec2 remap_for_wave_reduction(uint a) {
return uvec2(
insertBits(extractBits(a, 2u, 3u), a, 0u, 1u),
insertBits(extractBits(a, 3u, 3u), extractBits(a, 1u, 2u), 0u, 2u)
);
}

float reduce_4(vec4 v) {
return max(max(v.x, v.y), max(v.z, v.w));
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#include "flywheel:internal/indirect/downsample.glsl"

layout(binding = 0) uniform sampler2D mip_0;
layout(binding = 1, r32f) uniform writeonly image2D mip_1;
layout(binding = 2, r32f) uniform writeonly image2D mip_2;
layout(binding = 3, r32f) uniform writeonly image2D mip_3;
layout(binding = 4, r32f) uniform writeonly image2D mip_4;
layout(binding = 5, r32f) uniform writeonly image2D mip_5;
layout(binding = 6, r32f) uniform writeonly image2D mip_6;

float reduce_load_mip_0(uvec2 tex) {
// NOTE: mip_0 is the actual depth buffer, and mip_1 is the "base" of our depth pyramid and has the next
// smallest Po2 dimensions to mip_1's dimensions. We dispatch enough invocations to cover the entire mip_1
// and will very likely oversample mip_0, but that's okay because we need to ensure conservative coverage.
// All following mip levels are proper halvings of their parents and will not waste any work.
vec2 uv = (vec2(tex) + 0.5) / vec2(imageSize(mip_1)) * 0.5;
return reduce_4(textureGather(mip_0, uv));
}

void downsample_mips_0_and_1(uint x, uint y, ivec2 workgroup_id, uint local_invocation_index) {
vec4 v;

ivec2 tex = workgroup_id * 64 + ivec2(x * 2u, y * 2u);
ivec2 pix = workgroup_id * 32 + ivec2(x, y);
v[0] = reduce_load_mip_0(tex);
imageStore(mip_1, pix, vec4(v[0]));

tex = workgroup_id * 64 + ivec2(x * 2u + 32u, y * 2u);
pix = workgroup_id * 32 + ivec2(x + 16u, y);
v[1] = reduce_load_mip_0(tex);
imageStore(mip_1, pix, vec4(v[1]));

tex = workgroup_id * 64 + ivec2(x * 2u, y * 2u + 32u);
pix = workgroup_id * 32 + ivec2(x, y + 16u);
v[2] = reduce_load_mip_0(tex);
imageStore(mip_1, pix, vec4(v[2]));

tex = workgroup_id * 64 + ivec2(x * 2u + 32u, y * 2u + 32u);
pix = workgroup_id * 32 + ivec2(x + 16u, y + 16u);
v[3] = reduce_load_mip_0(tex);
imageStore(mip_1, pix, vec4(v[3]));

if (max_mip_level <= 1u) { return; }

for (uint i = 0u; i < 4u; i++) {
intermediate_memory[x][y] = v[i];
barrier();
if (local_invocation_index < 64u) {
v[i] = reduce_4(vec4(
intermediate_memory[x * 2u + 0u][y * 2u + 0u],
intermediate_memory[x * 2u + 1u][y * 2u + 0u],
intermediate_memory[x * 2u + 0u][y * 2u + 1u],
intermediate_memory[x * 2u + 1u][y * 2u + 1u]
));
pix = (workgroup_id * 16) + ivec2(
x + (i % 2u) * 8u,
y + (i / 2u) * 8u
);
imageStore(mip_2, pix, vec4(v[i]));
}
barrier();
}

if (local_invocation_index < 64u) {
intermediate_memory[x + 0u][y + 0u] = v[0];
intermediate_memory[x + 8u][y + 0u] = v[1];
intermediate_memory[x + 0u][y + 8u] = v[2];
intermediate_memory[x + 8u][y + 8u] = v[3];
}
}


void downsample_mip_2(uint x, uint y, ivec2 workgroup_id, uint local_invocation_index) {
if (local_invocation_index < 64u) {
float v = reduce_4(vec4(
intermediate_memory[x * 2u + 0u][y * 2u + 0u],
intermediate_memory[x * 2u + 1u][y * 2u + 0u],
intermediate_memory[x * 2u + 0u][y * 2u + 1u],
intermediate_memory[x * 2u + 1u][y * 2u + 1u]
));
imageStore(mip_3, (workgroup_id * 8) + ivec2(x, y), vec4(v));
intermediate_memory[x * 2u + y % 2u][y * 2u] = v;
}
}

void downsample_mip_3(uint x, uint y, ivec2 workgroup_id, uint local_invocation_index) {
if (local_invocation_index < 16u) {
float v = reduce_4(vec4(
intermediate_memory[x * 4u + 0u + 0u][y * 4u + 0u],
intermediate_memory[x * 4u + 2u + 0u][y * 4u + 0u],
intermediate_memory[x * 4u + 0u + 1u][y * 4u + 2u],
intermediate_memory[x * 4u + 2u + 1u][y * 4u + 2u]
));
imageStore(mip_4, (workgroup_id * 4) + ivec2(x, y), vec4(v));
intermediate_memory[x * 4u + y][y * 4u] = v;
}
}

void downsample_mip_4(uint x, uint y, ivec2 workgroup_id, uint local_invocation_index) {
if (local_invocation_index < 4u) {
float v = reduce_4(vec4(
intermediate_memory[x * 8u + 0u + 0u + y * 2u][y * 8u + 0u],
intermediate_memory[x * 8u + 4u + 0u + y * 2u][y * 8u + 0u],
intermediate_memory[x * 8u + 0u + 1u + y * 2u][y * 8u + 4u],
intermediate_memory[x * 8u + 4u + 1u + y * 2u][y * 8u + 4u]
));
imageStore(mip_5, (workgroup_id * 2) + ivec2(x, y), vec4(v));
intermediate_memory[x + y * 2u][0u] = v;
}
}

void downsample_mip_5(ivec2 workgroup_id, uint local_invocation_index) {
if (local_invocation_index < 1u) {
float v = reduce_4(vec4(
intermediate_memory[0u][0u],
intermediate_memory[1u][0u],
intermediate_memory[2u][0u],
intermediate_memory[3u][0u]
));
imageStore(mip_6, workgroup_id, vec4(v));
}
}

void downsample_mips_2_to_5(uint x, uint y, ivec2 workgroup_id, uint local_invocation_index) {
if (max_mip_level <= 2u) { return; }
barrier();
downsample_mip_2(x, y, workgroup_id, local_invocation_index);

if (max_mip_level <= 3u) { return; }
barrier();
downsample_mip_3(x, y, workgroup_id, local_invocation_index);

if (max_mip_level <= 4u) { return; }
barrier();
downsample_mip_4(x, y, workgroup_id, local_invocation_index);

if (max_mip_level <= 5u) { return; }
barrier();
downsample_mip_5(workgroup_id, local_invocation_index);
}

void main() {
uvec2 sub_xy = remap_for_wave_reduction(gl_LocalInvocationIndex % 64u);
uint x = sub_xy.x + 8u * ((gl_LocalInvocationIndex >> 6u) % 2u);
uint y = sub_xy.y + 8u * (gl_LocalInvocationIndex >> 7u);

downsample_mips_0_and_1(x, y, ivec2(gl_WorkGroupID.xy), gl_LocalInvocationIndex);

downsample_mips_2_to_5(x, y, ivec2(gl_WorkGroupID.xy), gl_LocalInvocationIndex);
}
Loading

0 comments on commit ddb0450

Please sign in to comment.