Skip to content

Commit

Permalink
Compute and process Differentiation Request graph
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Apr 26, 2024
1 parent d879f1b commit 22ec2d9
Show file tree
Hide file tree
Showing 12 changed files with 393 additions and 67 deletions.
17 changes: 13 additions & 4 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ namespace clad {
class CladPlugin;
clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P,
DiffRequest& request);
// FIXME: This function should be removed and the entire plans array
// should be somehow made accessible to all the visitors.
void AddRequestToSchedule(CladPlugin& P, const DiffRequest& request);
} // namespace plugin

} // namespace clad
Expand Down Expand Up @@ -87,6 +84,7 @@ namespace clad {
plugin::CladPlugin& m_CladPlugin;
clang::ASTContext& m_Context;
const DerivedFnCollector& m_DFC;
clad::DynamicGraph<DiffRequest>& m_DiffRequestGraph;
std::unique_ptr<utils::StmtClone> m_NodeCloner;
clang::NamespaceDecl* m_BuiltinDerivativesNSD;
/// A reference to the model to use for error estimation (if any).
Expand Down Expand Up @@ -137,7 +135,8 @@ namespace clad {

public:
DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P,
const DerivedFnCollector& DFC);
const DerivedFnCollector& DFC,
clad::DynamicGraph<DiffRequest>& DRG);
~DerivativeBuilder();
/// Reset the model use for error estimation (if any).
/// \param[in] estModel The error estimation model, can be either
Expand Down Expand Up @@ -172,6 +171,16 @@ namespace clad {
///
/// \returns The derived function if found, nullptr otherwise.
clang::FunctionDecl* FindDerivedFunction(const DiffRequest& request);
/// Add edge from current request to the given request in the DiffRequest
/// graph.
///
/// \param[in] request The request to add the edge to.
void AddEdgeToGraph(const DiffRequest& request);
/// Add edge between two requests in the DiffRequest graph.
///
/// \param[in] from The source request.
/// \param[in] to The destination request.
void AddEdgeToGraph(const DiffRequest& from, const DiffRequest& to);
};

} // end namespace clad
Expand Down
8 changes: 8 additions & 0 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ enum class DiffMode {
reverse_mode_forward_pass,
error_estimation
};

/// Returns true if the given mode is a pullback/pushforward mode.
inline bool IsPullbackOrPushforwardMode(DiffMode mode) {
return mode == DiffMode::experimental_pushforward ||
mode == DiffMode::experimental_pullback ||
mode == DiffMode::experimental_vector_pushforward ||
mode == DiffMode::reverse_mode_forward_pass;
}
}

#endif
63 changes: 57 additions & 6 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#ifndef CLAD_DIFF_PLANNER_H
#define CLAD_DIFF_PLANNER_H

#include "clad/Differentiator/DiffMode.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "llvm/ADT/SmallSet.h"
#include "clad/Differentiator/DiffMode.h"
#include "clad/Differentiator/DynamicGraph.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"

namespace clang {
class ASTContext;
Expand Down Expand Up @@ -90,9 +91,47 @@ struct DiffRequest {
/// 3) If no argument is provided, a default argument is used. The
/// function will be differentiated w.r.t. to its every parameter.
void UpdateDiffParamsInfo(clang::Sema& semaRef);

/// Define the == operator for DiffRequest.
bool operator==(const DiffRequest& other) const {
// either function match or previous declaration match
return (Function == other.Function ||
Function->getPreviousDecl() == other.Function ||
Function == other.Function->getPreviousDecl()) &&
BaseFunctionName == other.BaseFunctionName &&
CurrentDerivativeOrder == other.CurrentDerivativeOrder &&
RequestedDerivativeOrder == other.RequestedDerivativeOrder &&
CallContext == other.CallContext && Args == other.Args &&
Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis &&
DVI == other.DVI && use_enzyme == other.use_enzyme &&
DeclarationOnly == other.DeclarationOnly;
}

// String operator for printing the node.
operator std::string() const {
std::string res = BaseFunctionName + "__order_" +
std::to_string(CurrentDerivativeOrder) + "__mode_";
switch (Mode) {
case DiffMode::forward:
res += "forward";
break;
case DiffMode::reverse:
res += "reverse";
break;
case DiffMode::vector_forward_mode:
res += "vector_forward_mode";
break;
case DiffMode::experimental_pushforward:
res += "pushforward";
break;
case DiffMode::experimental_pullback:
res += "pullback";
break;
}
return res;
}
};

using DiffSchedule = llvm::SmallVector<DiffRequest, 16>;
using DiffInterval = std::vector<clang::SourceRange>;

struct RequestOptions {
Expand All @@ -106,9 +145,9 @@ struct DiffRequest {
///
DiffInterval& m_Interval;

/// The diff step-by-step plan for differentiation.
/// Graph to store the dependencies between different requests.
///
DiffSchedule& m_DiffPlans;
clad::DynamicGraph<DiffRequest>& m_DiffRequestGraph;

/// If set it means that we need to find the called functions and
/// add them for implicit diff.
Expand All @@ -120,12 +159,24 @@ struct DiffRequest {

public:
DiffCollector(clang::DeclGroupRef DGR, DiffInterval& Interval,
DiffSchedule& plans, clang::Sema& S, RequestOptions& opts);
clad::DynamicGraph<DiffRequest>& requestGraph, clang::Sema& S,
RequestOptions& opts);
bool VisitCallExpr(clang::CallExpr* E);

private:
bool isInInterval(clang::SourceLocation Loc) const;
};
}

// Define the hash function for DiffRequest.
template <> struct std::hash<clad::DiffRequest> {
std::size_t operator()(const clad::DiffRequest& DR) const {
// Use the function pointer as the hash of the DiffRequest, it
// is sufficient to break a reasonable number of collisions.
if (DR.Function->getPreviousDecl())
return std::hash<const void*>{}(DR.Function->getPreviousDecl());
return std::hash<const void*>{}(DR.Function);
}
};

#endif
166 changes: 166 additions & 0 deletions include/clad/Differentiator/DynamicGraph.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#ifndef CLAD_DYNAMICGRAPH_H
#define CLAD_DYNAMICGRAPH_H

#include <algorithm>
#include <functional>
#include <iostream>
#include <queue>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace clad {
template <typename T> class DynamicGraph {
private:
// Storing nodes in the graph. The index of the node in the vector is used as
// a unique identifier for the node in the adjacency list.
std::vector<T> nodes;

// Store the nodes in the graph as an unordered map from the node to a boolean
// indicating whether the node is processed or not. The second element in the
// pair is the id of the node in the nodes vector.
std::unordered_map<T, std::pair<bool, size_t>> nodeMap;

// Store the adjacency list for the graph. The adjacency list is a map from
// a node to the set of nodes that it has an edge to. We use integers inside
// the set to avoid copying the nodes.
std::unordered_map<size_t, std::set<size_t>> adjList;

// Store the reverse adjacency list for the graph. The reverse adjacency list
// is a map from a node to the set of nodes that have an edge to it. We use
// integers inside the set to avoid copying the nodes.
std::unordered_map<size_t, std::set<size_t>> revAdjList;

// Set of source nodes in the graph.
std::set<size_t> sources;

// Store the id of the node being processed right now.
int currentId = -1; // -1 means no node is being processed.

// Maintain a queue of nodes to be processed next.
std::queue<size_t> toProcessQueue;

public:
DynamicGraph() = default;

// Add an edge from src to dest
void addEdge(const T& src, const T& dest) {
std::pair<bool, size_t> srcInfo = addNode(src);
std::pair<bool, size_t> destInfo = addNode(dest);
size_t srcId = srcInfo.second;
size_t destId = destInfo.second;
adjList[srcId].insert(destId);
revAdjList[destId].insert(srcId);
}

// Add a node to the graph
std::pair<bool, size_t> addNode(const T& node, bool isSource = false) {
if (nodeMap.find(node) == nodeMap.end()) {
size_t id = nodes.size();
nodes.push_back(node);
nodeMap[node] = {false, id}; // node is not processed yet.
adjList[id] = {};
revAdjList[id] = {};
if (isSource) {
sources.insert(id);
toProcessQueue.push(id);
}
}
return nodeMap[node];
}

// Adds the edge from the current node to the destination node.
void addEdgeToCurrentNode(const T& dest) {
if (currentId == -1)
return;
addEdge(nodes[currentId], dest);
}

// Set the current node to the node with the given id.
void setCurrentProcessingNode(const T& node) {
if (nodeMap.find(node) != nodeMap.end())
currentId = nodeMap[node].second;
}

// Mark the current node as processed.
void markCurrentNodeProcessed() {
if (currentId != -1) {
nodeMap[nodes[currentId]].first = true;
for (size_t destId : adjList[currentId])
if (!nodeMap[nodes[destId]].first)
toProcessQueue.push(destId);
}
currentId = -1;
}

// Get the nodes in the graph.
std::vector<T> getNodes() { return nodes; }

// Check if two nodes are connected in the graph.
bool isConnected(const T& src, const T& dest) {
if (nodeMap.find(src) == nodeMap.end() ||
nodeMap.find(dest) == nodeMap.end())
return false;
size_t srcId = nodeMap[src].second;
size_t destId = nodeMap[dest].second;
return adjList[srcId].find(destId) != adjList[srcId].end();
}

// Print the graph in a human-readable format.
void print() {
// First print the nodes with their insertion order.
for (const T& node : nodes) {
std::pair<bool, int> nodeInfo = nodeMap[node];
std::cout << (std::string)node << ": #" << nodeInfo.second;
if (sources.find(nodeInfo.second) != sources.end())
std::cout << " (source)";
if (nodeInfo.first)
std::cout << ", (done)\n";
else
std::cout << ", (unprocessed)\n";
}
// Then print the edges.
for (int i = 0; i < nodes.size(); i++)
for (size_t dest : adjList[i])
std::cout << i << " -> " << dest << "\n";
}

// Topological sort of the directed graph. If the graph is not a DAG, the
// result will be a partial order. Use a recursive dfs heler function to
// implement the topological sort. If a->b, then a will come before b in the
// topological sort. In reverseOrder mode, the result will be in reverse
// topological order, i.e a->b, then b will come before a in the result.
std::vector<T> topologicalSort(bool reverseOrder = false) {
std::vector<T> res;
std::unordered_set<size_t> visited;

std::function<void(size_t)> dfs = [&](size_t node) -> void {
visited.insert(node);
for (size_t dest : adjList[node])
if (visited.find(dest) == visited.end())
dfs(dest);
res.push_back(nodes[node]);
};
for (size_t source : sources)
if (visited.find(source) == visited.end())
dfs(source);

if (reverseOrder)
return res;
std::reverse(res.begin(), res.end());
return res;
}

// Get the next to process node from the queue of nodes to be processed.
T getNextToProcessNode() {
if (toProcessQueue.empty())
return T();
size_t nextId = toProcessQueue.front();
toProcessQueue.pop();
return nodes[nextId];
}
};
} // end namespace clad

#endif // CLAD_DYNAMICGRAPH_H
2 changes: 1 addition & 1 deletion lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1168,8 +1168,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// into the queue.
pushforwardFnRequest.DeclarationOnly = false;
pushforwardFnRequest.DerivedFDPrototype = pushforwardFD;
plugin::AddRequestToSchedule(m_CladPlugin, pushforwardFnRequest);
}
m_Builder.AddEdgeToGraph(std::move(pushforwardFnRequest));

if (pushforwardFD) {
if (baseDiff.getExpr()) {
Expand Down
37 changes: 29 additions & 8 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ using namespace clang;
namespace clad {

DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P,
const DerivedFnCollector& DFC)
const DerivedFnCollector& DFC,
clad::DynamicGraph<DiffRequest>& G)
: m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), m_DFC(DFC),
m_DiffRequestGraph(G),
m_NodeCloner(new utils::StmtClone(m_Sema, m_Context)),
m_BuiltinDerivativesNSD(nullptr), m_NumericalDiffNSD(nullptr) {}

Expand Down Expand Up @@ -292,15 +294,25 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
assert(FD && "Must not be null.");
// If FD is only a declaration, try to find its definition.
if (!FD->getDefinition()) {
if (request.VerboseDiags)
diag(DiagnosticsEngine::Error,
request.CallContext ? request.CallContext->getBeginLoc() : noLoc,
"attempted differentiation of function '%0', which does not have a "
"definition", { FD->getNameAsString() });
return {};
// If only declaration is requested, allow this for non
// pullback/pushforward modes. For ex, this is required for Hessian -
// where we have forward mode followed by reverse mode, but we only need
// the declaration of the forward mode initially.
if (!request.DeclarationOnly ||
IsPullbackOrPushforwardMode(request.Mode)) {
if (request.VerboseDiags)
diag(DiagnosticsEngine::Error,
request.CallContext ? request.CallContext->getBeginLoc() : noLoc,
"attempted differentiation of function '%0', which does not "
"have a "
"definition",
{FD->getNameAsString()});
return {};
}
}

FD = FD->getDefinition();
if (!request.DeclarationOnly)
FD = FD->getDefinition();

// check if the function is non-differentiable.
if (clad::utils::hasNonDifferentiableAttribute(FD)) {
Expand Down Expand Up @@ -388,4 +400,13 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
return DFI.DerivedFn();
return nullptr;
}

void DerivativeBuilder::AddEdgeToGraph(const DiffRequest& request) {
m_DiffRequestGraph.addEdgeToCurrentNode(request);
}

void DerivativeBuilder::AddEdgeToGraph(const DiffRequest& from,
const DiffRequest& to) {
m_DiffRequestGraph.addEdge(from, to);
}
}// end namespace clad
Loading

0 comments on commit 22ec2d9

Please sign in to comment.