Skip to content

Commit

Permalink
some modifications to fprop in fvar app case, partial solution to #23
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Sep 27, 2023
1 parent 60e7a81 commit d18a62b
Show file tree
Hide file tree
Showing 10 changed files with 474 additions and 39 deletions.
4 changes: 2 additions & 2 deletions SciLean/Core/FunctionPropositions/HasAdjDiff.lean
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,10 @@ by

@[fprop]
theorem SciLean.EnumType.sum.arg_f.HasAdjDiff_rule
(f : X → ι → Y) (hf : HasAdjDiff K f)
(f : X → ι → Y) (hf : ∀ i, HasAdjDiff K (f · i))
: HasAdjDiff K (fun x => ∑ i, f x i) :=
by
have ⟨_,_⟩ := hf
have := fun i => (hf i).1
constructor; fprop; ftrans; fprop


Expand Down
5 changes: 3 additions & 2 deletions SciLean/Core/FunctionPropositions/HasAdjDiffAt.lean
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,11 @@ by

@[fprop]
theorem SciLean.EnumType.sum.arg_f.HasAdjDiffAt_rule
(f : X → ι → Y) (x : X) (hf : HasAdjDiffAt K f x)
(f : X → ι → Y) (x : X) (hf : ∀ i, HasAdjDiffAt K (f · i) x)
: HasAdjDiffAt K (fun x => ∑ i, f x i) x :=
by
have ⟨_,_⟩ := hf
have := fun i => (hf i).1
have := fun i => (hf i).2
constructor; fprop; ftrans; fprop


Expand Down
4 changes: 2 additions & 2 deletions SciLean/Core/FunctionTransformations/Adjoint.lean
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ by

open BigOperators in
theorem pi_rule
(f : X → (i : ι) → E i) (hf : IsContinuousLinearMap K f)
(f : X → (i : ι) → E i) (hf : ∀ i, IsContinuousLinearMap K (f · i))
: ((fun (x : X) =>L[K] fun (i : ι) => f x i) : X →L[K] PiLp 2 E)†
=
(fun x' =>L[K] ∑ i, (fun x =>L[K] f x i)† (x' i))
Expand Down Expand Up @@ -463,7 +463,7 @@ by
open BigOperators in
@[ftrans]
theorem Finset.sum.arg_f.adjoint_rule
(f : X → ι → Y) (hf : IsContinuousLinearMap K f)
(f : X → ι → Y) (hf : ∀ i, IsContinuousLinearMap K (f · i))
: (fun x =>L[K] ∑ i, f x i)†
=
(fun y =>L[K] ∑ i, (fun x =>L[K] f x i)† y) :=
Expand Down
12 changes: 8 additions & 4 deletions SciLean/Core/Meta/GenerateFwdCDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ def generateFwdCDeriv (constName : Name) (mainNames trailingNames : Array Name)
mkLocalDecls (n:=TermElabM)
(mainNames.map (fun n => n.appendBefore "h"))
.default
(← mainArgs.mapM (fun arg => mkAppM ``IsDifferentiable #[K,arg]))
(← mainArgs.mapM (fun arg => do
lambdaTelescope (← etaExpand arg) fun xs b => do
let f := (← mkLambdaFVars #[xs[0]!] b).eta
let prop ← mkAppM ``IsDifferentiable #[K,f]
mkForallFVars xs[1:] prop))

withLocalDecls decls fun mainArgProps => do

Expand All @@ -73,11 +77,11 @@ def generateFwdCDeriv (constName : Name) (mainNames trailingNames : Array Name)
let (rhs, proof) ← elabConvRewrite rhs conv
let isDiffProof ← elabProof isDiff tac

let .lam _ _ (.lam _ _ rhsBody _) _ := rhs
| throwError "unexpected result after function transformation, expecting `fun w dw => ...` but got\n{←ppExpr rhs}"
-- let .lam _ _ (.lam _ _ rhsBody _) _ := rhs
-- | throwError "unexpected result after function transformation, expecting `fun w dw => ...` but got\n{←ppExpr rhs}"

withLocalDecl `dw .default W fun dw => do
let rhsBody := rhsBody.instantiate #[dw,w]
let rhsBody := (mkAppN rhs #[w,dw]).headBeta -- rhsBody.instantiate #[dw,w]
let dargs ← mainArgs.mapM (fun arg => mkAppM ``fwdCDeriv #[K,arg,w,dw])

let fwdDerivFun ←
Expand Down
7 changes: 6 additions & 1 deletion SciLean/Core/Meta/GenerateRevCDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ def generateRevCDeriv (constName : Name) (mainNames trailingNames : Array Name)
mkLocalDecls (n:=TermElabM)
(mainNames.map (fun n => n.appendBefore "h"))
.default
(← mainArgs.mapM (fun arg => mkAppM ``HasAdjDiff #[K,arg]))
(← mainArgs.mapM (fun arg => do
lambdaTelescope (← etaExpand arg) fun xs b => do
let f := (← mkLambdaFVars #[xs[0]!] b).eta
let prop ← mkAppM ``HasAdjDiff #[K,f]
mkForallFVars xs[1:] prop))


withLocalDecls decls fun mainArgProps => do

Expand Down
10 changes: 10 additions & 0 deletions SciLean/Lean/Meta/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def getExplicitArgs (e : Expr) : MetaM (Option (Name×Array Expr)) := do
return (funName, explicitArgs)


/-- Eta expansion, but adds at most `n` binders
-/
def etaExpandN (e : Expr) (n : Nat) : MetaM Expr :=
withDefault do forallTelescopeReducing (← inferType e) fun xs _ => mkLambdaFVars xs[0:n] (mkAppN e xs[0:n])

/-- Eta expansion, it also beta reduces the body
-/
def etaExpand' (e : Expr) : MetaM Expr :=
withDefault do forallTelescopeReducing (← inferType e) fun xs _ => mkLambdaFVars xs (mkAppN e xs).headBeta


/--
Same as `mkAppM` but does not leave trailing implicit arguments.
Expand Down
Loading

0 comments on commit d18a62b

Please sign in to comment.