From cdb19904a61381dfb23610001c71dd2115ee4a04 Mon Sep 17 00:00:00 2001 From: kchristin Date: Sun, 3 Nov 2024 16:55:47 +0200 Subject: [PATCH] Revert not skipping cuda host functions with const args --- lib/Differentiator/ReverseModeVisitor.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 009cad9da..956c6c5e1 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1821,8 +1821,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If all arguments are constant literals, then this does not contribute to // the gradient. // FIXME: revert this when this is integrated in the activity analysis pass. - if (!isa(CE) && !isa(CE) && - CE->getCallReturnType(m_Context).getAsString() != "cudaError_t") { + if (!isa(CE) && !isa(CE)) { bool allArgsAreConstantLiterals = true; for (const Expr* arg : CE->arguments()) { // if it's of type MaterializeTemporaryExpr, then check its @@ -1846,8 +1845,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // derived function. In the case of member functions, `implicit` // this object is always passed by reference. if (!dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) && - !isa(CE) && !isa(CE) && - CE->getCallReturnType(m_Context).getAsString() != "cudaError_t") { + !isa(CE) && !isa(CE)) { for (const Expr* Arg : CE->arguments()) { StmtDiff ArgDiff = Visit(Arg, dfdx()); CallArgs.push_back(ArgDiff.getExpr());