From d4e926402bc74a018f9e1b9f82cc6a94390b29dd Mon Sep 17 00:00:00 2001 From: lecopivo Date: Tue, 29 Aug 2023 11:27:54 -0400 Subject: [PATCH] fix issues caused by the previous commit --- SciLean/DoodleRevCDeriv.lean | 39 ++++++++++++++++--- .../OdeSolvers/Solvers.lean | 3 +- SciLean/Util/SolveFun.lean | 3 +- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/SciLean/DoodleRevCDeriv.lean b/SciLean/DoodleRevCDeriv.lean index cf43b55c..b9deb5d5 100644 --- a/SciLean/DoodleRevCDeriv.lean +++ b/SciLean/DoodleRevCDeriv.lean @@ -25,9 +25,19 @@ theorem cderiv.uncurry_rule (f : X → Y → Z) + cderiv K (fun y => f xy.1 y) xy.2 dxy.2 := sorry_proof -set_option trace.Meta.Tactic.simp.rewrite true in -theorem pi_add (f g : α → β) [Add β] (a :α) - : (f + g) a = f a + g a := by simp +theorem revCDeriv.prod_rule + (f : X → Y → X×Y → Z) (hf : HasAdjDiff K (fun x : X×Y×(X×Y) => f x.1 x.2.1 x.2.2)) + : revCDeriv K (fun xy : X × Y => f xy.1 xy.2 xy) + = + fun x => + let ydf := revCDeriv K (fun x : X×Y×(X×Y) => f x.1 x.2.1 x.2.2) (x.1,x.2,x) + (ydf.1, + fun dz => + let dxy := ydf.2 dz + (dxy.1, dxy.2.1) + dxy.2.2) := +by + -- simp [cderiv.uncurry_rule _ hf] + sorry_proof open BigOperators theorem revCDeriv.pi_rule_v1 @@ -61,8 +71,6 @@ by - - theorem revCDeriv.pi_rule_v1' (f : (i : ι) → X → (ι → X) → Y) (hf : ∀ i, HasAdjDiff K (fun x : X×(ι→X) => f i x.1 x.2)) : (revCDeriv K fun (x : ι → X) (i : ι) => f i (x i) x) @@ -89,6 +97,27 @@ by ftrans sorry_proof + + +example + : revCDeriv K (fun xy : X×Y => (xy.1,xy.2)) + = + fun xy => + (xy, fun dxy => dxy) := +by + conv => + lhs; autodiff; enter[x]; let_normalize + + +example + : revCDeriv K (fun xy : X×Y => (xy.2,xy.1)) + = + fun xy => + ((xy.2,xy.1), fun dxy => (dxy.2, dxy.1)) := +by + conv => + lhs; autodiff; enter[x]; let_normalize + #eval 0 #check diff --git a/SciLean/Modules/DifferentialEquations/OdeSolvers/Solvers.lean b/SciLean/Modules/DifferentialEquations/OdeSolvers/Solvers.lean index bb495880..7cbbddc3 100644 --- a/SciLean/Modules/DifferentialEquations/OdeSolvers/Solvers.lean +++ b/SciLean/Modules/DifferentialEquations/OdeSolvers/Solvers.lean @@ -100,4 +100,5 @@ by solve_for p' from 1 := sorry_proof solve_as_inv solve_as_inv - ftrans; ftrans; ftrans + unfold hold + ftrans; ftrans; ftrans; diff --git a/SciLean/Util/SolveFun.lean b/SciLean/Util/SolveFun.lean index 1fb51828..5e8c0c69 100644 --- a/SciLean/Util/SolveFun.lean +++ b/SciLean/Util/SolveFun.lean @@ -8,6 +8,7 @@ import SciLean.Lean.Meta.Basic import SciLean.Lean.Array import SciLean.Tactic.LetNormalize import SciLean.Util.SorryProof +import SciLean.Util.Hold namespace SciLean @@ -174,7 +175,7 @@ def solveForFrom (e : Expr) (is js : Array Nat) : MetaM (Expr×Expr×MVarId) := let Q₁body ← mkAppFoldrM ``And Qs₁ let Q₁ ← mkLambdaFVars zs Q₁body - let zs'Val ← mkLambdaFVars ys (← mkAppM ``solveFun #[Q₁]) + let zs'Val ← mkAppM ``solveFun #[Q₁] >>= mkLambdaFVars ys >>= (mkAppM ``hold #[·]) -- (← mkAppM ``solveFun #[Q₁]) withLetDecl (zsName.appendAfter "'") (← inferType zs'Val) zs'Val fun zs'Var => do let zs' ← mkProdSplitElem (← mkAppM' zs'Var ys) zs.size