Skip to content

Commit

Permalink
Pass the final error argument by pointer rather than by reference and…
Browse files Browse the repository at this point in the history
… start using overloads in error estimation
  • Loading branch information
PetroZarytskyi committed Apr 30, 2024
1 parent 0889e54 commit d55aae9
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 186 deletions.
2 changes: 1 addition & 1 deletion demos/ErrorEstimation/FloatSum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ int main() {
finalError = 0;
unsigned int dn = 0;
// First execute the derived function.
df.execute(x, n, &ret[0], &dn, finalError);
df.execute(x, n, &ret[0], &dn, &finalError);

double kahanResult = kahanSum(x, n);
double vanillaResult = vanillaSum(x, n);
Expand Down
2 changes: 1 addition & 1 deletion demos/ErrorEstimation/PrintModel/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ int main() {
// Calculate the error
float dx, dy;
double error;
df.execute(2, 3, &dx, &dy, error);
df.execute(2, 3, &dx, &dy, &error);
}
10 changes: 5 additions & 5 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,14 +494,14 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {

template <typename ArgSpec = const char*, typename F,
typename DerivedFnType = GradientDerivedEstFnTraits_t<F>>
CladFunction<DerivedFnType> __attribute__((annotate("E")))
CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true> __attribute__((
annotate("E")))
estimate_error(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<
DerivedFnType>(derivedFn /* will be replaced by estimation code*/,
code);
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by estimation code*/, code);
}

// Gradient Structure for Reverse Mode Enzyme
Expand Down
7 changes: 3 additions & 4 deletions include/clad/Differentiator/FunctionTraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,7 @@ namespace clad {
// GradientDerivedEstFnTraits specializations for pure function pointer types
template <class ReturnType, class... Args>
struct GradientDerivedEstFnTraits<ReturnType (*)(Args...)> {
using type = void (*)(Args..., OutputParamType_t<Args, Args>...,
double&);
using type = void (*)(Args..., OutputParamType_t<Args, void>..., void*);
};

/// These macro expansions are used to cover all possible cases of
Expand All @@ -498,8 +497,8 @@ namespace clad {
#define GradientDerivedEstFnTraits_AddSPECS(var, cv, vol, ref, noex) \
template <typename R, typename C, typename... Args> \
struct GradientDerivedEstFnTraits<R (C::*)(Args...) cv vol ref noex> { \
using type = void (C::*)(Args..., OutputParamType_t<Args, Args>..., \
double&) cv vol ref noex; \
using type = void (C::*)(Args..., OutputParamType_t<Args, void>..., \
void*) cv vol ref noex; \
};

#if __cpp_noexcept_function_type > 0
Expand Down
4 changes: 3 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,9 @@ namespace clad {
/// Builds an overload for the gradient function that has derived params for
/// all the arguments of the requested function and it calls the original
/// gradient function internally
clang::FunctionDecl* CreateGradientOverload();
/// \param[in] numExtraParam The number of extra parameters requested by an
/// external source (e.g. the final error in error estimation).
clang::FunctionDecl* CreateGradientOverload(unsigned numExtraParam = 0);

/// Returns the type that should be used to represent the derivative of a
/// variable of type `yType` with respect to a parameter variable of type
Expand Down
12 changes: 7 additions & 5 deletions lib/Differentiator/ErrorEstimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ void ErrorEstimationHandler::ActAfterCreatingDerivedFnParamTypes(
// If we are performing error estimation, our gradient function
// will have an extra argument which will hold the final error value
paramTypes.push_back(
m_RMV->m_Context.getLValueReferenceType(m_RMV->m_Context.DoubleTy));
m_RMV->m_Context.getPointerType(m_RMV->m_Context.DoubleTy));
}

void ErrorEstimationHandler::ActAfterCreatingDerivedFnParams(
Expand All @@ -307,7 +307,8 @@ void ErrorEstimationHandler::ActAfterCreatingDerivedFnParams(

void ErrorEstimationHandler::ActBeforeCreatingDerivedFnBodyScope() {
// Reference to the final error statement
SetFinalErrorExpr(m_RMV->BuildDeclRef(m_Params->back()));
DeclRefExpr* DRE = m_RMV->BuildDeclRef(m_Params->back());
SetFinalErrorExpr(m_RMV->BuildOp(UO_Deref, DRE));
}

void ErrorEstimationHandler::ActOnEndOfDerivedFnBody() {
Expand Down Expand Up @@ -468,12 +469,13 @@ void ErrorEstimationHandler::ActBeforeFinalizingDifferentiateSingleExpr(
void ErrorEstimationHandler::ActBeforeDifferentiatingCallExpr(
llvm::SmallVectorImpl<clang::Expr*>& pullbackArgs,
llvm::SmallVectorImpl<Stmt*>& ArgDecls, bool hasAssignee) {
auto errorRef =
VarDecl* errorRef =
m_RMV->BuildVarDecl(m_RMV->m_Context.DoubleTy, "_t",
m_RMV->getZeroInit(m_RMV->m_Context.DoubleTy));
ArgDecls.push_back(m_RMV->BuildDeclStmt(errorRef));
auto finErr = m_RMV->BuildDeclRef(errorRef);
pullbackArgs.push_back(finErr);
Expr* finErr = m_RMV->BuildDeclRef(errorRef);
Expr* arg = m_RMV->BuildOp(UO_AddrOf, finErr);
pullbackArgs.push_back(arg);
if (hasAssignee) {
if (m_NestedFuncError)
m_NestedFuncError = m_RMV->BuildOp(BO_Add, m_NestedFuncError, finErr);
Expand Down
21 changes: 12 additions & 9 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

FunctionDecl* ReverseModeVisitor::CreateGradientOverload() {
FunctionDecl*
ReverseModeVisitor::CreateGradientOverload(unsigned numExtraParam) {
auto gradientParams = m_Derivative->parameters();
auto gradientNameInfo = m_Derivative->getNameInfo();
// Calculate the total number of parameters that would be required for
// automatic differentiation in the derived function if all args are
// requested.
// FIXME: Here we are assuming all function parameters are of differentiable
// type. Ideally, we should not make any such assumption.
std::size_t totalDerivedParamsSize = m_Function->getNumParams() * 2;
std::size_t numOfDerivativeParams = m_Function->getNumParams();
std::size_t totalDerivedParamsSize =
m_Function->getNumParams() * 2 + numExtraParam;
std::size_t numOfDerivativeParams =
m_Function->getNumParams() + numExtraParam;

// Account for the this pointer.
if (isa<CXXMethodDecl>(m_Function) && !utils::IsStaticMethod(m_Function))
Expand Down Expand Up @@ -273,7 +276,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
else
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));
if (args.empty())
// If there are no parameters to differentiate with respect to, don't
// generate the gradient. However, if an external source is attached, the
// gradient function can another purpose.
if (args.empty() && !m_ExternalSource)
return {};

if (m_ExternalSource)
Expand Down Expand Up @@ -336,9 +342,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If reverse mode differentiates only part of the arguments it needs to
// generate an overload that can take in all the diff variables
bool shouldCreateOverload = false;
// FIXME: Gradient overload doesn't know how to handle additional parameters
// added by the plugins yet.
if (!isVectorValued && numExtraParam == 0)
if (!isVectorValued)
shouldCreateOverload = true;
if (request.DerivedFDPrototype)
// If the overload is already created, we don't need to create it again.
Expand Down Expand Up @@ -452,8 +456,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

FunctionDecl* gradientOverloadFD = nullptr;
if (shouldCreateOverload) {
gradientOverloadFD =
CreateGradientOverload();
gradientOverloadFD = CreateGradientOverload(numExtraParam);
}

return DerivativeAndOverload{result.first, gradientOverloadFD};
Expand Down
46 changes: 23 additions & 23 deletions test/ErrorEstimation/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ float func(float x, float y) {
return y;
}

//CHECK: void func_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
//CHECK: void func_grad(float x, float y, float *_d_x, float *_d_y, double *_final_error) {
//CHECK-NEXT: float _t0;
//CHECK-NEXT: float _t1;
//CHECK-NEXT: _t0 = x;
Expand All @@ -29,47 +29,47 @@ float func(float x, float y) {
//CHECK-NEXT: *_d_x += _r_d1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_x += _r_d0;
//CHECK-NEXT: *_d_y += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: _final_error += std::abs(*_d_y * y * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_y * y * {{.+}});
//CHECK-NEXT: }

float func2(float x, int y) {
x = y * x + x * x;
return x;
}

//CHECK: void func2_grad_0(float x, int y, float *_d_x, double &_final_error) {
//CHECK: void func2_grad_0(float x, int y, float *_d_x, double *_final_error) {
//CHECK-NEXT: float _t0;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y * x + x * x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_x += y * _r_d0;
//CHECK-NEXT: *_d_x += _r_d0 * x;
//CHECK-NEXT: *_d_x += x * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: }

float func3(int x, int y) {
x = y;
return y;
}

//CHECK: void func3_grad(int x, int y, double &_final_error) {
//CHECK: void func3_grad(int x, int y, double *_final_error) {
//CHECK-NEXT: int _t0;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y;
Expand All @@ -85,7 +85,7 @@ float func4(float x, float y) {
return x;
}

//CHECK: void func4_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
//CHECK: void func4_grad(float x, float y, float *_d_x, float *_d_y, double *_final_error) {
//CHECK-NEXT: double _d_z = 0;
//CHECK-NEXT: float _t0;
//CHECK-NEXT: double z = y;
Expand All @@ -95,16 +95,16 @@ float func4(float x, float y) {
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: _d_z += _r_d0;
//CHECK-NEXT: *_d_y += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: *_d_y += _d_z;
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: _final_error += std::abs(*_d_y * y * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_y * y * {{.+}});
//CHECK-NEXT: }

float func5(float x, float y) {
Expand All @@ -113,7 +113,7 @@ float func5(float x, float y) {
return x;
}

//CHECK: void func5_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
//CHECK: void func5_grad(float x, float y, float *_d_x, float *_d_y, double *_final_error) {
//CHECK-NEXT: float _t0;
//CHECK-NEXT: int z = 56;
//CHECK-NEXT: _t0 = x;
Expand All @@ -122,28 +122,28 @@ float func5(float x, float y) {
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = *_d_x;
//CHECK-NEXT: *_d_x -= _r_d0;
//CHECK-NEXT: *_d_y += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: _final_error += std::abs(*_d_y * y * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_y * y * {{.+}});
//CHECK-NEXT: }

float func6(float x) { return x; }

//CHECK: void func6_grad(float x, float *_d_x, double &_final_error) {
//CHECK: void func6_grad(float x, float *_d_x, double *_final_error) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_x += 1;
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: }

float func7(float x, float y) { return (x * y); }

//CHECK: void func7_grad(float x, float y, float *_d_x, float *_d_y, double &_final_error) {
//CHECK: void func7_grad(float x, float y, float *_d_x, float *_d_y, double *_final_error) {
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: _ret_value0 = (x * y);
//CHECK-NEXT: goto _label0;
Expand All @@ -152,17 +152,17 @@ float func7(float x, float y) { return (x * y); }
//CHECK-NEXT: *_d_x += 1 * y;
//CHECK-NEXT: *_d_y += x * 1;
//CHECK-NEXT: }
//CHECK-NEXT: _final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: _final_error += std::abs(*_d_y * y * {{.+}});
//CHECK-NEXT: _final_error += std::abs(1. * _ret_value0 * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_x * x * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(*_d_y * y * {{.+}});
//CHECK-NEXT: *_final_error += std::abs(1. * _ret_value0 * {{.+}});
//CHECK-NEXT: }

float func8(int x, int y) {
x = y * y;
return x;
}

//CHECK: void func8_grad(int x, int y, double &_final_error) {
//CHECK: void func8_grad(int x, int y, double *_final_error) {
//CHECK-NEXT: int _t0;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y * y;
Expand Down
Loading

0 comments on commit d55aae9

Please sign in to comment.