diff --git a/include/clad/Differentiator/DynamicGraph.h b/include/clad/Differentiator/DynamicGraph.h index ae84e0b73..55e5ad350 100644 --- a/include/clad/Differentiator/DynamicGraph.h +++ b/include/clad/Differentiator/DynamicGraph.h @@ -98,6 +98,10 @@ template class DynamicGraph { m_currentId = -1; } + /// Check if currently processing a node. + /// \returns True if currently processing a node, false otherwise. + bool isProcessingNode() { return m_currentId != -1; } + /// Get the nodes in the graph. std::vector getNodes() { return m_nodes; } diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index c667ddd2c..062afcc70 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -407,12 +407,17 @@ namespace clad { Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, Enabled); Sema::LocalEagerInstantiationScope LocalInstantiations(S); - DiffRequest request = m_DiffRequestGraph.getNextToProcessNode(); - while (request.Function != nullptr) { - m_DiffRequestGraph.setCurrentProcessingNode(request); - ProcessDiffRequest(request); - m_DiffRequestGraph.markCurrentNodeProcessed(); - request = m_DiffRequestGraph.getNextToProcessNode(); + if (!m_DiffRequestGraph.isProcessingNode()) { + // This check is to avoid recursive processing of the graph, as + // HandleTopLevelDecl can be called recursively in non-standard + // setup for code generation. + DiffRequest request = m_DiffRequestGraph.getNextToProcessNode(); + while (request.Function != nullptr) { + m_DiffRequestGraph.setCurrentProcessingNode(request); + ProcessDiffRequest(request); + m_DiffRequestGraph.markCurrentNodeProcessed(); + request = m_DiffRequestGraph.getNextToProcessNode(); + } } // Put the TUScope in a consistent state after clad is done.