From 852b4dca430a13f08500ab084d28f0db51086d4e Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Mon, 9 Dec 2024 22:16:03 +0200 Subject: [PATCH] warn fix --- lib/Differentiator/ReverseModeVisitor.cpp | 22 +++++++++++---- test/Gradient/UserDefinedTypes.C | 33 ++++++++++++++++++++++- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index ec4862b3f..a7d4dc6cb 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -34,7 +34,9 @@ #include #include +#include "llvm/ADT/SmallString.h" #include "llvm/Support/SaveAndRestore.h" +#include #include #include @@ -4281,12 +4283,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CE->getConstructor(), primalArgs, reverseForwAdjointArgs, /*baseExpr=*/nullptr)) { if (RD->isAggregate()) { - diag(DiagnosticsEngine::Note, CE->getConstructor()->getBeginLoc(), - "No need to provide a custom constructor forward sweep for an " - "aggregate type."); + SmallString<128> Name_class; + llvm::raw_svector_ostream OS_class(Name_class); + RD->getNameForDiagnostic(OS_class, m_Context.getPrintingPolicy(), + /*qualified=*/true); diag(DiagnosticsEngine::Warning, CE->getBeginLoc(), - "No need to provide a custom constructor forward sweep for an " - "aggregate type."); + "'%0' is an aggregate type and its constructor does not require a " + "user-defined forward sweep function", + {OS_class.str()}); + const FunctionDecl* constr_forw = + cast(customReverseForwFnCall)->getDirectCallee(); + SmallString<128> Name_forw; + llvm::raw_svector_ostream OS_forw(Name_forw); + constr_forw->getNameForDiagnostic( + OS_forw, m_Context.getPrintingPolicy(), /*qualified=*/true); + diag(DiagnosticsEngine::Note, constr_forw->getBeginLoc(), + "'%0' is defined here", {OS_forw.str()}); } Expr* callRes = StoreAndRef(customReverseForwFnCall); Expr* val = diff --git a/test/Gradient/UserDefinedTypes.C b/test/Gradient/UserDefinedTypes.C index 6e2580a01..9fb5b14a2 100644 --- a/test/Gradient/UserDefinedTypes.C +++ b/test/Gradient/UserDefinedTypes.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out 2>&1 | %filecheck %s +// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out -Xclang -verify 2>&1 | %filecheck %s // RUN: ./UserDefinedTypes.out | %filecheck_exec %s // RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oUserDefinedTypes.out // RUN: ./UserDefinedTypes.out | %filecheck_exec %s @@ -485,6 +485,35 @@ double fn14(double x, double y) { // CHECK-NEXT: } // CHECK-NEXT:} +template +struct SimpleArray { + T elements[N]; // Aggregate initialization +}; + +namespace clad { +namespace custom_derivatives { +namespace class_functions { +template<::std::size_t N> +::clad::ValueAndAdjoint, SimpleArray> // expected-note {{'clad::custom_derivatives::class_functions::constructor_reverse_forw<2}}{{' is defined here}} +constructor_reverse_forw(::clad::ConstructorReverseForwTag>) { + SimpleArray a; + SimpleArray d_a; + return {a, d_a}; +} +}}} + +double fn15(double x, double y) { + SimpleArray arr; // expected-warning {{'SimpleArray' is an aggregate type and its constructor does not require a user-defined forward sweep function}} + return arr.elements[0]; +} + +// CHECK:void fn15_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: ::clad::ValueAndAdjoint, SimpleArray > _t0 = clad::custom_derivatives::class_functions::constructor_reverse_forw(clad::ConstructorReverseForwTag >()); +// CHECK-NEXT: SimpleArray _d_arr(_t0.adjoint); +// CHECK-NEXT: SimpleArray arr(_t0.value); +// CHECK-NEXT: _d_arr.elements[0] += 1; +// CHECK-NEXT:} + void print(const Tangent& t) { for (int i = 0; i < 5; ++i) { printf("%.2f", t.data[i]); @@ -556,6 +585,8 @@ int main() { INIT_GRADIENT(fn14); TEST_GRADIENT(fn14, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {30.00, 22.00} + + INIT_GRADIENT(fn15); } // CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {