Skip to content

Commit

Permalink
Simplify handling of diff request options. NFC
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 16, 2024
1 parent cebc426 commit b7440b7
Showing 1 changed file with 167 additions and 142 deletions.
309 changes: 167 additions & 142 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,138 @@ namespace clad {
return found != m_ActivityRunInfo.ToBeRecorded.end();
}

///\returns true on error.
static bool ProcessInvocationArgs(Sema& S, SourceLocation endLoc,
const RequestOptions& ReqOpts,
const FunctionDecl* FD,
DiffRequest& request) {
const AnnotateAttr* A = FD->getAttr<AnnotateAttr>();
if (A->getAnnotation().equals("E")) {
// Error estimation has no options yet.
request.Mode = DiffMode::error_estimation;
return false;
}

if (A->getAnnotation().equals("D"))
request.Mode = DiffMode::forward;
else if (A->getAnnotation().equals("H"))
request.Mode = DiffMode::hessian;
else if (A->getAnnotation().equals("J"))
request.Mode = DiffMode::jacobian;
else if (A->getAnnotation().equals("G"))
request.Mode = DiffMode::reverse;
else {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc, "Unknown mode '%0'",
A->getAnnotation());
return true;

Check warning on line 680 in lib/Differentiator/DiffPlanner.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/DiffPlanner.cpp#L678-L680

Added lines #L678 - L680 were not covered by tests
}

request.EnableTBRAnalysis = ReqOpts.EnableTBRAnalysis;
request.EnableVariedAnalysis = ReqOpts.EnableVariedAnalysis;

const TemplateArgumentList* TAL = FD->getTemplateSpecializationArgs();
if (!TAL)
return false; // We no extra configuration, we are done.

Check warning on line 688 in lib/Differentiator/DiffPlanner.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/DiffPlanner.cpp#L688

Added line #L688 was not covered by tests

// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
const auto template_arg = TAL->get(0);
if (template_arg.getKind() == TemplateArgument::Pack)
for (const auto& arg : TAL->get(0).pack_elements())
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
else
bitmasked_opts_value = template_arg.getAsIntegral().getExtValue();

Check warning on line 698 in lib/Differentiator/DiffPlanner.cpp

View check run for this annotation

Codecov / codecov/patch

lib/Differentiator/DiffPlanner.cpp#L698

Added line #L698 was not covered by tests

bool enable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_tbr);
bool disable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr);
bool enable_va_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_va);
bool disable_va_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_va);

// Sanity checks.
if (enable_tbr_in_req && disable_tbr_in_req) {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"Both enable and disable TBR options are specified.");
return true;
}
if (enable_va_in_req && disable_va_in_req) {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"Both enable and disable VA options are specified.");
return true;
}
if (enable_tbr_in_req && request.Mode == DiffMode::forward) {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"TBR analysis is not meant for forward mode AD.");
return true;
}

// reverse vector mode is not yet supported.
if (request.Mode == DiffMode::reverse &&
clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"Reverse vector mode is not yet supported.");
return true;
}

// Override the default value of TBR analysis.
if (enable_tbr_in_req || disable_tbr_in_req)
request.EnableTBRAnalysis = enable_tbr_in_req && !disable_tbr_in_req;

// Override the default value of TBR analysis.
if (enable_va_in_req || disable_va_in_req)
request.EnableVariedAnalysis = enable_va_in_req && !disable_va_in_req;

// Check for clad::hessian<diagonal_only>.
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) {
if (request.Mode == DiffMode::hessian) {
request.Mode = DiffMode::hessian_diagonal;
return false;
}
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"Diagonal only option is only valid for Hessian mode.");
return true;
}

if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;

if (request.Mode == DiffMode::forward) {
// Check for clad::differentiate<N>.
if (unsigned order = clad::GetDerivativeOrder(bitmasked_opts_value))
request.RequestedDerivativeOrder = order;

// Check for clad::differentiate<immediate_mode>.
if (clad::HasOption(bitmasked_opts_value, clad::opts::immediate_mode))
request.ImmediateMode = true;

// Check for clad::differentiate<vector_mode>.
if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) {
request.Mode = DiffMode::vector_forward_mode;

// Currently only first order derivative is supported.
if (request.RequestedDerivativeOrder != 1) {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"Only first order derivative is supported for now "
"in vector forward mode.");
return true;
}

// We don't yet support enzyme with vector mode.
if (request.use_enzyme) {
utils::EmitDiag(S, DiagnosticsEngine::Error, endLoc,
"Enzyme's vector mode is not yet supported.");
return true;
}
}
}

return false;
}

bool DiffCollector::VisitCallExpr(CallExpr* E) {
// Check if we should look into this.
// FIXME: Generated code does not usually have valid source locations.
Expand All @@ -666,153 +798,46 @@ namespace clad {
FunctionDecl* FD = E->getDirectCallee();
if (!FD)
return true;

// We need to find our 'special' diff annotated such:
// clad::differentiate(...) __attribute__((annotate("D")))
// TODO: why not check for its name? clad::differentiate/gradient?
const AnnotateAttr* A = FD->getAttr<AnnotateAttr>();
if (A &&
(A->getAnnotation().equals("D") || A->getAnnotation().equals("G") ||
A->getAnnotation().equals("H") || A->getAnnotation().equals("J") ||
A->getAnnotation().equals("E"))) {
// A call to clad::differentiate or clad::gradient was found.
DeclRefExpr* DRE = getArgFunction(E, m_Sema);
if (!DRE)
return true;
DiffRequest request{};

// bitmask_opts is a template pack of unsigned integers, so we need to
// do bitwise or of all the values to get the final value.
unsigned bitmasked_opts_value = 0;
bool enable_tbr_in_req = false;
bool disable_tbr_in_req = false;
bool enable_va_in_req = false;
bool disable_va_in_req = false;
if (!A->getAnnotation().equals("E") &&
FD->getTemplateSpecializationArgs()) {
const auto template_arg = FD->getTemplateSpecializationArgs()->get(0);
if (template_arg.getKind() == TemplateArgument::Pack)
for (const auto& arg :
FD->getTemplateSpecializationArgs()->get(0).pack_elements())
bitmasked_opts_value |= arg.getAsIntegral().getExtValue();
else
bitmasked_opts_value = template_arg.getAsIntegral().getExtValue();

// Set option for TBR analysis.
enable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_tbr);
disable_tbr_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_tbr);
// Set option for Activity analysis.
enable_va_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::enable_va);
disable_va_in_req =
clad::HasOption(bitmasked_opts_value, clad::opts::disable_va);
if (enable_tbr_in_req && disable_tbr_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Both enable and disable TBR options are specified.");
return true;
}
if (enable_va_in_req && disable_va_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Both enable and disable VA options are specified.");
return true;
}
if (enable_tbr_in_req || disable_tbr_in_req) {
// override the default value of TBR analysis.
request.EnableTBRAnalysis = enable_tbr_in_req && !disable_tbr_in_req;
} else {
request.EnableTBRAnalysis = m_Options.EnableTBRAnalysis;
}
if (enable_va_in_req || disable_va_in_req) {
// override the default value of TBR analysis.
request.EnableVariedAnalysis = enable_va_in_req && !disable_va_in_req;
} else {
request.EnableVariedAnalysis = m_Options.EnableVariedAnalysis;
}
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only)) {
if (!A->getAnnotation().equals("H")) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Diagonal only option is only valid for Hessian "
"mode.");
return true;
}
}
}
if (!A)
return true;
if (!A->getAnnotation().equals("D") && !A->getAnnotation().equals("G") &&
!A->getAnnotation().equals("H") && !A->getAnnotation().equals("J") &&
!A->getAnnotation().equals("E"))
return true;

if (A->getAnnotation().equals("D")) {
request.Mode = DiffMode::forward;
unsigned derivative_order =
clad::GetDerivativeOrder(bitmasked_opts_value);
if (derivative_order == 0) {
derivative_order = 1; // default to first order derivative.
}
request.RequestedDerivativeOrder = derivative_order;
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;
if (clad::HasOption(bitmasked_opts_value, clad::opts::immediate_mode))
request.ImmediateMode = true;
if (enable_tbr_in_req) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"TBR analysis is not meant for forward mode AD.");
return true;
}
if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) {
request.Mode = DiffMode::vector_forward_mode;

// currently only first order derivative is supported.
if (derivative_order != 1) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Only first order derivative is supported for now "
"in vector forward mode.");
return true;
}
// we don't yet support enzyme with vector mode.
if (request.use_enzyme) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Enzyme's vector mode is not yet supported.");
return true;
}
}
} else if (A->getAnnotation().equals("H")) {
if (clad::HasOption(bitmasked_opts_value, clad::opts::diagonal_only))
request.Mode = DiffMode::hessian_diagonal;
else
request.Mode = DiffMode::hessian;
} else if (A->getAnnotation().equals("J")) {
request.Mode = DiffMode::jacobian;
} else if (A->getAnnotation().equals("G")) {
request.Mode = DiffMode::reverse;
if (clad::HasOption(bitmasked_opts_value, clad::opts::use_enzyme))
request.use_enzyme = true;
// reverse vector mode is not yet supported.
if (clad::HasOption(bitmasked_opts_value, clad::opts::vector_mode)) {
utils::EmitDiag(m_Sema, DiagnosticsEngine::Error, endLoc,
"Reverse vector mode is not yet supported.");
return true;
}
} else {
request.Mode = DiffMode::error_estimation;
}
request.CallContext = E;
request.CallUpdateRequired = true;
request.VerboseDiags = true;
request.Args = E->getArg(1);
auto derivedFD = cast<FunctionDecl>(DRE->getDecl());
request.Function = derivedFD;
request.BaseFunctionName = utils::ComputeEffectiveFnName(request.Function);

if (isCallOperator(m_Sema.getASTContext(), request.Function)) {
request.Functor = cast<CXXMethodDecl>(request.Function)->getParent();
}
// FIXME: add support for nested calls to clad::differentiate/gradient
// inside differentiated functions
assert(!m_TopMostFD &&
"nested clad::differentiate/gradient are not yet supported");
llvm::SaveAndRestore<const FunctionDecl*> saveTopMost = m_TopMostFD;
m_TopMostFD = FD;
TraverseDecl(derivedFD);
m_DiffRequestGraph.addNode(request, /*isSource=*/true);
}
// A call to clad::differentiate or clad::gradient was found.
DeclRefExpr* DRE = getArgFunction(E, m_Sema);
if (!DRE)
return true;

DiffRequest request;

if (ProcessInvocationArgs(m_Sema, endLoc, m_Options, FD, request))
return true;

request.CallContext = E;
request.CallUpdateRequired = true;
request.VerboseDiags = true;
request.Args = E->getArg(1);
auto derivedFD = cast<FunctionDecl>(DRE->getDecl());
request.Function = derivedFD;
request.BaseFunctionName = utils::ComputeEffectiveFnName(request.Function);

if (isCallOperator(m_Sema.getASTContext(), request.Function))
request.Functor = cast<CXXMethodDecl>(request.Function)->getParent();
// FIXME: add support for nested calls to clad::differentiate/gradient
// inside differentiated functions
assert(!m_TopMostFD &&
"nested clad::differentiate/gradient are not yet supported");
llvm::SaveAndRestore<const FunctionDecl*> saveTopMost = m_TopMostFD;
m_TopMostFD = FD;
TraverseDecl(derivedFD);
m_DiffRequestGraph.addNode(request, /*isSource=*/true);
/*else if (m_TopMostFD) {
// If another function is called inside differentiated function,
// this will be handled by Forward/ReverseModeVisitor::Derive.
Expand Down

0 comments on commit b7440b7

Please sign in to comment.