Skip to content

Commit

Permalink
derivative rules for ArrayType.ofFn
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 27, 2024
1 parent b63767f commit 1d70152
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 9 deletions.
18 changes: 18 additions & 0 deletions SciLean/Data/ArrayType/Properties.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import SciLean.Analysis.Calculus.RevFDerivProj
import SciLean.Analysis.Calculus.FwdFDeriv

import SciLean.Meta.GenerateAddGroupHomSimp
import SciLean.Meta.GenerateFunTrans

namespace SciLean

Expand Down Expand Up @@ -161,6 +162,15 @@ theorem ArrayType.modify.arg_contf.IsContinuousLinearMap_rule
-- bacause of this reason it can't apply `IsContinuousLinearMap.continuous`
sorry_proof

abbrev_fun_trans : fderiv K (fun f : Idx → Elem => ArrayType.ofFn (Cont:=Cont) f) by
fun_trans

@[fun_trans]
theorem ArrayType.ofFn.arg_f.fwdFDeriv_rule :
fwdFDeriv K (fun f : Idx → Elem => ArrayType.ofFn (Cont:=Cont) f)
=
fun f df => (ArrayType.ofFn (Cont:=Cont) f, ArrayType.ofFn (Cont:=Cont) df) := by fun_trans

-- TODO: add Differentiable, ContDiff for `modify` function

end OnNormedSpaces
Expand Down Expand Up @@ -262,6 +272,7 @@ theorem ArrayType.ofFn.arg_f.adjoint_rule :
=
fun c i => ArrayType.get c i := by sorry_proof


end OnAdjointSpace


Expand Down Expand Up @@ -308,6 +319,13 @@ theorem ArrayType.get.arg_cont.revFDerivProjUpdate_rule (i : Idx)
(ArrayType.get xi.1 i, fun (j : I) (de : E j) dw =>
xi.2 (i,j) de dw) := by unfold revFDerivProjUpdate; fun_trans

@[fun_trans]
theorem ArrayType.ofFn.arg_f.revFDeriv_rule :
revFDeriv K (fun f : Idx → Elem => ArrayType.ofFn f)
=
fun f =>
(ArrayType.ofFn (Cont:=Cont) f, fun (dx : Cont) i => ArrayType.get dx i) := by
unfold revFDeriv; fun_trans

@[fun_trans]
theorem ArrayType.ofFn.arg_f.revFDerivProj_rule_unit_simple :
Expand Down
107 changes: 98 additions & 9 deletions examples/GaussianMixtureModel.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@ import SciLean.Lean.Meta.Basic

open SciLean Scalar

section Missing

variable
{R} [RCLike R]
{X} [NormedAddCommGroup X] [NormedSpace R X]
{Y} [NormedAddCommGroup Y] [NormedSpace R Y]

@[fun_prop]
theorem ite.arg_te.Differentiable_rule {c : Prop} [h : Decidable c]
(t e : X → Y) (ht : Differentiable R t) (he : Differentiable R e) :
Differentiable R (fun x => ite c (t x) (e x)) := by split_ifs <;> assumption

@[fun_prop]
theorem dite.arg_te.Differentiable_rule {c : Prop} [h : Decidable c]
(t : c → X → Y) (e : ¬c → X → Y) (ht : ∀ h, Differentiable R (t h)) (he : ∀ h, Differentiable R (e h)) :
Differentiable R (fun x => dite c (t · x) (e · x)) := by split_ifs <;> aesop

end Missing

variable {R} [RealScalar R] [PlainDataType R]

set_default_scalar R
Expand Down Expand Up @@ -80,21 +99,47 @@ def likelihood (x : R^[D]^[N]) (w : R^[K]) (μ : R^[D]^[K]) (σ : R^[D,D]^[K]) :

namespace Param

def lowerTriangularIndex (i j : Fin n) (h : i < j) : Fin (((n-1)*n)/2) := sorry
def lowerTriangularIndex (i j : Fin n) (h : i < j) : Fin (((n-1)*n)/2) :=
⟨i.1, sorry

noncomputable
def Q (q : R^[D]) (l : R^[((D-1)*D)/2]) : R^[D,D] :=
⊞ i j =>
if i = j then exp (q[i])
else if h : i < j then l[lowerTriangularIndex i j h]
(i j : Fin D) =>
if i = j then exp (q[i])
else if i < j then l[lowerTriangularIndex i j sorry]
else 0

def w (α : R^[K]) : R^[K] := ⊞ i => exp α[i] / ∑ k, exp α[k]

#check ite
-- {α : Sort u} → (c : Prop) → [h : Decidable c] → α → α → α


example (l : R^[n]) : Differentiable R fun (q : R) (i : Fin m) =>
if i.1 < 5 then q
else if h : i.1 < n then l[⟨i.1,h⟩] else 0 := by
set_option trace.Meta.Tactic.fun_prop true in
fun_prop


set_option trace.Meta.Tactic.fun_prop true in
example (l : R^[((D-1)*D)/2]) : Differentiable R fun (q : R^[D]) (p : Fin D × Fin D) =>
if p.1 = p.2 then SciLean.Scalar.exp q[p.1]
else if h : p.1 < p.2 then l[Param.lowerTriangularIndex p.1 p.2 h] else 0 := by
fun_prop

set_option trace.Meta.Tactic.fun_trans true in
def_fun_trans (l : R^[((D-1)*D)/2]) : fwdFDeriv R (fun q : R^[D] => Param.Q q l) by
unfold Q
simp[Function.HasUncurry.uncurry]
-- autodiff -- casuses panic
fun_trans

#exit

@[simp, simp_core]
theorem det_Q (q : R^[D]) (l : R^[((D-1)*D)/2]) : (Q q l).det = exp q.sum := sorry



@[simp, simp_core]
theorem det_QTQ (q : R^[D]) (l : R^[((D-1)*D)/2]) : ((Q q l)ᵀ * (Q q l)).det = exp (2 * q.sum) := sorry

Expand All @@ -113,10 +158,14 @@ theorem trace_QQT (q : R^[D]) (l : R^[((D-1)*D)/2]) :

end Param





open Param in
noncomputable
def likelihood' (x : R^[D]^[N]) (α : R^[K]) (μ : R^[D]^[K]) (q : R^[D]^[K]) (l : R^[((D-1)*D)/2]^[K]) : R :=
likelihood x (w α) μ (⊞ k => ((Q q[k] l[k])ᵀ * Q q[k] l[k])⁻¹)
likelihood x (α.softmax) μ (⊞ k => ((Q q[k] l[k])ᵀ * Q q[k] l[k])⁻¹)
rewrite_by
unfold likelihood
simp
Expand Down Expand Up @@ -257,13 +306,53 @@ open Param in
noncomputable
def loss (m : R) (x : R^[D]^[N]) (α : R^[K]) (μ : R^[D]^[K]) (q : R^[D]^[K]) (l : R^[((D-1)*D)/2]^[K]) : R :=
let σ := ⊞ k => ((Q q[k] l[k])ᵀ * Q q[k] l[k])⁻¹
(- log (likelihood x (w α) μ σ * prior m σ))
(- log (likelihood x (α.softmax) μ σ /-* prior m σ -/))
rewrite_by
unfold likelihood
simp only [simp_core, likelihood, prior, σ, w]
simp only [DataArrayN.softmax_spec, DataArrayN.softmaxSpec]
simp only [simp_core, likelihood, prior, σ]
simp only [simp_core, mul_pull_from_sum, refinedRewritePost, sum_push,
log_mul, log_prod, mul_exp, log_sum_exp, log_pow, log_div, log_inv]
simp only [simp_core]
ring_nf

#exit

def_fun_trans loss in α : fwdFDeriv R by
unfold loss
autodiff

def_fun_trans loss in α : revFDeriv R by
unfold loss
autodiff


def_fun_trans loss in μ : fwdFDeriv R by
unfold loss
autodiff


def_fun_trans loss in α μ : fwdFDeriv R by
unfold loss
autodiff


-- attribute [-simp_core] revFDeriv_on_DataArrayN

-- def_fun_trans loss in μ : revFDeriv R by
-- unfold loss
-- autodiff


-- def_fun_trans loss in q : fwdFDeriv R by
-- unfold loss
-- autodiff


-- def_fun_trans loss in l : fwdFDeriv R by
-- unfold loss
-- autodiff




Expand Down

0 comments on commit 1d70152

Please sign in to comment.