Skip to content

Commit

Permalink
Add tests for non-differentiable attribute in reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
MihailMihov committed Jul 5, 2024
1 parent 8749404 commit 98abc84
Show file tree
Hide file tree
Showing 2 changed files with 265 additions and 0 deletions.
223 changes: 223 additions & 0 deletions test/Gradient/NonDifferentiable.C
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// RUN: %cladclang %s -I%S/../../include -oNonDifferentiable.out 2>&1 | %filecheck %s
// RUN: ./NonDifferentiable.out | %filecheck_exec %s
// CHECK-NOT: {{.*error|warning|note:.*}}

#define non_differentiable __attribute__((annotate("another_attribute"), annotate("non_differentiable")))

#include "clad/Differentiator/Differentiator.h"

extern "C" int printf(const char* fmt, ...);

class SimpleFunctions1 {
public:
SimpleFunctions1() noexcept : x(0), y(0) {}
SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y) {}
double x;
non_differentiable double y;
double mem_fn_1(double i, double j) { return (x + y) * i + i * j * j; }
non_differentiable double mem_fn_2(double i, double j) { return i * j; }
double mem_fn_3(double i, double j) { return mem_fn_1(i, j) + i * j; }
double mem_fn_4(double i, double j) { return mem_fn_2(i, j) + i * j; }
double mem_fn_5(double i, double j) { return mem_fn_2(i, j) * mem_fn_1(i, j) * i; }
SimpleFunctions1 operator+(const SimpleFunctions1& other) const {
return SimpleFunctions1(x + other.x, y + other.y);
}
};

double fn_s1_mem_fn(double i, double j) {
SimpleFunctions1 obj(2, 3);
return obj.mem_fn_1(i, j) + i * j;
}

double fn_s1_field(double i, double j) {
SimpleFunctions1 obj(2, 3);
return obj.x * obj.y + i * j;
}

double fn_s1_operator(double i, double j) {
SimpleFunctions1 obj1(2, 3);
SimpleFunctions1 obj2(3, 5);
return (obj1 + obj2).mem_fn_1(i, j);
}

class non_differentiable SimpleFunctions2 {
public:
SimpleFunctions2() noexcept : x(0), y(0) {}
SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {}
double x;
double y;
double mem_fn(double i, double j) { return (x + y) * i + i * j * j; }
SimpleFunctions2 operator+(const SimpleFunctions2& other) const {
return SimpleFunctions2(x + other.x, y + other.y);
}
};

double fn_s2_mem_fn(double i, double j) {
SimpleFunctions2 obj(2, 3);
return obj.mem_fn(i, j) + i * j;
}

double fn_s2_field(double i, double j) {
SimpleFunctions2 *obj0, obj(2, 3);
return obj.x * obj.y + i * j;
}

double fn_s2_operator(double i, double j) {
SimpleFunctions2 obj1(2, 3);
SimpleFunctions2 obj2(3, 5);
return (obj1 + obj2).mem_fn(i, j);
}

#define INIT_EXPR(classname) \
classname expr_1(2, 3); \
classname expr_2(3, 5);

#define TEST_CLASS(classname, name, i, j) \
auto d_##name = clad::gradient(&classname::name); \
double result_##name[2] = {}; \
d_##name.execute(expr_1, i, j, &result_##name[0], &result_##name[1]); \
printf("%.2f %.2f\n", result_##name[0], result_##name[1]);

#define TEST_FUNC(name, i, j) \
auto d_##name = clad::gradient(&name); \
double result_##name[2] = {}; \
d_##name.execute(i, j, &result_##name[0], &result_##name[1]); \
printf("%.2f\n", result_##name[0], result_##name[1]);

int main() {
INIT_EXPR(SimpleFunctions1);

TEST_CLASS(SimpleFunctions1, mem_fn_1, 3, 5) // CHECK-EXEC: 30.00
// CHECK-EXEC: 33.00

TEST_CLASS(SimpleFunctions1, mem_fn_3, 3, 5) // CHECK-EXEC: 35.00
// CHECK-EXEC: 38.00

TEST_CLASS(SimpleFunctions1, mem_fn_4, 3, 5) // CHECK-EXEC: 5.00
// CHECK-EXEC: 5.00

TEST_CLASS(SimpleFunctions1, mem_fn_5, 3, 5) // CHECK-EXEC: 2700.00
// CHECK-EXEC: 2970.00

TEST_FUNC(fn_s1_mem_fn, 3, 5) // CHECK-EXEC: 35.00

TEST_FUNC(fn_s1_field, 3, 5) // CHECK-EXEC: 5.00

/*TEST_FUNC(fn_s1_operator, 3, 5) // CHECK-EXEC: 38.00*/

TEST_FUNC(fn_s2_mem_fn, 3, 5) // CHECK-EXEC: 5.00

/*TEST_FUNC(fn_s2_field, 3, 5) // CHECK-EXEC: 5.00*/

/*TEST_FUNC(fn_s2_operator, 3, 5) // CHECK-EXEC: 0.00*/

// CHECK: void mem_fn_1_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: {
// CHECK-NEXT: (*_d_this).x += 1 * i;
// CHECK-NEXT: (*_d_this).y += 1 * i;
// CHECK-NEXT: *_d_i += (this->x + this->y) * 1;
// CHECK-NEXT: *_d_i += 1 * j * j;
// CHECK-NEXT: *_d_j += i * 1 * j;
// CHECK-NEXT: *_d_j += i * j * 1;
// CHECK-NEXT: }

// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j);

// CHECK: void mem_fn_3_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 *_t0;
// CHECK-NEXT: _t0 = this;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: _t0->mem_fn_1_pullback(i, j, 1, &(*_d_this), &_r0, &_r1);
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void mem_fn_4_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 *_t0;
// CHECK-NEXT: _t0 = this;
// CHECK-NEXT: {
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void mem_fn_5_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: SimpleFunctions1 *_t1;
// CHECK-NEXT: double _t2;
// CHECK-NEXT: SimpleFunctions1 *_t3;
// CHECK-NEXT: _t1 = this;
// CHECK-NEXT: _t2 = this->mem_fn_2(i, j);
// CHECK-NEXT: _t3 = this;
// CHECK-NEXT: _t0 = this->mem_fn_1(i, j);
// CHECK-NEXT: {
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: double _r2 = 0;
// CHECK-NEXT: double _r3 = 0;
// CHECK-NEXT: _t3->mem_fn_1_pullback(i, j, _t2 * 1 * i, &(*_d_this), &_r2, &_r3);
// CHECK-NEXT: *_d_i += _r2;
// CHECK-NEXT: *_d_j += _r3;
// CHECK-NEXT: *_d_i += _t2 * _t0 * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn_s1_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 _t0;
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: _t0 = obj;
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: _t0.mem_fn_1_pullback(i, j, 1, &_d_obj, &_r0, &_r1);
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn_s1_field_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: {
// CHECK-NEXT: _d_obj.x += 1 * obj.y;
// CHECK-NEXT: _d_obj.y += obj.x * 1;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn_s2_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: SimpleFunctions2 _d_obj({});
// CHECK-NEXT: SimpleFunctions2 _t0;
// CHECK-NEXT: SimpleFunctions2 obj(2, 3);
// CHECK-NEXT: _t0 = obj;
// CHECK-NEXT: {
// CHECK-NEXT: *_d_i += _r0;
// CHECK-NEXT: *_d_j += _r1;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) {
// CHECK-NEXT: {
// CHECK-NEXT: (*_d_this).x += _d_y * i;
// CHECK-NEXT: (*_d_this).y += _d_y * i;
// CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y;
// CHECK-NEXT: *_d_i += _d_y * j * j;
// CHECK-NEXT: *_d_j += i * _d_y * j;
// CHECK-NEXT: *_d_j += i * j * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }

}
42 changes: 42 additions & 0 deletions test/Gradient/NonDifferentiableError.C
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1

#define non_differentiable __attribute__((annotate("non_differentiable")))

#include "clad/Differentiator/Differentiator.h"

extern "C" int printf(const char* fmt, ...);

class non_differentiable SimpleFunctions2 {
public:
SimpleFunctions2() noexcept : x(0), y(0) {}
SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {}
double x;
double y;
double mem_fn(double i, double j) { return (x + y) * i + i * j * j; } // expected-error {{attempted differentiation of method 'mem_fn' in class 'SimpleFunctions2', which is marked as non-differentiable}}
};

non_differentiable double fn_s2_mem_fn(double i, double j) {
SimpleFunctions2 obj(2, 3);
return obj.mem_fn(i, j) + i * j;
}

#define INIT_EXPR(classname) \
classname expr_1(2, 3); \
classname expr_2(3, 5);

#define TEST_CLASS(classname, name, i, j) \
auto d_##name = clad::differentiate(&classname::name, "i"); \
printf("%.2f\n", d_##name.execute(expr_1, i, j)); \
printf("%.2f\n", d_##name.execute(expr_2, i, j)); \
printf("\n");

#define TEST_FUNC(name, i, j) \
auto d_##name = clad::differentiate(&name, "i"); \
printf("%.2f\n", d_##name.execute(i, j)); \
printf("\n");

int main() {
INIT_EXPR(SimpleFunctions2);
TEST_CLASS(SimpleFunctions2, mem_fn, 3, 5);
TEST_FUNC(fn_s2_mem_fn, 3, 5); // expected-error {{attempted differentiation of function 'fn_s2_mem_fn', which is marked as non-differentiable}}
}

0 comments on commit 98abc84

Please sign in to comment.