Skip to content

Commit

Permalink
Fix recursive processing of DiffRequests
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed May 19, 2024
1 parent 13e4c13 commit b6f8f15
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
4 changes: 4 additions & 0 deletions include/clad/Differentiator/DynamicGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ template <typename T> class DynamicGraph {
m_currentId = -1;
}

/// Check if currently processing a node.
/// \returns True if currently processing a node, false otherwise.
bool isProcessingNode() { return m_currentId != -1; }

/// Get the nodes in the graph.
std::vector<T> getNodes() { return m_nodes; }

Expand Down
17 changes: 11 additions & 6 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,12 +407,17 @@ namespace clad {
Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, Enabled);
Sema::LocalEagerInstantiationScope LocalInstantiations(S);

DiffRequest request = m_DiffRequestGraph.getNextToProcessNode();
while (request.Function != nullptr) {
m_DiffRequestGraph.setCurrentProcessingNode(request);
ProcessDiffRequest(request);
m_DiffRequestGraph.markCurrentNodeProcessed();
request = m_DiffRequestGraph.getNextToProcessNode();
if (!m_DiffRequestGraph.isProcessingNode()) {
// This check is to avoid recursive processing of the graph, as
// HandleTopLevelDecl can be called recursively in non-standard
// setup for code generation.
DiffRequest request = m_DiffRequestGraph.getNextToProcessNode();
while (request.Function != nullptr) {
m_DiffRequestGraph.setCurrentProcessingNode(request);
ProcessDiffRequest(request);
m_DiffRequestGraph.markCurrentNodeProcessed();
request = m_DiffRequestGraph.getNextToProcessNode();
}
}

// Put the TUScope in a consistent state after clad is done.
Expand Down

0 comments on commit b6f8f15

Please sign in to comment.