Skip to content

Commit

Permalink
custom simp set for ftrans
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Oct 25, 2023
1 parent 37e2f84 commit 7123cab
Show file tree
Hide file tree
Showing 18 changed files with 104 additions and 74 deletions.
10 changes: 5 additions & 5 deletions SciLean/Core/FloatAsReal.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ import SciLean.Core.Objects.Scalar
namespace SciLean

instance : CommRing Float where
zero_mul := by intros; apply isomorph.ext `FloatToReal; simp; ftrans
zero_mul := by intros; apply isomorph.ext `FloatToReal; simp; ftrans; simp
mul_zero := by intros; apply isomorph.ext `FloatToReal; simp; ftrans
mul_comm := by intros; apply isomorph.ext `FloatToReal; simp; ftrans; rw[mul_comm]
left_distrib := by intros; apply isomorph.ext `FloatToReal; simp; ftrans; ftrans; rw[mul_add]
right_distrib := by intros; apply isomorph.ext `FloatToReal; simp; ftrans; ftrans; rw[add_mul]
mul_one := by intros; apply isomorph.ext `FloatToReal; simp; ftrans
one_mul := by intros; apply isomorph.ext `FloatToReal; simp; ftrans
left_distrib := by intros; apply isomorph.ext `FloatToReal; simp; ftrans; simp; ftrans; simp; rw[mul_add]
right_distrib := by intros; apply isomorph.ext `FloatToReal; simp; ftrans; simp; ftrans; simp; rw[add_mul]
mul_one := by intros; apply isomorph.ext `FloatToReal; simp; ftrans; simp
one_mul := by intros; apply isomorph.ext `FloatToReal; simp; ftrans; simp
npow n x := x.pow (n.toFloat) --- TODO: change this implementation
npow_zero n := sorry_proof
npow_succ n x := sorry_proof
Expand Down
2 changes: 1 addition & 1 deletion SciLean/Core/FunctionPropositions/Diffeomorphism.lean
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ by
funext x dx
have H : ((cderiv K (fun x => invFun (f x) ∘ (f x)) x dx) ∘ (invFun (f x)))
=
0 := by simp[invFun_comp (hf _).1.1]; ftrans
0 := by simp[invFun_comp (hf _).1.1]; ftrans; simp
rw[← sub_zero (cderiv K (fun x => Function.invFun (f x)) x dx)]
rw[← H]
simp_rw[comp.arg_fg_a0.cderiv_rule (K:=K) (fun x => invFun (f x)) f (by fprop) (by fprop)]
Expand Down
2 changes: 1 addition & 1 deletion SciLean/Core/FunctionPropositions/HasAdjDiffAt.lean
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ theorem HSMul.hSMul.arg_a1.HasAdjDiffAt_rule
: HasAdjDiffAt K (fun x => c • g x) x :=
by
have ⟨_,_⟩ := hg
constructor; fprop; ftrans; fprop
constructor; fprop; ftrans; simp; fprop



Expand Down
3 changes: 2 additions & 1 deletion SciLean/Core/FunctionPropositions/IsContinuousLinearMap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Mathlib.Analysis.InnerProductSpace.Basic

import SciLean.Tactic.FProp.Basic
import SciLean.Tactic.FProp.Notation
import SciLean.Tactic.FTrans.Init

namespace SciLean

Expand All @@ -27,7 +28,7 @@ def ContinuousLinearMap.mk'
: X →L[R] Y :=
⟨⟨⟨f, hf.linear.map_add⟩, hf.linear.map_smul⟩, hf.cont⟩

@[simp]
@[simp, ftrans_simp]
theorem ContinuousLinearMap.mk'_eval
(x : X) (f : X → Y) (hf : IsContinuousLinearMap R f)
: mk' R f hf x = f x := by rfl
Expand Down
4 changes: 2 additions & 2 deletions SciLean/Core/FunctionTransformations/CDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def scalarCDeriv (f : K → X) (t : K) : X := cderiv K f t 1
-- Basic identities ------------------------------------------------------------
--------------------------------------------------------------------------------

@[simp]
@[simp, ftrans_simp]
theorem cderiv_apply
(f : X → Y → Z) (x dx : X) (y : Y)
: cderiv K f x dx y
=
cderiv K (fun x' => f x' y) x dx := sorry_proof

@[simp]
@[simp, ftrans_simp]
theorem cderiv_zero
(f : X → Y) (x : X)
: cderiv K f x 0 = 0 := by sorry_proof
Expand Down
6 changes: 3 additions & 3 deletions SciLean/Core/FunctionTransformations/RevCDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revCDeriv
funext _; ftrans; ftrans; simp
funext _; ftrans; ftrans

theorem comp_rule'
(f : Y → Z) (g : X → Y)
Expand All @@ -163,7 +163,7 @@ by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revCDeriv
funext _; simp; ftrans
funext _; simp; ftrans;


theorem let_rule
Expand All @@ -183,7 +183,7 @@ by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revCDeriv
funext _; ftrans; ftrans; simp
funext _; ftrans; ftrans


theorem let_rule'
Expand Down
5 changes: 3 additions & 2 deletions SciLean/Core/FunctionTransformations/RevDerivUpdate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ theorem proj_rule (i : ι)
(x i, fun dxi k dx j => if h : i=j then dx j + k • h ▸ dxi else dx j) :=
by
unfold revDerivUpdate
funext _; ftrans; ftrans; funext dxi k dx j; simp; sorry_proof
funext _; ftrans; ftrans;
simp; funext dxi k dx j; simp; sorry_proof
variable {E}

theorem comp_rule
Expand Down Expand Up @@ -149,7 +150,7 @@ by
have _ := fun i => (hf i).1
have _ := fun i => (hf i).2
unfold revDerivUpdate
funext _; ftrans; ftrans; -- simp
funext _; ftrans; ftrans; simp
funext dy dx
sorry_proof

Expand Down
5 changes: 3 additions & 2 deletions SciLean/Core/FunctionTransformations/RevFDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@ theorem id_rule
by
unfold revFDeriv
funext _
ftrans; ftrans; ext; simp
ftrans; ftrans; ext; simp; simp



theorem const_rule (y : Y)
: revFDeriv K (fun _ : X => y) = fun x => (y, fun _ => 0) :=
by
unfold revFDeriv
funext _
ftrans; ftrans; ext; simp
ftrans; ftrans; ext; simp; simp
variable{X}

variable(E)
Expand Down
18 changes: 10 additions & 8 deletions SciLean/Core/Monads/StateT.lean
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ instance (S : Type _) [Vec K S] : FwdDerivMonad K (StateT S m) (StateT (S×S) m'
simp at hf; simp at hg
simp[fwdCDeriv, bind, StateT.bind, StateT.bind.match_1]
ftrans
simp

fwdDerivM_pair f hf :=
by
Expand Down Expand Up @@ -90,6 +91,7 @@ theorem _root_.getThe.arg.fwdDerivValM_rule
by
simp[getThe, MonadStateOf.get, StateT.get,fwdDerivValM, fwdDerivM, pure, StateT.pure]
ftrans
simp

-- MonadState.get --------------------------------------------------------------
--------------------------------------------------------------------------------
Expand All @@ -111,7 +113,7 @@ theorem _root_.MonadState.get.arg.fwdDerivValM_rule
by
simp[MonadState.get, getThe, MonadStateOf.get, StateT.get,fwdDerivValM, fwdDerivM]
ftrans

simp

-- -- setThe ----------------------------------------------------------------------
-- --------------------------------------------------------------------------------
Expand Down Expand Up @@ -162,7 +164,7 @@ theorem _root_.MonadStateOf.set.arg_a0.fwdDerivM_rule
pure ((),())) :=
by
simp[set, StateT.set,fwdDerivM, bind,Bind.bind, StateT.bind]
ftrans; congr
ftrans; congr; simp; rfl


-- modifyThe ----------------------------------------------------------------------
Expand All @@ -189,7 +191,7 @@ theorem _root_.modifyThe.arg_f.fwdDerivM_rule
pure ((),())) :=
by
simp[modifyThe, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet,fwdDerivM,bind,Bind.bind, StateT.bind]
ftrans; congr
ftrans; congr; simp; rfl


-- modify ----------------------------------------------------------------------
Expand All @@ -216,7 +218,7 @@ theorem _root_.modify.arg_f.fwdDerivM_rule
pure ((),())) :=
by
simp[modify, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet,fwdDerivM,bind,Bind.bind, StateT.bind]
ftrans; congr
ftrans; congr; simp; rfl


end FwdDerivMonad
Expand Down Expand Up @@ -306,7 +308,7 @@ theorem _root_.getThe.arg.revDerivValM_rule
pure ((← getThe S), fun ds => modifyThe S (fun ds' => ds + ds'))) :=
by
simp[getThe, MonadStateOf.get, StateT.get,revDerivValM, revDerivM, pure, StateT.pure, bind, StateT.bind, set, StateT.set, modifyThe, modify, MonadStateOf.modifyGet, StateT.modifyGet]
ftrans
ftrans; simp

-- MonadState.get --------------------------------------------------------------
--------------------------------------------------------------------------------
Expand All @@ -328,7 +330,7 @@ theorem _root_.MonadState.get.arg.revDerivValM_rule
pure ((← get), fun ds => modify (fun ds' => ds + ds'))) :=
by
simp[MonadState.get, getThe, MonadStateOf.get, StateT.get,revDerivValM, revDerivM, pure, StateT.pure, bind, StateT.bind, set, StateT.set, modifyThe, modify, MonadStateOf.modifyGet, StateT.modifyGet, modifyGet]
ftrans
ftrans; simp


-- -- setThe ----------------------------------------------------------------------
Expand Down Expand Up @@ -386,7 +388,7 @@ theorem _root_.MonadStateOf.set.arg_a0.revDerivM_rule
pure dx)) :=
by
simp[set, StateT.set, revDerivM, getThe, MonadStateOf.get, StateT.get, bind, StateT.bind, pure, StateT.pure, get]
ftrans
ftrans; simp

-- -- modifyThe ----------------------------------------------------------------------
-- --------------------------------------------------------------------------------
Expand Down Expand Up @@ -445,7 +447,7 @@ theorem _root_.modify.arg_f.revDerivM_rule
pure dxs.1)) :=
by
simp[modifyThe, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet,revDerivM, bind, StateT.bind, getThe, MonadStateOf.get, StateT.get, set, StateT.set, get, pure, StateT.pure, modify]
ftrans
ftrans; simp

end RevDerivMonad

Expand Down
3 changes: 2 additions & 1 deletion SciLean/Core/Notation/Autodiff.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ macro "autodiff" : conv => do
`(conv|
(lsimp (config := {failIfUnchanged := false, zeta := false, singlePass := true}) only [cderiv_as_fwdCDeriv, scalarGradient, gradient, scalarCDeriv,revCDerivEval]
ftrans only
lsimp (config := {failIfUnchanged := false, zeta := false}) [uncurryN, UncurryN.uncurry, curryN, CurryN.curry]))
simp (config := {zeta:=false}) only [uncurryN, UncurryN.uncurry, CurryN.curry, curryN]
lsimp (config := {failIfUnchanged := false, zeta := false})))

macro "autodiff" : tactic => do
`(tactic| conv => autodiff)
3 changes: 2 additions & 1 deletion SciLean/Core/Objects/Scalar.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Mathlib.Analysis.SpecialFunctions.Pow.Complex
import Mathlib.Analysis.SpecialFunctions.Pow.Real

import SciLean.Util.SorryProof
import SciLean.Tactic.FTrans.Init

namespace SciLean

Expand Down Expand Up @@ -93,7 +94,7 @@ instance {R K} [Scalar R K] : HPow K K K := ⟨fun x y => Scalar.pow x y⟩

open ComplexConjugate

@[simp]
@[simp, ftrans_simp]
theorem conj_for_real_scalar {R} [RealScalar R] (r : R)
: conj r = r := sorry_proof

Expand Down
23 changes: 15 additions & 8 deletions SciLean/Tactic/FTrans/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import Mathlib.Tactic.NormNum.Core

import SciLean.Lean.Meta.Basic
import SciLean.Tactic.FTrans.Init
import SciLean.Tactic.FTrans.Simp
import SciLean.Tactic.LSimp2.Main
import SciLean.Tactic.StructuralInverse
import SciLean.Tactic.AnalyzeLambda
Expand Down Expand Up @@ -584,9 +585,9 @@ mutual
partial def methods : Simp.Methods :=
if useSimp then {
pre := fun e ↦ do
Simp.andThen (← Simp.preDefault e discharge) (fun e' => tryFTrans? e' discharge)
Simp.andThen (← withTraceNode `lsimp (fun _ => do pure s!"lsimp default pre") do Simp.preDefault e discharge) (fun e' => tryFTrans? e' discharge)
post := fun e ↦ do
Simp.andThen (← Simp.postDefault e discharge) (fun e' => tryFTrans? e' discharge (post := true))
Simp.andThen (← withTraceNode `lsimp (fun _ => do pure s!"lsimp default post") do Simp.postDefault e discharge) (fun e' => tryFTrans? e' discharge (post := true))
discharge? := discharge
} else {
pre := fun e ↦ do
Expand Down Expand Up @@ -652,11 +653,15 @@ def fTransAt (g : MVarId) (ctx : Simp.Context) (fvarIdsToSimp : Array FVarId)

open Qq Lean Meta Elab Tactic Term

def getFTransTheorems : CoreM SimpTheorems := do
let ext ← Lean.Meta.getSimpExtension? "ftrans_simp"
ext.get!.getTheorems

/-- Constructs a simp context from the simp argument syntax. -/
def getSimpContext (args : Syntax) (simpOnly := false) :
def getFTransContext (args : Syntax) (simpOnly := false) :
TacticM Simp.Context := do
let simpTheorems ←
if simpOnly then simpOnlyBuiltins.foldlM (·.addConst ·) {} else getSimpTheorems
if simpOnly then simpOnlyBuiltins.foldlM (·.addConst ·) {} else getFTransTheorems
let mut { ctx, starArg } ← elabSimpArgs args (eraseLocal := false) (kind := .simp)
{ simpTheorems := #[simpTheorems], congrTheorems := ← getSimpCongrTheorems }
unless starArg do return ctx
Expand All @@ -679,8 +684,8 @@ Elaborates a call to `fun_trans only? [args]` or `norm_num1`.
-- FIXME: had to inline a bunch of stuff from `mkSimpContext` and `simpLocation` here
def elabFTrans (args : Syntax) (loc : Syntax)
(simpOnly := false) (useSimp := true) : TacticM Unit := do
let ctx ← getSimpContext args (!useSimp || simpOnly)
let ctx := {ctx with config := {ctx.config with iota := true, zeta := false, singlePass := true, autoUnfold := true}}
let ctx ← getFTransContext args (!useSimp || simpOnly)
let ctx := {ctx with config := {ctx.config with iota := true, zeta := false, singlePass := true, dsimp := false, decide := false}}
let g ← getMainGoal
let res ← match expandOptLocation loc with
| .targets hyps simplifyTarget => fTransAt g ctx (← getFVarIds hyps) simplifyTarget useSimp
Expand All @@ -699,8 +704,10 @@ open Lean Elab Tactic Lean.Parser.Tactic

syntax (name := fTransConv) "ftrans" &" only"? (simpArgs)? : conv


/-- Elaborator for `norm_num` conv tactic. -/
@[tactic fTransConv] def elabFTransConv : Tactic := fun stx ↦ withMainContext do
let ctx ← getSimpContext stx[2] !stx[1].isNone
let ctx := {ctx with config := {ctx.config with iota := true, zeta := false, singlePass := true}}
let ctx ← getFTransContext stx[2] !stx[1].isNone
let ctx := {ctx with config := {ctx.config with iota := true, zeta := false, singlePass := true, dsimp := false, decide := false}}
Conv.applySimpResult (← deriveSimp ctx (← instantiateMVars (← Conv.getLhs)) (useSimp := true))

9 changes: 1 addition & 8 deletions SciLean/Tactic/FTrans/Init.lean
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,7 @@ initialize registerTraceClass `Meta.Tactic.ftrans.unify
initialize registerOption `linter.ftransDeclName { defValue := true, descr := "suggests declaration name for ftrans rule" }
-- initialize registerTraceClass `Meta.Tactic.ftrans.lambda_special_cases


-- /-- Simp attribute to mark function transformation rules.
-- -/
-- register_simp_attr ftrans_simp

-- macro "ftrans" : attr => `(attr| ftrans_simp ↓)

register_simp_attr ftrans_simp

open Meta Simp

Expand Down Expand Up @@ -260,7 +254,6 @@ private def FTransRules.merge! (_ : Name) (fp fp' : FTransRules) : FTransRules
initialize FTransRulesExt : MergeMapDeclarationExtension FTransRules
← mkMergeMapDeclarationExtension ⟨FTransRules.merge!, sorry_proof⟩


open Lean Qq Meta Elab Term in
initialize funTransRuleAttr : TagAttribute ←
registerTagAttribute
Expand Down
10 changes: 10 additions & 0 deletions SciLean/Tactic/FTrans/Simp.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import SciLean.Tactic.FTrans.Init
import Mathlib.Algebra.Group.Prod
import Mathlib.GroupTheory.GroupAction.Prod
import Mathlib.Algebra.SMulWithZero

namespace SciLean

attribute [ftrans_simp] Prod.mk_add_mk Prod.mk_mul_mk Prod.smul_mk Prod.mk_sub_mk Prod.neg_mk Prod.vadd_mk add_zero zero_add sub_zero zero_sub sub_self neg_zero mul_zero zero_mul zero_smul smul_zero smul_eq_mul smul_neg eq_self iff_self

attribute [ftrans_simp] Equiv.invFun_as_coe Equiv.symm_symm
Loading

0 comments on commit 7123cab

Please sign in to comment.