forked from vgvassilev/clad
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tests for non-differentiable attribute in reverse mode
- Loading branch information
1 parent
8749404
commit 98abc84
Showing
2 changed files
with
265 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
// RUN: %cladclang %s -I%S/../../include -oNonDifferentiable.out 2>&1 | %filecheck %s | ||
// RUN: ./NonDifferentiable.out | %filecheck_exec %s | ||
// CHECK-NOT: {{.*error|warning|note:.*}} | ||
|
||
#define non_differentiable __attribute__((annotate("another_attribute"), annotate("non_differentiable"))) | ||
|
||
#include "clad/Differentiator/Differentiator.h" | ||
|
||
extern "C" int printf(const char* fmt, ...); | ||
|
||
class SimpleFunctions1 { | ||
public: | ||
SimpleFunctions1() noexcept : x(0), y(0) {} | ||
SimpleFunctions1(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} | ||
double x; | ||
non_differentiable double y; | ||
double mem_fn_1(double i, double j) { return (x + y) * i + i * j * j; } | ||
non_differentiable double mem_fn_2(double i, double j) { return i * j; } | ||
double mem_fn_3(double i, double j) { return mem_fn_1(i, j) + i * j; } | ||
double mem_fn_4(double i, double j) { return mem_fn_2(i, j) + i * j; } | ||
double mem_fn_5(double i, double j) { return mem_fn_2(i, j) * mem_fn_1(i, j) * i; } | ||
SimpleFunctions1 operator+(const SimpleFunctions1& other) const { | ||
return SimpleFunctions1(x + other.x, y + other.y); | ||
} | ||
}; | ||
|
||
double fn_s1_mem_fn(double i, double j) { | ||
SimpleFunctions1 obj(2, 3); | ||
return obj.mem_fn_1(i, j) + i * j; | ||
} | ||
|
||
double fn_s1_field(double i, double j) { | ||
SimpleFunctions1 obj(2, 3); | ||
return obj.x * obj.y + i * j; | ||
} | ||
|
||
double fn_s1_operator(double i, double j) { | ||
SimpleFunctions1 obj1(2, 3); | ||
SimpleFunctions1 obj2(3, 5); | ||
return (obj1 + obj2).mem_fn_1(i, j); | ||
} | ||
|
||
class non_differentiable SimpleFunctions2 { | ||
public: | ||
SimpleFunctions2() noexcept : x(0), y(0) {} | ||
SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} | ||
double x; | ||
double y; | ||
double mem_fn(double i, double j) { return (x + y) * i + i * j * j; } | ||
SimpleFunctions2 operator+(const SimpleFunctions2& other) const { | ||
return SimpleFunctions2(x + other.x, y + other.y); | ||
} | ||
}; | ||
|
||
double fn_s2_mem_fn(double i, double j) { | ||
SimpleFunctions2 obj(2, 3); | ||
return obj.mem_fn(i, j) + i * j; | ||
} | ||
|
||
double fn_s2_field(double i, double j) { | ||
SimpleFunctions2 *obj0, obj(2, 3); | ||
return obj.x * obj.y + i * j; | ||
} | ||
|
||
double fn_s2_operator(double i, double j) { | ||
SimpleFunctions2 obj1(2, 3); | ||
SimpleFunctions2 obj2(3, 5); | ||
return (obj1 + obj2).mem_fn(i, j); | ||
} | ||
|
||
#define INIT_EXPR(classname) \ | ||
classname expr_1(2, 3); \ | ||
classname expr_2(3, 5); | ||
|
||
#define TEST_CLASS(classname, name, i, j) \ | ||
auto d_##name = clad::gradient(&classname::name); \ | ||
double result_##name[2] = {}; \ | ||
d_##name.execute(expr_1, i, j, &result_##name[0], &result_##name[1]); \ | ||
printf("%.2f %.2f\n", result_##name[0], result_##name[1]); | ||
|
||
#define TEST_FUNC(name, i, j) \ | ||
auto d_##name = clad::gradient(&name); \ | ||
double result_##name[2] = {}; \ | ||
d_##name.execute(i, j, &result_##name[0], &result_##name[1]); \ | ||
printf("%.2f\n", result_##name[0], result_##name[1]); | ||
|
||
int main() { | ||
INIT_EXPR(SimpleFunctions1); | ||
|
||
TEST_CLASS(SimpleFunctions1, mem_fn_1, 3, 5) // CHECK-EXEC: 30.00 | ||
// CHECK-EXEC: 33.00 | ||
|
||
TEST_CLASS(SimpleFunctions1, mem_fn_3, 3, 5) // CHECK-EXEC: 35.00 | ||
// CHECK-EXEC: 38.00 | ||
|
||
TEST_CLASS(SimpleFunctions1, mem_fn_4, 3, 5) // CHECK-EXEC: 5.00 | ||
// CHECK-EXEC: 5.00 | ||
|
||
TEST_CLASS(SimpleFunctions1, mem_fn_5, 3, 5) // CHECK-EXEC: 2700.00 | ||
// CHECK-EXEC: 2970.00 | ||
|
||
TEST_FUNC(fn_s1_mem_fn, 3, 5) // CHECK-EXEC: 35.00 | ||
|
||
TEST_FUNC(fn_s1_field, 3, 5) // CHECK-EXEC: 5.00 | ||
|
||
/*TEST_FUNC(fn_s1_operator, 3, 5) // CHECK-EXEC: 38.00*/ | ||
|
||
TEST_FUNC(fn_s2_mem_fn, 3, 5) // CHECK-EXEC: 5.00 | ||
|
||
/*TEST_FUNC(fn_s2_field, 3, 5) // CHECK-EXEC: 5.00*/ | ||
|
||
/*TEST_FUNC(fn_s2_operator, 3, 5) // CHECK-EXEC: 0.00*/ | ||
|
||
// CHECK: void mem_fn_1_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: (*_d_this).x += 1 * i; | ||
// CHECK-NEXT: (*_d_this).y += 1 * i; | ||
// CHECK-NEXT: *_d_i += (this->x + this->y) * 1; | ||
// CHECK-NEXT: *_d_i += 1 * j * j; | ||
// CHECK-NEXT: *_d_j += i * 1 * j; | ||
// CHECK-NEXT: *_d_j += i * j * 1; | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j); | ||
|
||
// CHECK: void mem_fn_3_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions1 *_t0; | ||
// CHECK-NEXT: _t0 = this; | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: double _r0 = 0; | ||
// CHECK-NEXT: double _r1 = 0; | ||
// CHECK-NEXT: _t0->mem_fn_1_pullback(i, j, 1, &(*_d_this), &_r0, &_r1); | ||
// CHECK-NEXT: *_d_i += _r0; | ||
// CHECK-NEXT: *_d_j += _r1; | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void mem_fn_4_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions1 *_t0; | ||
// CHECK-NEXT: _t0 = this; | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: *_d_i += _r0; | ||
// CHECK-NEXT: *_d_j += _r1; | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void mem_fn_5_grad(double i, double j, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: double _t0; | ||
// CHECK-NEXT: SimpleFunctions1 *_t1; | ||
// CHECK-NEXT: double _t2; | ||
// CHECK-NEXT: SimpleFunctions1 *_t3; | ||
// CHECK-NEXT: _t1 = this; | ||
// CHECK-NEXT: _t2 = this->mem_fn_2(i, j); | ||
// CHECK-NEXT: _t3 = this; | ||
// CHECK-NEXT: _t0 = this->mem_fn_1(i, j); | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: *_d_i += _r0; | ||
// CHECK-NEXT: *_d_j += _r1; | ||
// CHECK-NEXT: double _r2 = 0; | ||
// CHECK-NEXT: double _r3 = 0; | ||
// CHECK-NEXT: _t3->mem_fn_1_pullback(i, j, _t2 * 1 * i, &(*_d_this), &_r2, &_r3); | ||
// CHECK-NEXT: *_d_i += _r2; | ||
// CHECK-NEXT: *_d_j += _r3; | ||
// CHECK-NEXT: *_d_i += _t2 * _t0 * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void fn_s1_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions1 _d_obj({}); | ||
// CHECK-NEXT: SimpleFunctions1 _t0; | ||
// CHECK-NEXT: SimpleFunctions1 obj(2, 3); | ||
// CHECK-NEXT: _t0 = obj; | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: double _r0 = 0; | ||
// CHECK-NEXT: double _r1 = 0; | ||
// CHECK-NEXT: _t0.mem_fn_1_pullback(i, j, 1, &_d_obj, &_r0, &_r1); | ||
// CHECK-NEXT: *_d_i += _r0; | ||
// CHECK-NEXT: *_d_j += _r1; | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void fn_s1_field_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions1 _d_obj({}); | ||
// CHECK-NEXT: SimpleFunctions1 obj(2, 3); | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: _d_obj.x += 1 * obj.y; | ||
// CHECK-NEXT: _d_obj.y += obj.x * 1; | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void fn_s2_mem_fn_grad(double i, double j, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: SimpleFunctions2 _d_obj({}); | ||
// CHECK-NEXT: SimpleFunctions2 _t0; | ||
// CHECK-NEXT: SimpleFunctions2 obj(2, 3); | ||
// CHECK-NEXT: _t0 = obj; | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: *_d_i += _r0; | ||
// CHECK-NEXT: *_d_j += _r1; | ||
// CHECK-NEXT: *_d_i += 1 * j; | ||
// CHECK-NEXT: *_d_j += i * 1; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
// CHECK: void mem_fn_1_pullback(double i, double j, double _d_y, SimpleFunctions1 *_d_this, double *_d_i, double *_d_j) { | ||
// CHECK-NEXT: { | ||
// CHECK-NEXT: (*_d_this).x += _d_y * i; | ||
// CHECK-NEXT: (*_d_this).y += _d_y * i; | ||
// CHECK-NEXT: *_d_i += (this->x + this->y) * _d_y; | ||
// CHECK-NEXT: *_d_i += _d_y * j * j; | ||
// CHECK-NEXT: *_d_j += i * _d_y * j; | ||
// CHECK-NEXT: *_d_j += i * j * _d_y; | ||
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1 | ||
|
||
#define non_differentiable __attribute__((annotate("non_differentiable"))) | ||
|
||
#include "clad/Differentiator/Differentiator.h" | ||
|
||
extern "C" int printf(const char* fmt, ...); | ||
|
||
class non_differentiable SimpleFunctions2 { | ||
public: | ||
SimpleFunctions2() noexcept : x(0), y(0) {} | ||
SimpleFunctions2(double p_x, double p_y) noexcept : x(p_x), y(p_y) {} | ||
double x; | ||
double y; | ||
double mem_fn(double i, double j) { return (x + y) * i + i * j * j; } // expected-error {{attempted differentiation of method 'mem_fn' in class 'SimpleFunctions2', which is marked as non-differentiable}} | ||
}; | ||
|
||
non_differentiable double fn_s2_mem_fn(double i, double j) { | ||
SimpleFunctions2 obj(2, 3); | ||
return obj.mem_fn(i, j) + i * j; | ||
} | ||
|
||
#define INIT_EXPR(classname) \ | ||
classname expr_1(2, 3); \ | ||
classname expr_2(3, 5); | ||
|
||
#define TEST_CLASS(classname, name, i, j) \ | ||
auto d_##name = clad::differentiate(&classname::name, "i"); \ | ||
printf("%.2f\n", d_##name.execute(expr_1, i, j)); \ | ||
printf("%.2f\n", d_##name.execute(expr_2, i, j)); \ | ||
printf("\n"); | ||
|
||
#define TEST_FUNC(name, i, j) \ | ||
auto d_##name = clad::differentiate(&name, "i"); \ | ||
printf("%.2f\n", d_##name.execute(i, j)); \ | ||
printf("\n"); | ||
|
||
int main() { | ||
INIT_EXPR(SimpleFunctions2); | ||
TEST_CLASS(SimpleFunctions2, mem_fn, 3, 5); | ||
TEST_FUNC(fn_s2_mem_fn, 3, 5); // expected-error {{attempted differentiation of function 'fn_s2_mem_fn', which is marked as non-differentiable}} | ||
} |