Skip to content

Commit

Permalink
warn fix
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Dec 10, 2024
1 parent 8b20c3a commit 6797990
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
22 changes: 17 additions & 5 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
#include <clang/AST/OperationKinds.h>
#include <clang/Sema/Ownership.h>

#include "llvm/ADT/SmallString.h"
#include "llvm/Support/SaveAndRestore.h"
#include <llvm/Support/raw_ostream.h>

#include <algorithm>
#include <numeric>
Expand Down Expand Up @@ -4281,12 +4283,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CE->getConstructor(), primalArgs, reverseForwAdjointArgs,
/*baseExpr=*/nullptr)) {
if (RD->isAggregate()) {
diag(DiagnosticsEngine::Note, CE->getConstructor()->getBeginLoc(),
"No need to provide a custom constructor forward sweep for an "
"aggregate type.");
SmallString<128> Name_class;
llvm::raw_svector_ostream OS_class(Name_class);
RD->getNameForDiagnostic(OS_class, m_Context.getPrintingPolicy(),
/*qualified=*/true);
diag(DiagnosticsEngine::Warning, CE->getBeginLoc(),
"No need to provide a custom constructor forward sweep for an "
"aggregate type.");
"'%0' is an aggregate type and its constructor does not require a "
"user-defined forward sweep function",
{OS_class.str()});
const FunctionDecl* constr_forw =
cast<CallExpr>(customReverseForwFnCall)->getDirectCallee();
SmallString<128> Name_forw;
llvm::raw_svector_ostream OS_forw(Name_forw);
constr_forw->getNameForDiagnostic(
OS_forw, m_Context.getPrintingPolicy(), /*qualified=*/true);
diag(DiagnosticsEngine::Note, constr_forw->getBeginLoc(),
"'%0' is defined here", {OS_forw.str()});
}
Expr* callRes = StoreAndRef(customReverseForwFnCall);
Expr* val =
Expand Down
38 changes: 37 additions & 1 deletion test/Gradient/UserDefinedTypes.C
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out 2>&1 | %filecheck %s
// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out -Xclang -verify 2>&1 | %filecheck %s
// RUN: ./UserDefinedTypes.out | %filecheck_exec %s
// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oUserDefinedTypes.out
// RUN: ./UserDefinedTypes.out | %filecheck_exec %s
Expand Down Expand Up @@ -485,6 +485,40 @@ double fn14(double x, double y) {
// CHECK-NEXT: }
// CHECK-NEXT:}

template <typename T, std::size_t N>
struct SimpleArray {
T elements[N]; // Aggregate initialization
};

namespace clad {
namespace custom_derivatives {
namespace class_functions {
template<::std::size_t N>
::clad::ValueAndAdjoint<SimpleArray<double, N>, SimpleArray<double, N>> // expected-note {{'clad::custom_derivatives::class_functions::constructor_reverse_forw<2}}{{' is defined here}}
constructor_reverse_forw(::clad::ConstructorReverseForwTag<SimpleArray<double, N>>) {
SimpleArray<double, N> a;
SimpleArray<double, N> d_a;
return {a, d_a};
}
}}}

double fn15(double x, double y) {
SimpleArray<double, 2> arr; // expected-warning {{'SimpleArray<double, 2>' is an aggregate type and its constructor does not require a user-defined forward sweep function}}
return arr.elements[0];
}

// CHECK:void fn15_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: ::clad::ValueAndAdjoint< ::std::array<double, {{2U|2UL|2ULL}}>, ::std::array<double, {{2U|2UL|2ULL}}> > _t0 = {{.*}}constructor_reverse_forw(clad::ConstructorReverseForwTag<array<double, 2> >());
// CHECK-NEXT: std::array<double, 2> _d_arr(_t0.adjoint);
// CHECK-NEXT: std::array<double, 2> arr(_t0.value);
// CHECK-NEXT: std::array<double, 2> _t1 = arr;
// CHECK-NEXT: clad::ValueAndAdjoint<reference, reference> _t2 = _t1.operator_subscript_forw(0, &_d_arr, 0);
// CHECK-NEXT: {
// CHECK-NEXT: size_type _r0 = {{0U|0UL|0ULL}};
// CHECK-NEXT: _t1.operator_subscript_pullback(0, 1, &_d_arr, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT:}

void print(const Tangent& t) {
for (int i = 0; i < 5; ++i) {
printf("%.2f", t.data[i]);
Expand Down Expand Up @@ -556,6 +590,8 @@ int main() {

INIT_GRADIENT(fn14);
TEST_GRADIENT(fn14, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {30.00, 22.00}

INIT_GRADIENT(fn15);
}

// CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {
Expand Down

0 comments on commit 6797990

Please sign in to comment.