Skip to content

Commit

Permalink
fix gbuffer normal & improve stochastic samping
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiQing-R committed Dec 6, 2024
1 parent 1c7c1a5 commit cb0ed4f
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 115 deletions.
4 changes: 3 additions & 1 deletion shaders/gbuffer.vert
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ layout(set = 0, binding = 0, scalar) uniform _SceneCamera { SceneCamera sceneCa

layout(push_constant) uniform Instance {
mat4 model;
mat4 modelInvTrp;
uint id;
} instanceData;

Expand All @@ -26,7 +27,8 @@ layout(location = 1) out vec3 normal;
void main()
{
instanceID = instanceData.id;
normal = decompress_unit_vec(in_normal);
vec3 normalL = decompress_unit_vec(in_normal);
normal = mat3(instanceData.modelInvTrp) * normalL;

vec4 wpos = instanceData.model * vec4(in_pos, 1);
wpos /= wpos.w;
Expand Down
2 changes: 1 addition & 1 deletion shaders/lightPass.frag
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ void main()
vec3 indirectLight = texture(indirectLightMap, uv).rgb;
//vec3 diffuseAlbedo = state.mat.albedo * (M_1_OVER_PI * (1.0 - state.mat.metallic));
vec3 diffuseAlbedo = state.mat.albedo * (1.0 - F_SchlickRoughness(state.mat.f0, max(0.0, dot(-camRay.direction, state.normal)), state.mat.roughness)
* (1.0 - state.mat.metallic));
* (1.0 - state.mat.metallic)) * 0.5;
vec3 directLighting = hit ? vec3(0) : directLight.radiance;
vec3 reflectionColor = texelFetch(filteredReflectionColor, ivec2(gl_FragCoord.xy), 0).rgb;

Expand Down
170 changes: 81 additions & 89 deletions shaders/reflection_generation.comp
Original file line number Diff line number Diff line change
Expand Up @@ -141,109 +141,101 @@ void main()
vec2 stocasticUV = uv + vec2(rand(randSeed) - 0.5, rand(randSeed) - 0.5) / vec2(imageRes);

prd.seed = tea(rtxState.size.x * gl_GlobalInvocationID.y + gl_GlobalInvocationID.x, rtxState.totalFrames * rtxState.maxSamples);
vec2 d = stocasticUV * 2.0 - 1.0;

vec4 origin = sceneCamera.viewInverse * vec4(0, 0, 0, 1);
vec4 target = sceneCamera.projInverse * vec4(d.x, d.y, 1, 1);
vec4 direction = sceneCamera.viewInverse * vec4(normalize(target.xyz), 0);
Ray ray = Ray(origin.xyz, direction.xyz);

float firstDepth = -1.f;
BsdfSampleRec reflectBsdfSampleRec;
float weight;
vec3 radiance = surfelRefelctionTrace(ray, 8, firstDepth, reflectBsdfSampleRec, weight);
if (firstDepth == INFINITY)
{
imageStore(reflectionColor, imageCoords, vec4(0.0));
imageStore(reflectionDirection, imageCoords, vec4(0.0));
return;
}
float lum = dot(radiance, vec3(0.212671f, 0.715160f, 0.072169f));
if(lum > rtxState.fireflyClampThreshold)
{
radiance *= rtxState.fireflyClampThreshold / lum;
}
float invPdf = 1.f / reflectBsdfSampleRec.pdf;
//radiance *= invPdf * reflectBsdfSampleRec.f * max(0.0, dot(reflectBsdfSampleRec.L, state.normal));

imageStore(reflectionColor, imageCoords, vec4(radiance, weight));
imageStore(reflectionDirection, imageCoords, vec4(reflectBsdfSampleRec.L, invPdf));



// ivec2 gbufferCoords = ivec2(imageCoords * 2);
// uint primObjID = texelFetch(gbufferPrim, gbufferCoords, 0).r;
// path trace solution
// vec2 d = stocasticUV * 2.0 - 1.0;
//
// // reconstruct world position from depth
// float depth = texelFetch(gbufferDepth, gbufferCoords, 0).r;
// if (depth == 1.0)
// {
// imageStore(reflectionColor, imageCoords, vec4(0.0));
// imageStore(reflectionDirection, imageCoords, vec4(0.0));
// return;
// }
//
// vec3 worldPos = WorldPosFromDepth(stocasticUV, depth);
//
// uint nodeID = primObjID >> 23;
// uint instanceID = sceneNodes[nodeID].primMesh;
// mat4 worldMat = sceneNodes[nodeID].worldMatrix;
// uint primID = primObjID & 0x007FFFFF;
// InstanceData pinfo = geoInfo[instanceID];
//
// // Primitive buffer addresses
// Indices indices = Indices(pinfo.indexAddress);
// Vertices vertices = Vertices(pinfo.vertexAddress);
//
// // Indices of this triangle primitive.
// uvec3 tri = indices.i[primID];
//
// // All vertex attributes of the triangle.
// VertexAttributes attr0 = vertices.v[tri.x];
// VertexAttributes attr1 = vertices.v[tri.y];
// VertexAttributes attr2 = vertices.v[tri.z];
//
// // camera ray
// vec3 camPos = (sceneCamera.viewInverse * vec4(0, 0, 0, 1)).xyz;
// Ray camRay = Ray(camPos, normalize(worldPos - camPos));
// vec4 origin = sceneCamera.viewInverse * vec4(0, 0, 0, 1);
// vec4 target = sceneCamera.projInverse * vec4(d.x, d.y, 1, 1);
// vec4 direction = sceneCamera.viewInverse * vec4(normalize(target.xyz), 0);
// Ray ray = Ray(origin.xyz, direction.xyz);
//
// // decompress normal
// vec3 normal = decompress_unit_vec(texelFetch(gbufferNormal, gbufferCoords, 0).r);
// float firstDepth = -1.f;
// BsdfSampleRec reflectBsdfSampleRec;
// float weight;
//
// // reflected direction and brdf
// State state = GetState(primObjID, normal, depth, uv);
// // ignore rough surface
// if (state.mat.roughness > 0.95)
// {
// imageStore(reflectionColor, imageCoords, vec4(0.0));
// vec3 radiance = surfelRefelctionTrace(ray, 2, firstDepth, reflectBsdfSampleRec, weight);
// if (firstDepth == INFINITY)
// {
// imageStore(reflectionColor, imageCoords, vec4(0.0));
// imageStore(reflectionDirection, imageCoords, vec4(0.0));
// return;
// }
// BsdfSampleRec reflectBsdfSampleRec;
// float weight = 0.0;
// uint maxItr = 0;
//
// while(weight < 1e-5 && maxItr < 3)
// float lum = dot(radiance, vec3(0.212671f, 0.715160f, 0.072169f));
// if(lum > rtxState.fireflyClampThreshold)
// {
// reflectBsdfSampleRec.f = SpecSample(state, -camRay.direction, state.ffnormal, reflectBsdfSampleRec.L, reflectBsdfSampleRec.pdf, prd.seed);
// // calculate brdf weight
// weight = brdfWeight(-camRay.direction, state.ffnormal, reflectBsdfSampleRec.L, state.mat.roughness);
// reflectBsdfSampleRec.pdf = max(1e-4, reflectBsdfSampleRec.pdf);
// maxItr++;
// radiance *= rtxState.fireflyClampThreshold / lum;
// }
//
// // reflection color
// Ray reflectedRay = Ray(worldPos + 1e-2 * normal, normalize(reflectBsdfSampleRec.L));


// float firstDepth = -1.f;
// vec3 radiance = surfelRefelctionTrace(reflectedRay, 1, firstDepth);
// float invPdf = 1.f / reflectBsdfSampleRec.pdf;
// radiance *= invPdf * reflectBsdfSampleRec.f * max(0.0, dot(reflectBsdfSampleRec.L, state.normal));
// //radiance *= invPdf * reflectBsdfSampleRec.f * max(0.0, dot(reflectBsdfSampleRec.L, state.normal));
//
// imageStore(reflectionColor, imageCoords, vec4(radiance, weight));
// imageStore(reflectionDirection, imageCoords, vec4(reflectBsdfSampleRec.L, invPdf));



ivec2 gbufferCoords = ivec2(imageCoords * 2);
uint primObjID = texelFetch(gbufferPrim, gbufferCoords, 0).r;

// reconstruct world position from depth
float depth = texelFetch(gbufferDepth, gbufferCoords, 0).r;
if (depth == 1.0)
{
imageStore(reflectionColor, imageCoords, vec4(0.0));
imageStore(reflectionDirection, imageCoords, vec4(0.0));
return;
}

vec3 worldPos = WorldPosFromDepth(stocasticUV, depth);

uint nodeID = primObjID >> 23;
uint instanceID = sceneNodes[nodeID].primMesh;
mat4 worldMat = sceneNodes[nodeID].worldMatrix;
uint primID = primObjID & 0x007FFFFF;
InstanceData pinfo = geoInfo[instanceID];

// camera ray
vec3 camPos = (sceneCamera.viewInverse * vec4(0, 0, 0, 1)).xyz;
Ray camRay = Ray(camPos, normalize(worldPos - camPos));

// decompress normal
vec3 normal = decompress_unit_vec(texelFetch(gbufferNormal, gbufferCoords, 0).r);

// reflected direction and brdf
State state = GetState(primObjID, normal, depth, uv);
// ignore rough surface
if (state.mat.roughness > 0.95)
{
imageStore(reflectionColor, imageCoords, vec4(0.0));
imageStore(reflectionDirection, imageCoords, vec4(0.0));
return;
}
BsdfSampleRec reflectBsdfSampleRec;
float weight = 0.0;
uint maxItr = 0;

while(weight < 1e-5 && maxItr < 3)
{
reflectBsdfSampleRec.f = SpecSample(state, -camRay.direction, state.ffnormal, reflectBsdfSampleRec.L, reflectBsdfSampleRec.pdf, prd.seed);
// calculate brdf weight
weight = brdfWeight(-camRay.direction, state.ffnormal, reflectBsdfSampleRec.L, state.mat.roughness);
reflectBsdfSampleRec.pdf = max(1e-4, reflectBsdfSampleRec.pdf);
maxItr++;
}

// reflection color
Ray reflectedRay = Ray(worldPos + 1e-2 * normal, normalize(reflectBsdfSampleRec.L));

float firstDepth = -1.f;
BsdfSampleRec tmp;
vec3 radiance = surfelRefelctionTrace(reflectedRay, 1, firstDepth, tmp, weight);
float invPdf = 1.f / reflectBsdfSampleRec.pdf;
radiance *= invPdf * reflectBsdfSampleRec.f * max(0.0, dot(reflectBsdfSampleRec.L, state.ffnormal));

imageStore(reflectionColor, imageCoords, vec4(radiance, weight));
imageStore(reflectionDirection, imageCoords, vec4(vec3(state.matID + 1), invPdf));


//ClosestHit(reflectedRay);


Expand Down
42 changes: 23 additions & 19 deletions shaders/shaderUtils_surfel_cell.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,17 @@ bool finalizePathWithSurfel(vec3 worldPos, vec3 worldNor, inout vec4 irradiance)
const uint searchRange = min(16, cellInfo.surfelCount);
uint searchCnt = 0;

for (uint i = 0; i < cellInfo.surfelCount; i++)
uint randSeed = initRandom(uvec2(rtxState.totalFrames, floatBitsToUint(worldPos.x)),
uvec2(floatBitsToUint(worldPos.y), floatBitsToUint(worldPos.z)), rtxState.frame);

uint targetCnt = min(64, cellInfo.surfelCount);
float surfelCntF = float(cellInfo.surfelCount);

for (uint i = 0; i < targetCnt; i++)
{
if (searchCnt == searchRange) break;
uint surfelIndex = cellToSurfel[cellOffset + i];
uint currIndex = uint(rand(randSeed) * surfelCntF);

uint surfelIndex = cellToSurfel[cellOffset + currIndex];
Surfel surfel = surfelBuffer[surfelIndex];
vec3 neiNor = decompress_unit_vec(surfel.normal);
bool isSleeping = (surfelRecycleInfo[surfelIndex].status & 0x0001) != 0;
Expand Down Expand Up @@ -208,7 +215,6 @@ bool finalizePathWithSurfel(vec3 worldPos, vec3 worldNor, inout vec4 irradiance)
contribution *= pow(1.f - dist / surfel.radius, 2.0);
irradiance += vec4(surfel.radiance, 1.f) * contribution;
}
searchCnt++;
surfelRecycleInfo[surfelIndex].status |= 0x0004u;
}

Expand All @@ -219,9 +225,7 @@ bool finalizePathWithSurfel(vec3 worldPos, vec3 worldNor, inout vec4 irradiance)
{
irradiance /= irradiance.w;
}

//uint randSeed = initRandom(uvec2(rtxState.totalFrames, floatBitsToUint(worldPos.x)),
// uvec2(floatBitsToUint(worldPos.y), floatBitsToUint(worldPos.z)), rtxState.frame);

//
// spawn sleeping surfel if coverage is low.
//if (surfelCounter.aliveSurfelCnt < kMaxSurfelCount &&
Expand Down Expand Up @@ -582,18 +586,18 @@ vec3 surfelRefelctionTrace(Ray r, int maxDepth, inout float firstDepth, inout Bs
}

// use surfel indirect when the path reach max depth
//if (depth == maxDepth && valid)
// {
// vec4 irradiance = vec4(0.0);
// bool rst = finalizePathWithSurfel(sstate.position, sstate.normal, irradiance);
// //bool rst = false;
// if (rst)
// {
// // apply diffuse ratio
// irradiance.rgb *= diffuseRatio;
// radiance += irradiance.xyz * throughput;
// }
// }
if (depth == maxDepth && valid)
{
vec4 irradiance = vec4(0.0);
bool rst = finalizePathWithSurfel(sstate.position, sstate.normal, irradiance);
//bool rst = false;
if (rst)
{
// apply diffuse ratio
irradiance.rgb *= diffuseRatio;
radiance += irradiance.xyz * throughput;
}
}


return radiance;
Expand Down
12 changes: 9 additions & 3 deletions shaders/temporal_spatial_pass.comp
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,10 @@ void main()
vec4 centerDirection = imageLoad(reflectionDirection, halfImageCoords);
if (centerDirection == vec4(0.0))
{
//imageStore(filteredReflectionColor, imageCoords, vec4(0.0));
imageStore(filteredReflectionColor, imageCoords, vec4(0.0));
return;
}
float matID = centerDirection.x;

//vec4 centerBrdf = imageLoad(reflectionPointBrdf, halfImageCoords);

Expand All @@ -151,7 +152,7 @@ void main()
result += centerContrib;
weightSum += centerWeight;

uint quadID = (imageCoords.y & 0x1) << 1 + (imageCoords.x & 0x1);
uint quadID = (imageCoords.x & 0x1) << 1 + (imageCoords.y & 0x1);
float variance = 0.0;

for (int i = 0; i < 16; ++i)
Expand All @@ -164,7 +165,12 @@ void main()
continue;

float brdfWeight = neighborColor.a;
float pdfInv = max(imageLoad(reflectionDirection, neighborCoords).a, 1e-3);
vec4 neiDir = imageLoad(reflectionDirection, neighborCoords);
float neiMatID = neiDir.x;
if (neiMatID != matID)
continue;

float pdfInv = max(neiDir.a, 1e-3);

// Combine BRDF weight with distance weight
float weight = max(brdfWeight * pdfInv, 0.0);
Expand Down
1 change: 1 addition & 0 deletions src/gbuffer_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ void GbufferPass::run(const VkCommandBuffer& cmdBuf, const VkExtent2D& size, nvv
instanceData.id = nodeID++;
uint32_t primID = node.primMesh;
instanceData.model = node.worldMatrix;
instanceData.modelInvTrp = glm::mat4(glm::inverse(glm::transpose(glm::mat3(instanceData.model))));

// Sending the push constant information
vkCmdPushConstants(cmdBuf, m_pipelineLayout, VK_SHADER_STAGE_VERTEX_BIT, 0, sizeof(InstanceData), &instanceData);
Expand Down
1 change: 1 addition & 0 deletions src/gbuffer_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class GbufferPass : Renderer
struct InstanceData
{
glm::mat4 model;
glm::mat4 modelInvTrp;
uint32_t id;
};

Expand Down
4 changes: 2 additions & 2 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ int main(int argc, char** argv)
InputParser parser(argc, argv);
//std::string sceneFile = parser.getString("-f", "Sponza/Sponza.gltf");
//std::string sceneFile = parser.getString("-f", "Street/scene.gltf");
//std::string sceneFile = parser.getString("-f", "Hospital/scene.gltf");
std::string sceneFile = parser.getString("-f", "station/station.gltf");
std::string sceneFile = parser.getString("-f", "apocal/apocal.gltf");
//std::string sceneFile = parser.getString("-f", "station/station.gltf");
std::string hdrFilename = parser.getString("-e", "std_env.hdr");

// Setup GLFW window
Expand Down

0 comments on commit cb0ed4f

Please sign in to comment.