-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
52 changed files
with
3,539 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import SciLean | ||
|
||
|
||
-- there are some linker issues :( | ||
#exit | ||
|
||
open SciLean | ||
|
||
|
||
variable (R : Type) [RealScalar R] | ||
|
||
set_default_scalar R | ||
|
||
|
||
/-- | ||
info: fun x => | ||
let zdz := x * x; | ||
let ydy := zdz ^ 3; | ||
let ydy := ydy + ydy; | ||
fun dx => | ||
let zdz_1 := x * dx + dx * x; | ||
let ydy_1 := 3 * zdz_1 * zdz ^ 2; | ||
let ydy_2 := ydy_1 + ydy_1; | ||
(ydy, ydy_2) : R → R → R × R | ||
-/ | ||
#guard_msgs in | ||
#check (∂> (x : R), Id'.run do | ||
let mut x := x | ||
x := x*x | ||
x := x^3 | ||
x := x+x | ||
return x) rewrite_by autodiff | ||
|
||
|
||
variable (a b : R) | ||
|
||
|
||
#check (fwdFDeriv R fun x : R => Id'.run do | ||
let mut x := x | ||
if h : a < b then | ||
x := x^3 | ||
if 0 < b then | ||
x := x^4 | ||
if 0 < a then | ||
x := x^5 | ||
return x) rewrite_by | ||
fun_trans (config:={zeta:=false}) only [simp_core] | ||
let_normalize (config:={removeLambdaLet:=false,removeNoFVarLet:=true}) | ||
fun_trans (config:={zeta:=false}) only [simp_core] | ||
lsimp (config:={singlePass:=true}) | ||
lsimp (config:={singlePass:=true}) | ||
|
||
|
||
#check (fwdFDeriv R (fun x : R => | ||
let f := fun y : R => y | ||
f x)) rewrite_by fun_trans (config:={zeta:=false}) -- nothing happens :( | ||
|
||
|
||
set_option linter.unusedTactic false | ||
example : Differentiable R ((fun x : R => | ||
let f := fun y : R => y + x | ||
f x)) := by (try fun_prop); sorry_proof -- fun_prop does not work :( | ||
|
||
|
||
-- #check (fwdFDeriv R fun x : R => Id'.run do | ||
-- let mut x := x | ||
-- x := x^2 | ||
-- if a < b then | ||
-- x := x^3 | ||
-- if 0 < b then | ||
-- x := x^4 | ||
-- if 0 < a then | ||
-- x := x^5 | ||
-- return x) rewrite_by autodiff (disch:=simp only [simp_core]) | ||
|
||
-- #guard_msgs in | ||
-- #check (<∂ (x : R), Id'.run do | ||
-- let mut x := x | ||
-- x := x*x | ||
-- x := x^3 | ||
-- x := x+x | ||
-- return x) rewrite_by autodiff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import SciLean | ||
|
||
open SciLean | ||
|
||
variable {R : Type} [RealScalar R] [PlainDataType R] | ||
{n : Nat} | ||
|
||
|
||
set_default_scalar R | ||
|
||
|
||
/-- | ||
info: fun x => | ||
IndexType.foldl | ||
(fun dx i => | ||
let dx := ArrayType.modify dx i fun xi => xi + 1; | ||
dx) | ||
0 : R^[n] → R^[n] | ||
-/ | ||
#guard_msgs in | ||
#check (∇ (x : R^[n]), ∑ i, x[i]) rewrite_by autodiff | ||
|
||
|
||
/-- | ||
info: fun x => | ||
IndexType.foldl | ||
(fun dx i => | ||
let ydf := x[i]; | ||
let y' := 2 * ydf; | ||
let dx := ArrayType.modify dx i fun xi => xi + y'; | ||
dx) | ||
0 : R^[n] → R^[n] | ||
-/ | ||
#guard_msgs in | ||
#check (∇ (x : R^[n]), ∑ i, x[i]^2) rewrite_by autodiff | ||
|
||
|
||
variable (A : R^[n,n]) | ||
|
||
|
||
/-- | ||
info: fun x => | ||
IndexType.foldl | ||
(fun dx i => | ||
let zdg := x[i]; | ||
let dx := | ||
IndexType.foldl | ||
(fun dx i_1 => | ||
let ydf := A[i, i_1]; | ||
let ydf_1 := ydf * zdg; | ||
let zdg := x[i_1]; | ||
let dy₂ := ydf * zdg; | ||
let dx := ArrayType.modify dx i_1 fun xi => xi + ydf_1; | ||
let dx := ArrayType.modify dx i fun xi => xi + dy₂; | ||
dx) | ||
dx; | ||
dx) | ||
0 : R^[n] → R^[n] | ||
-/ | ||
#guard_msgs in | ||
#check (∇ (x : R^[n]), ∑ i j, A[i,j]*x[i]*x[j]) rewrite_by autodiff |
Oops, something went wrong.