Skip to content

Commit

Permalink
Change differentiation schedule for forward mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Apr 1, 2024
1 parent 5cb3c18 commit ed695ad
Show file tree
Hide file tree
Showing 19 changed files with 749 additions and 545 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ namespace clad {
class CladPlugin;
clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P,
DiffRequest& request);
void AddRequestToSchedule(CladPlugin& P, const DiffRequest& request);
} // namespace plugin

} // namespace clad
Expand Down
133 changes: 71 additions & 62 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,68 +20,77 @@ namespace clang {

namespace clad {

/// A struct containing information about request to differentiate a function.
struct DiffRequest {
/// Function to be differentiated.
const clang::FunctionDecl* Function = nullptr;
/// Name of the base function to be differentiated. Can be different from
/// function->getNameAsString() when higher-order derivatives are computed.
std::string BaseFunctionName = {};
/// Current derivative order to be computed.
unsigned CurrentDerivativeOrder = 1;
/// Highest requested derivative order.
unsigned RequestedDerivativeOrder = 1;
/// Context in which the function is being called, or a call to
/// clad::gradient/differentiate, where function is the first arg.
clang::CallExpr* CallContext = nullptr;
/// Args provided to the call to clad::gradient/differentiate.
const clang::Expr* Args = nullptr;
/// Requested differentiation mode, forward or reverse.
DiffMode Mode = DiffMode::unknown;
/// If function appears in the call to clad::gradient/differentiate,
/// the call must be updated and the first arg replaced by the derivative.
bool CallUpdateRequired = false;
/// A flag to enable/disable diag warnings/errors during differentiation.
bool VerboseDiags = false;
/// A flag to enable TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
/// Puts the derived function and its code in the diff call
void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD,
clang::Sema& SemaRef);
/// Functor type to be differentiated, if any.
///
/// It is required because we cannot always determine if we are
/// differentiating a call operator using the function to be
/// differentiated, for example, when we are computing higher
/// order derivatives.
const clang::CXXRecordDecl* Functor = nullptr;

/// Stores differentiation parameters information. Stored information
/// includes info on indices range for array parameters, and nested data
/// member information for record (class) type parameters.
DiffInputVarsInfo DVI;

// A flag to enable the use of enzyme for backend instead of clad
bool use_enzyme = false;

/// Recomputes `DiffInputVarsInfo` using the current values of data members.
///
/// Differentiation parameters info is computed by parsing the argument
/// expression for the clad differentiation function calls. The argument is
/// used to specify independent parameter(s) for differentiation. There are
/// three valid options for the argument expression:
/// 1) A string literal, containing comma-separated names of function's
/// parameters, as defined in function's definition. If any of the
/// parameters are of array or pointer type the indexes of the array
/// that needs to be differentiated can also be specified, e.g.
/// "arr[1]" or "arr[2:5]". The function will be differentiated w.r.t.
/// all the specified parameters.
/// 2) A numeric literal. The function will be differentiated w.r.t. to
/// the parameter corresponding to literal's value index.
/// 3) If no argument is provided, a default argument is used. The
/// function will be differentiated w.r.t. to its every parameter.
void UpdateDiffParamsInfo(clang::Sema& semaRef);
};
/// A struct containing information about request to differentiate a function.
struct DiffRequest {
/// Function to be differentiated.
const clang::FunctionDecl* Function = nullptr;
/// Name of the base function to be differentiated. Can be different from
/// function->getNameAsString() when higher-order derivatives are computed.
std::string BaseFunctionName = {};
/// Current derivative order to be computed.
unsigned CurrentDerivativeOrder = 1;
/// Highest requested derivative order.
unsigned RequestedDerivativeOrder = 1;
/// Context in which the function is being called, or a call to
/// clad::gradient/differentiate, where function is the first arg.
clang::CallExpr* CallContext = nullptr;
/// Args provided to the call to clad::gradient/differentiate.
const clang::Expr* Args = nullptr;
/// Requested differentiation mode, forward or reverse.
DiffMode Mode = DiffMode::unknown;
/// If function appears in the call to clad::gradient/differentiate,
/// the call must be updated and the first arg replaced by the derivative.
bool CallUpdateRequired = false;
/// A flag to enable/disable diag warnings/errors during differentiation.
bool VerboseDiags = false;
/// A flag to enable TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
/// Puts the derived function and its code in the diff call
void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD,
clang::Sema& SemaRef);
/// Functor type to be differentiated, if any.
///
/// It is required because we cannot always determine if we are
/// differentiating a call operator using the function to be
/// differentiated, for example, when we are computing higher
/// order derivatives.
const clang::CXXRecordDecl* Functor = nullptr;

/// Stores differentiation parameters information. Stored information
/// includes info on indices range for array parameters, and nested data
/// member information for record (class) type parameters.
DiffInputVarsInfo DVI;

// A flag to enable the use of enzyme for backend instead of clad
bool use_enzyme = false;

/// A pointer to keep track of the prototype of the derived function.
/// This will be particularly useful for pushforward and pullback functions.
clang::FunctionDecl* DerivedFDPrototype = nullptr;

/// A boolean to indicate if only the declaration of the derived function
/// is required (and not the definition or body).
/// This will be particularly useful for pushforward and pullback functions.
bool DeclarationOnly = false;

/// Recomputes `DiffInputVarsInfo` using the current values of data members.
///
/// Differentiation parameters info is computed by parsing the argument
/// expression for the clad differentiation function calls. The argument is
/// used to specify independent parameter(s) for differentiation. There are
/// three valid options for the argument expression:
/// 1) A string literal, containing comma-separated names of function's
/// parameters, as defined in function's definition. If any of the
/// parameters are of array or pointer type the indexes of the array
/// that needs to be differentiated can also be specified, e.g.
/// "arr[1]" or "arr[2:5]". The function will be differentiated w.r.t.
/// all the specified parameters.
/// 2) A numeric literal. The function will be differentiated w.r.t. to
/// the parameter corresponding to literal's value index.
/// 3) If no argument is provided, a default argument is used. The
/// function will be differentiated w.r.t. to its every parameter.
void UpdateDiffParamsInfo(clang::Sema& semaRef);
};

using DiffSchedule = llvm::SmallVector<DiffRequest, 16>;
using DiffInterval = std::vector<clang::SourceRange>;
Expand Down
31 changes: 23 additions & 8 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,17 +492,23 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
m_Derivative->setParams(params);
m_Derivative->setBody(nullptr);

beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();
if (!request.DeclarationOnly) {
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();

// execute the functor inside the function body.
ExecuteInsidePushforwardFunctionBlock();
// execute the functor inside the function body.
ExecuteInsidePushforwardFunctionBlock();

Stmt* derivativeBody = endBlock();
m_Derivative->setBody(derivativeBody);
Stmt* derivativeBody = endBlock();
m_Derivative->setBody(derivativeBody);

endScope(); // Function body scope

if (request.DerivedFDPrototype)
m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype);
}

endScope(); // Function body scope
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope
Expand Down Expand Up @@ -1136,9 +1142,18 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// pushforwardFnRequest.RequestedDerivativeOrder = m_DerivativeOrder;
// Silence diag outputs in nested derivation process.
pushforwardFnRequest.VerboseDiags = false;

// Derive declaration of the pushforward function.
pushforwardFnRequest.DeclarationOnly = true;
FunctionDecl* pushforwardFD =
plugin::ProcessDiffRequest(m_CladPlugin, pushforwardFnRequest);

// Add the request to derive the definition of the pushforward function
// into the queue.
pushforwardFnRequest.DeclarationOnly = false;
pushforwardFnRequest.DerivedFDPrototype = pushforwardFD;
plugin::AddRequestToSchedule(m_CladPlugin, pushforwardFnRequest);

if (pushforwardFD) {
if (baseDiff.getExpr()) {
callDiff =
Expand Down
8 changes: 5 additions & 3 deletions test/FirstDerivative/CallArguments.C
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ float f_const_helper(const float x) {
return x * x;
}

// CHECK: clad::ValueAndPushforward<float, float> f_const_helper_pushforward(const float x, const float _d_x) {
// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};
// CHECK-NEXT: }
// CHECK: clad::ValueAndPushforward<float, float> f_const_helper_pushforward(const float x, const float _d_x);

float f_const_args_func_7(const float x, const float y) {
return f_const_helper(x) + f_const_helper(y) - y;
Expand Down Expand Up @@ -168,5 +166,9 @@ int main () { // expected-no-diagnostics
printf("f8_darg0=%f\n", f8.execute(f8x,2.F));
//CHECK-EXEC: f8_darg0=2.000000

// CHECK: clad::ValueAndPushforward<float, float> f_const_helper_pushforward(const float x, const float _d_x) {
// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};
// CHECK-NEXT: }

return 0;
}
16 changes: 10 additions & 6 deletions test/FirstDerivative/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@ double test_7(double i, double j) {
return res;
}

// CHECK: void increment_pushforward(int &i, int &_d_i) {
// CHECK-NEXT: ++i;
// CHECK-NEXT: }
// CHECK: void increment_pushforward(int &i, int &_d_i);

// CHECK: double test_7_darg0(double i, double j) {
// CHECK-NEXT: double _d_i = 1;
Expand All @@ -154,9 +152,7 @@ double test_8(double x) {
return func_with_enum(x, e);
}

// CHECK: clad::ValueAndPushforward<double, double> func_with_enum_pushforward(double x, E e, double _d_x) {
// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};
// CHECK-NEXT: }
// CHECK: clad::ValueAndPushforward<double, double> func_with_enum_pushforward(double x, E e, double _d_x);

// CHECK: double test_8_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
Expand All @@ -178,4 +174,12 @@ int main () {
clad::differentiate<clad::opts::enable_tbr>(test_8); // expected-error {{TBR analysis is not meant for forward mode AD.}}
clad::differentiate<clad::opts::enable_tbr, clad::opts::disable_tbr>(test_8); // expected-error {{Both enable and disable TBR options are specified.}}
return 0;

// CHECK: void increment_pushforward(int &i, int &_d_i) {
// CHECK-NEXT: ++i;
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<double, double> func_with_enum_pushforward(double x, E e, double _d_x) {
// CHECK-NEXT: return {x * x, _d_x * x + x * _d_x};
// CHECK-NEXT: }
}
Loading

0 comments on commit ed695ad

Please sign in to comment.