Skip to content

Commit

Permalink
* Using multiple ray gen nodes in a graph with different root signatu…
Browse files Browse the repository at this point in the history
…res is ok now.

 * It used to throw all ray tracing shaders into the shader table which made problems when there were incompatible root signatures inolved.
 * Now it only includes the shaders actually used by the ray gen node when making a shader table for that node.
* You can use multiple miss shaders in a ray gen node now.
 * The lookup function for translating miss shader names to indices had broken logic.
* @IvarPD
 * Updated OpenFBX and improved the FBX importer to acknowledge multiple materials in a single mesh.
 * If you run the viewer with "-renderdoc", a renderdoc capture button shows up. Disabled by default because this disables raytracing.
  • Loading branch information
Atrix256 committed Nov 22, 2024
1 parent 1a7b426 commit ff062f2
Show file tree
Hide file tree
Showing 60 changed files with 10,897 additions and 10,496 deletions.
28 changes: 16 additions & 12 deletions GigiCompilerLib/Backends/DX12/Backend_DX12.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1753,9 +1753,13 @@ void CopyShaderFileDX12(const Shader& shader, const std::unordered_map<std::stri
else if (GetTokenParameter(token.c_str(), "RTHitGroupIndex", param))
{
int foundIndex = -1;
for (int i = 0; i < (int)renderGraph.hitGroups.size(); ++i)
for (int i = 0; i < shader.Used_RTHitGroupIndex.size(); ++i)
{
const RTHitGroup& hitGroup = renderGraph.hitGroups[i];
int HGIndex = GetHitGroupIndex(renderGraph, shader.Used_RTHitGroupIndex[i].c_str());
if (HGIndex < 0)
continue;

const RTHitGroup& hitGroup = renderGraph.hitGroups[HGIndex];
if (shader.scope == hitGroup.scope && param == hitGroup.originalName)
{
foundIndex = i;
Expand All @@ -1765,7 +1769,7 @@ void CopyShaderFileDX12(const Shader& shader, const std::unordered_map<std::stri
break;
}

Assert(foundIndex != -1, "Could not find RTHitGroupIndex for \"%s\"", param.c_str());
Assert(foundIndex != -1, "Could not find RTHitGroupIndex for \"%s\" in shader \"%s\"", param.c_str(), shader.name.c_str());
if (foundIndex != -1)
{
shaderSpecificStringReplacementMap[token] = std::ostringstream();
Expand All @@ -1779,24 +1783,24 @@ void CopyShaderFileDX12(const Shader& shader, const std::unordered_map<std::stri
}
else if (GetTokenParameter(token.c_str(), "RTMissIndex", param))
{
int RTMissIndex = -1;
int foundIndex = 1;
for (const Shader& shader : renderGraph.shaders)
int foundIndex = -1;
for (int i = 0; i < shader.Used_RTMissIndex.size(); ++i)
{
if (shader.type != ShaderType::RTMiss)
int MissIndex = GetShaderIndex(renderGraph, shader.Used_RTMissIndex[i].c_str());
if (MissIndex < 0)
continue;

RTMissIndex++;

if (param == shader.name)
const Shader& missShader = renderGraph.shaders[MissIndex];
if (shader.scope == missShader.scope && param == missShader.originalName)
{
foundIndex = RTMissIndex;
foundIndex = i;
break;
}
if (foundIndex != -1)
break;
}
Assert(foundIndex != -1, "Could not find RTMissIndex for \"%s\"", param.c_str());

Assert(foundIndex != -1, "Could not find RTMissIndex for \"%s\" in shader \"%s\"", param.c_str(), shader.name.c_str());
if (foundIndex != -1)
{
shaderSpecificStringReplacementMap[token] = std::ostringstream();
Expand Down
70 changes: 58 additions & 12 deletions GigiCompilerLib/Backends/DX12/nodes/node_action_rayShader.inl
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,47 @@ static void MakeStringReplacementForNode(std::unordered_map<std::string, std::os
};
std::vector<ShaderExport> shaderExports;

// gather all the shaders involved
for (const Shader& shader : renderGraph.shaders)
{
if (shader.type != ShaderType::RTClosestHit && shader.type != ShaderType::RTMiss &&
shader.type != ShaderType::RTAnyHit && shader.type != ShaderType::RTIntersection)
continue;

// If this shader is referenced as RTMissIndex, include it
bool includeShader = false;
for (const std::string& RTMissShader : node.shader.shader->Used_RTMissIndex)
{
if (!_stricmp(shader.name.c_str(), RTMissShader.c_str()))
{
includeShader = true;
break;
}
}

// If this shader is part of an RTHitGroupIndex, include it
if (!includeShader)
{
for (const std::string& RTHitGroupName : node.shader.shader->Used_RTHitGroupIndex)
{
int hitGroupIndex = GetHitGroupIndex(renderGraph, RTHitGroupName.c_str());
if (hitGroupIndex >= 0)
{
const RTHitGroup& hitGroup = renderGraph.hitGroups[hitGroupIndex];
if (!_stricmp(shader.name.c_str(), hitGroup.closestHit.name.c_str()) ||
!_stricmp(shader.name.c_str(), hitGroup.anyHit.name.c_str()) ||
!_stricmp(shader.name.c_str(), hitGroup.intersection.name.c_str()))
{
includeShader = true;
break;
}
}
}
}

if (!includeShader)
continue;

ShaderExport newExport;
newExport.shader = &shader;
newExport.fileName = shader.destFileName;
Expand Down Expand Up @@ -175,7 +210,7 @@ static void MakeStringReplacementForNode(std::unordered_map<std::string, std::os

stringReplacementMap["/*$(CreateShared)*/"] <<
"\n { \"MAX_RECURSION_DEPTH\", \"" << node.maxRecursionDepth << "\" },"
"\n { \"RT_HIT_GROUP_COUNT\", \"" << renderGraph.hitGroups.size() << "\" },"
"\n { \"RT_HIT_GROUP_COUNT\", \"" << node.shader.shader->Used_RTHitGroupIndex.size() << "\" },"
"\n { nullptr, nullptr }"
"\n };";
}
Expand Down Expand Up @@ -228,22 +263,27 @@ static void MakeStringReplacementForNode(std::unordered_map<std::string, std::os
}

// Hit Groups
if (node.shader.shader->Used_RTHitGroupIndex.size() > 0)
{
stateObjectCreation <<
"\n"
"\n // Make the hit group sub objects"
"\n D3D12_HIT_GROUP_DESC hitGroupDescs[" << renderGraph.hitGroups.size() << "];"
"\n D3D12_HIT_GROUP_DESC hitGroupDescs[" << node.shader.shader->Used_RTHitGroupIndex.size() << "];"
;

for (size_t i = 0; i < renderGraph.hitGroups.size(); ++i)
for (size_t hgi = 0; hgi < node.shader.shader->Used_RTHitGroupIndex.size(); ++hgi)
{
const RTHitGroup& hitGroup = renderGraph.hitGroups[i];
int hitGroupIndex = GetHitGroupIndex(renderGraph, node.shader.shader->Used_RTHitGroupIndex[hgi].c_str());
if (hitGroupIndex < 0)
continue;

const RTHitGroup& hitGroup = renderGraph.hitGroups[hitGroupIndex];
stateObjectCreation <<
"\n"
"\n // Hit group: " << hitGroup.name <<
"\n {"
"\n D3D12_HIT_GROUP_DESC& hitGroupDesc = hitGroupDescs[" << i << "];"
"\n hitGroupDesc.HitGroupExport = L\"hitgroup" << i << "\";"
"\n D3D12_HIT_GROUP_DESC& hitGroupDesc = hitGroupDescs[" << hgi << "];"
"\n hitGroupDesc.HitGroupExport = L\"hitgroup" << hgi << "\";"
;

// Any hit Shader Import
Expand Down Expand Up @@ -393,13 +433,13 @@ static void MakeStringReplacementForNode(std::unordered_map<std::string, std::os

// Get the count for the number of items in each shader table
int shaderTableMissCount = 0;
for (const Shader& shader : renderGraph.shaders)
for (const ShaderExport& shaderExport : shaderExports)
{
if (shader.type == ShaderType::RTMiss)
if (shaderExport.shaderType == ShaderType::RTMiss)
shaderTableMissCount++;
}
int shaderTableRaygenCount = 1;
int shaderTableHitGroupCount = (int)renderGraph.hitGroups.size();
int shaderTableHitGroupCount = (int)node.shader.shader->Used_RTHitGroupIndex.size(); // How many hit groups used by this shader

// Ray Gen Shader Table
{
Expand Down Expand Up @@ -435,6 +475,7 @@ static void MakeStringReplacementForNode(std::unordered_map<std::string, std::os
}

// Miss Shader Table
if (node.shader.shader->Used_RTMissIndex.size() > 0)
{
stringReplacementMap["/*$(CreateShared)*/"] <<
"\n"
Expand Down Expand Up @@ -468,6 +509,7 @@ static void MakeStringReplacementForNode(std::unordered_map<std::string, std::os
}

// Hit Group Table
if (node.shader.shader->Used_RTHitGroupIndex.size() > 0)
{
stringReplacementMap["/*$(CreateShared)*/"] <<
"\n"
Expand All @@ -482,7 +524,7 @@ static void MakeStringReplacementForNode(std::unordered_map<std::string, std::os
"\n"
;

for (size_t index = 0; index < renderGraph.hitGroups.size(); ++index)
for (int index = 0; index < node.shader.shader->Used_RTHitGroupIndex.size(); ++index)
{
stringReplacementMap["/*$(CreateShared)*/"] <<
"\n memcpy(shaderTableBytes, soprops->GetShaderIdentifier(L\"hitgroup" << index << "\"), D3D12_RAYTRACING_SHADER_RECORD_BYTE_ALIGNMENT);"
Expand Down Expand Up @@ -778,14 +820,18 @@ static void MakeStringReplacementForNode(std::unordered_map<std::string, std::os
"\n dispatchDesc.Depth = ((baseDispatchSize[2] + " << node.dispatchSize.preAdd[2] << ") * " << node.dispatchSize.multiply[2] << ") / " <<
node.dispatchSize.divide[2] << " + " << node.dispatchSize.postAdd[2] << ";";

//if (runtimeData.m_shaderTableMiss)

// write out the table addresses and size
stringReplacementMap["/*$(Execute)*/"] <<
"\n dispatchDesc.RayGenerationShaderRecord.StartAddress = ContextInternal::rayShader_" << node.name << "_shaderTableRayGen->GetGPUVirtualAddress();"
"\n dispatchDesc.RayGenerationShaderRecord.SizeInBytes = ContextInternal::rayShader_" << node.name << "_shaderTableRayGenSize;"
"\n dispatchDesc.MissShaderTable.StartAddress = ContextInternal::rayShader_" << node.name << "_shaderTableMiss->GetGPUVirtualAddress();"
"\n if (ContextInternal::rayShader_" << node.name << "_shaderTableMiss)"
"\n dispatchDesc.MissShaderTable.StartAddress = ContextInternal::rayShader_" << node.name << "_shaderTableMiss->GetGPUVirtualAddress();"
"\n dispatchDesc.MissShaderTable.SizeInBytes = ContextInternal::rayShader_" << node.name << "_shaderTableMissSize;"
"\n dispatchDesc.MissShaderTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;"
"\n dispatchDesc.HitGroupTable.StartAddress = ContextInternal::rayShader_" << node.name << "_shaderTableHitGroup->GetGPUVirtualAddress();"
"\n if (ContextInternal::rayShader_" << node.name << "_shaderTableHitGroup)"
"\n dispatchDesc.HitGroupTable.StartAddress = ContextInternal::rayShader_" << node.name << "_shaderTableHitGroup->GetGPUVirtualAddress();"
"\n dispatchDesc.HitGroupTable.SizeInBytes = ContextInternal::rayShader_" << node.name << "_shaderTableHitGroupSize;"
"\n dispatchDesc.HitGroupTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;"
"\n"
Expand Down
Loading

0 comments on commit ff062f2

Please sign in to comment.