Skip to content

Commit

Permalink
fix issues caused by the previous commit
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Aug 29, 2023
1 parent 32dacb2 commit d4e9264
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
39 changes: 34 additions & 5 deletions SciLean/DoodleRevCDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,19 @@ theorem cderiv.uncurry_rule (f : X → Y → Z)
+
cderiv K (fun y => f xy.1 y) xy.2 dxy.2 := sorry_proof

set_option trace.Meta.Tactic.simp.rewrite true in
theorem pi_add (f g : α → β) [Add β] (a :α)
: (f + g) a = f a + g a := by simp
theorem revCDeriv.prod_rule
(f : X → Y → X×Y → Z) (hf : HasAdjDiff K (fun x : X×Y×(X×Y) => f x.1 x.2.1 x.2.2))
: revCDeriv K (fun xy : X × Y => f xy.1 xy.2 xy)
=
fun x =>
let ydf := revCDeriv K (fun x : X×Y×(X×Y) => f x.1 x.2.1 x.2.2) (x.1,x.2,x)
(ydf.1,
fun dz =>
let dxy := ydf.2 dz
(dxy.1, dxy.2.1) + dxy.2.2) :=
by
-- simp [cderiv.uncurry_rule _ hf]
sorry_proof

open BigOperators
theorem revCDeriv.pi_rule_v1
Expand Down Expand Up @@ -61,8 +71,6 @@ by





theorem revCDeriv.pi_rule_v1'
(f : (i : ι) → X → (ι → X) → Y) (hf : ∀ i, HasAdjDiff K (fun x : X×(ι→X) => f i x.1 x.2))
: (revCDeriv K fun (x : ι → X) (i : ι) => f i (x i) x)
Expand All @@ -89,6 +97,27 @@ by
ftrans
sorry_proof



example
: revCDeriv K (fun xy : X×Y => (xy.1,xy.2))
=
fun xy =>
(xy, fun dxy => dxy) :=
by
conv =>
lhs; autodiff; enter[x]; let_normalize


example
: revCDeriv K (fun xy : X×Y => (xy.2,xy.1))
=
fun xy =>
((xy.2,xy.1), fun dxy => (dxy.2, dxy.1)) :=
by
conv =>
lhs; autodiff; enter[x]; let_normalize

#eval 0

#check
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,5 @@ by
solve_for p' from 1 := sorry_proof
solve_as_inv
solve_as_inv
ftrans; ftrans; ftrans
unfold hold
ftrans; ftrans; ftrans;
3 changes: 2 additions & 1 deletion SciLean/Util/SolveFun.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import SciLean.Lean.Meta.Basic
import SciLean.Lean.Array
import SciLean.Tactic.LetNormalize
import SciLean.Util.SorryProof
import SciLean.Util.Hold

namespace SciLean

Expand Down Expand Up @@ -174,7 +175,7 @@ def solveForFrom (e : Expr) (is js : Array Nat) : MetaM (Expr×Expr×MVarId) :=
let Q₁body ← mkAppFoldrM ``And Qs₁
let Q₁ ← mkLambdaFVars zs Q₁body

let zs'Val ← mkLambdaFVars ys (← mkAppM ``solveFun #[Q₁])
let zs'Val ← mkAppM ``solveFun #[Q₁] >>= mkLambdaFVars ys >>= (mkAppM ``hold #[·]) -- (← mkAppM ``solveFun #[Q₁])
withLetDecl (zsName.appendAfter "'") (← inferType zs'Val) zs'Val fun zs'Var => do

let zs' ← mkProdSplitElem (← mkAppM' zs'Var ys) zs.size
Expand Down

0 comments on commit d4e9264

Please sign in to comment.