Skip to content

Commit

Permalink
properties of vector,matrix,tensor functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 27, 2024
1 parent e96f325 commit f5e257f
Show file tree
Hide file tree
Showing 20 changed files with 653 additions and 210 deletions.
52 changes: 46 additions & 6 deletions SciLean/Data/ArrayType/Properties.lean
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ section OnNormedSpaces
variable [NormedAddCommGroup Elem] [NormedSpace K Elem]
{W : Type*} [NormedAddCommGroup W] [NormedSpace K W]

theorem ArrayType.differentiable_elemwise
(cont : W → Cont) :
(∀ i, Differentiable K (fun w => ArrayType.get (cont w) i))
Differentiable K (fun w => cont w) := sorry_proof

theorem ArrayType.fwdFDeriv_elemwise
(cont : W → Cont) :
Expand All @@ -374,13 +379,43 @@ theorem ArrayType.fwdFDeriv_elemwise
(cont w,
ArrayType.ofFn (Elem:=Elem) (Cont:=Cont) fun i =>
let xdx := fwdFDeriv K (fun w => ArrayType.get (cont w) i) w dw
xdx.2) := sorry
xdx.2) := sorry_proof


@[fun_prop]
theorem ArrayType.mapIdxMono.arg_fcont.IsContinuousLinearMap_rule
(cont : W → Cont) (hcont : IsContinuousLinearMap K cont)
(f : W → Idx → Elem → Elem) (hf : ∀ i, IsContinuousLinearMap K ↿(f · i ·)) :
(IsContinuousLinearMap K fun w : W => mapIdxMono (f w) (cont w)) := sorry_proof

theorem DataArrayN.mapMono.arg_fcont.fwdFDeriv_rule
-- todo: add `DifferentiableAt` version
@[fun_prop]
theorem ArrayType.mapMono.arg_fcont.Differentiable_rule
(cont : W → Cont) (hcont : Differentiable K cont)
(f : W → Elem → Elem) (hf : Differentiable K ↿f) :
Differentiable K fun w : W => mapMono (f w) (cont w) := by
apply ArrayType.differentiable_elemwise
simp; fun_prop

@[fun_trans]
theorem ArrayType.mapMono.arg_fcont.fderiv_rule
(cont : W → Cont) (hcont : Differentiable K cont)
(f : W → Elem → Elem) (hf : Differentiable K fun (w,x) => f w x) :
(fwdFDeriv K fun w : W => ArrayType.mapMono (f w) (cont w) )
(f : W → Elem → Elem) (hf : Differentiable K ↿f) :
(fderiv K fun w : W => mapMono (f w) (cont w) )
=
fun w => ContinuousLinearMap.mk' K (hf:=sorry_proof) fun dw =>
let c := cont w
let dc := fderiv K cont w dw
ArrayType.mapIdxMono (cont:=dc) (fun i dxi =>
let xi := ArrayType.get c i
let ydy := fwdFDeriv K (↿f) (w,xi) (dw,dxi)
ydy.2) := sorry_proof

@[fun_trans]
theorem ArrayType.mapMono.arg_fcont.fwdFDeriv_rule
(cont : W → Cont) (hcont : Differentiable K cont)
(f : W → Elem → Elem) (hf : Differentiable K ↿f) :
(fwdFDeriv K fun w : W => mapMono (f w) (cont w) )
=
fun w dw =>
let cdc := fwdFDeriv K cont w dw
Expand All @@ -393,8 +428,13 @@ theorem DataArrayN.mapMono.arg_fcont.fwdFDeriv_rule

funext w dw
rw[ArrayType.fwdFDeriv_elemwise]
fun_trans[Function.HasUncurry.uncurry]
constructor <;> (apply ArrayType.ext (Idx:=Idx); intro i; simp[fwdFDeriv])
simp
constructor
· apply ArrayType.ext (Idx:=Idx); intro i; rfl
· apply ArrayType.ext (Idx:=Idx); intro i
fun_trans [fwdFDeriv]
rfl



end OnNormedSpaces
Expand Down
46 changes: 39 additions & 7 deletions SciLean/Data/DataArray/Operations.lean
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ abbrev reshape5 (x : X^[I]) (n₁ n₂ n₃ n₄ n₅ : ℕ)
x.reshape (Fin n₁ × Fin n₂ × Fin n₃ × Fin n₄ × Fin n₅) (by simp[h]; ac_rfl)


----------------------------------------------------------------------------------------------------
-- Basic Linear Algebra Operations -----------------------------------------------------------------
----------------------------------------------------------------------------------------------------

variable [DecidableEq I]

variable {R : Type*} [inst : RealScalar R] [PlainDataType R]
Expand Down Expand Up @@ -154,12 +158,46 @@ instance : HPow (R^[I,I]) ℤ (R^[I,I]) where

end Matrix

/-- Inverse of transpose matrix `A⁻ᵀ = Aᵀ⁻¹`
Tranpose and inversion commute, i.e. `Aᵀ⁻¹ = A⁻¹ᵀ`, we prefer `Aᵀ⁻¹` and `simp` by default rewrites
`A⁻¹ᵀ` to `Aᵀ⁻¹`. -/
macro:max A:term "⁻ᵀ" :term => `($Aᵀ⁻¹)

@[app_unexpander Inv.inv]
def _root_.Inv.inv.unexpander : Lean.PrettyPrinter.Unexpander
| `($_ $A) =>
match A with
| `($Aᵀ) => `($A⁻ᵀ)
| _ => `($A⁻¹)
| _ => throw ()


noncomputable
def solve (A : R^[I,I]) (b : R^[I]) := A⁻¹ * b

noncomputable
def solve' (A : R^[I,I]) (B : R^[I,J]) := A⁻¹ * B

/-- Rank polymorphic solve -/
class Solve (R : Type*) (I : Type*) (J : Type*)
[RealScalar R] [PlainDataType R] [IndexType I] [IndexType J] where
/-- Linear system solve that accepts either vector or matrix as right hand side. -/
solve (A : R^[I,I]) (b : R^[J]) : R^[J]

noncomputable
instance : Solve R I I where
solve A b := A.solve b

noncomputable
instance : Solve R I (I×J) where
solve A B := A.solve' B



----------------------------------------------------------------------------------------------------
-- Commong Nonlinear Operations --------------------------------------------------------------------
----------------------------------------------------------------------------------------------------

set_default_scalar R

Expand All @@ -170,15 +208,9 @@ def softmax (x : R^[I]) : R^[I] :=
let w := ∑ i, exp (x[i] - xmax)
⊞ i => exp (x[i] - xmax) / w

def softmax' (x dx : R^[I]) : R :=
let xmax := x.maxD 0
let w := ∑ i, exp (x[i] - xmax)
let z := ∑ i, dx[i] * exp (x[i] - xmax)
z / w

def logsumexp (x : R^[I]) : R :=
let xmax := IndexType.maxD (x[·]) 0
log (∑ i, exp (x[i] - xmax)) - xmax
log (∑ i, exp (x[i] - xmax)) + xmax

/-- Elementwise exponential -/
def exp (x : R^[I]) : R^[I] :=
Expand Down
20 changes: 20 additions & 0 deletions SciLean/Data/DataArray/Operations/Det.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import SciLean.Data.DataArray.Operations.Trace
import SciLean.Data.DataArray.Operations.Matmul

namespace SciLean

open DataArrayN

def_fun_prop det in A
with_transitive : Differentiable R by sorry_proof

abbrev_fun_trans det in A [DecidableEq I] : fderiv R by
-- Jacobi's formula: https://en.wikipedia.org/wiki/Jacobi%27s_formula
equals (fun A => fun dA =>L[R] A.det * (A⁻¹ * dA).trace) =>
sorry_proof

abbrev_fun_trans det in A [DecidableEq I] : fwdFDeriv R by unfold fwdFDeriv; autodiff

abbrev_fun_trans det in A [DecidableEq I] : revFDeriv R by
unfold revFDeriv
fun_trans
34 changes: 4 additions & 30 deletions SciLean/Data/DataArray/Operations/Diag.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,6 @@ import SciLean.Data.DataArray.Operations.Multiply

namespace SciLean

section Missing


theorem sum_over_prod' {R} [AddCommMonoid R] {I J : Type*} [IndexType I] [IndexType J]
{f : I → J → R} : ∑ i j, f i j = ∑ (i : I×J), f i.1 i.2 := sorry

theorem sum_over_prod {R} [AddCommMonoid R] {I J : Type*} [IndexType I] [IndexType J]
{f : I×J → R} : ∑ i, f i = ∑ i j, f (i,j) := sorry

@[rsimp]
theorem sum_ite {R} [AddCommMonoid R] {I : Type*} [IndexType I] [DecidableEq I]
{f : I → R} (j : I) : (∑ i, if i = j then f i else 0) = f j := sorry

@[rsimp]
theorem sum_ite' {R} [AddCommMonoid R] {I : Type*} [IndexType I] [DecidableEq I]
{f : I → R} (j : I) : (∑ i, if j = i then f i else 0) = f j := sorry

theorem sum_swap {R} [AddCommMonoid R] {I J : Type*} [IndexType I] [IndexType J]
{f : I → J → R} : ∑ i j, f i j = ∑ j i, f i j := sorry

end Missing


variable
{I : Type*} [IndexType I] [DecidableEq I]
Expand All @@ -37,14 +15,11 @@ def_fun_prop diag in x

#generate_linear_map_simps DataArrayN.diag.arg_x.IsLinearMap_rule

-- todo: change to abbrev_def_trans
def_fun_trans diag in x : fderiv R by autodiff
abbrev_fun_trans diag in x : fderiv R by autodiff

-- todo: change to abbrev_def_trans
def_fun_trans diag in x : fwdFDeriv R by autodiff
abbrev_fun_trans diag in x : fwdFDeriv R by autodiff

-- todo: change to abbrev_def_trans
def_fun_trans diag in x : adjoint R by
abbrev_fun_trans diag in x : adjoint R by
equals (fun x' => x'.diagonal) =>
funext x
apply AdjointSpace.ext_inner_left R
Expand All @@ -54,7 +29,6 @@ def_fun_trans diag in x : adjoint R by
DataArrayN.diag,DataArrayN.diagonal,
sum_over_prod, sum_ite']

-- todo: change to abbrev_def_trans
def_fun_trans diag in x : revFDeriv R by
abbrev_fun_trans diag in x : revFDeriv R by
unfold revFDeriv
autodiff
12 changes: 4 additions & 8 deletions SciLean/Data/DataArray/Operations/Diagonal.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@ def_fun_prop diagonal in x

#generate_linear_map_simps DataArrayN.diagonal.arg_x.IsLinearMap_rule

-- todo: change to abbrev_def_trans
def_fun_trans diagonal in x [RealScalar R] : fderiv R by
abbrev_fun_trans diagonal in x [RealScalar R] : fderiv R by
fun_trans

-- todo: change to abbrev_def_trans
def_fun_trans diagonal in x [RealScalar R] : fwdFDeriv R by
abbrev_fun_trans diagonal in x [RealScalar R] : fwdFDeriv R by
autodiff

-- todo: change to abbrev_def_trans
def_fun_trans diagonal in x [DecidableEq I] [RealScalar R] : adjoint R by
abbrev_fun_trans diagonal in x [DecidableEq I] [RealScalar R] : adjoint R by
equals (fun x' => x'.diag) =>
funext x
apply AdjointSpace.ext_inner_left R
Expand All @@ -36,7 +33,6 @@ def_fun_trans diagonal in x [DecidableEq I] [RealScalar R] : adjoint R by
DataArrayN.diagonal,DataArrayN.diag,
sum_over_prod, sum_ite']

-- todo: change to abbrev_def_trans
def_fun_trans diagonal in x [DecidableEq I] [RealScalar R] : revFDeriv R by
abbrev_fun_trans diagonal in x [DecidableEq I] [RealScalar R] : revFDeriv R by
unfold revFDeriv
autodiff
15 changes: 5 additions & 10 deletions SciLean/Data/DataArray/Operations/Dot.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,15 @@ def_fun_prop dot in x y with_transitive : Differentiable R
#generate_linear_map_simps DataArrayN.dot.arg_y.IsLinearMap_rule


-- todo: change to abbrev_def_trans
def_fun_trans dot in x y : fderiv R by
abbrev_fun_trans dot in x y : fderiv R by
rw[fderiv_wrt_prod (by fun_prop)]
fun_trans

-- todo: change to abbrev_def_trans
def_fun_trans dot in x y : fwdFDeriv R by
abbrev_fun_trans dot in x y : fwdFDeriv R by
rw[fwdFDeriv_wrt_prod (by fun_prop)]
autodiff

-- todo: change to abbrev_def_trans
def_fun_trans dot in x : adjoint R by
abbrev_fun_trans dot in x : adjoint R by
equals (fun z => z•y) =>
funext x
apply AdjointSpace.ext_inner_left R
Expand All @@ -33,8 +30,7 @@ def_fun_trans dot in x : adjoint R by
sum_over_prod, Function.HasUncurry.uncurry, sum_pull]
ac_rfl

-- todo: change to abbrev_def_trans
def_fun_trans dot in y : adjoint R by
abbrev_fun_trans dot in y : adjoint R by
equals (fun z => z•x) =>
funext y
apply AdjointSpace.ext_inner_left R
Expand All @@ -44,8 +40,7 @@ def_fun_trans dot in y : adjoint R by
sum_over_prod, Function.HasUncurry.uncurry, sum_pull]
ac_rfl

-- todo: change to abbrev_def_trans
def_fun_trans dot in x y : revFDeriv R by
abbrev_fun_trans dot in x y : revFDeriv R by
rw[revFDeriv_wrt_prod (by fun_prop)]
unfold revFDeriv
autodiff
29 changes: 29 additions & 0 deletions SciLean/Data/DataArray/Operations/Exp.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import SciLean.Data.DataArray.Operations.Multiply
import SciLean.Data.ArrayType.Properties
import SciLean.Analysis.SpecialFunctions.Exp

namespace SciLean.DataArrayN

open Scalar

variable
{R : Type*} [RealScalar R] [PlainDataType R]
{I : Type*} [IndexType I]

set_default_scalar R


def_fun_prop exp in x : Differentiable R by
unfold exp; fun_prop

abbrev_fun_trans DataArrayN.exp in x : fderiv R by
equals (fun x => fun dx =>L[R] dx.multiply x.exp) =>
fun_trans[multiply,exp,Function.HasUncurry.uncurry]

abbrev_fun_trans DataArrayN.exp in x : fwdFDeriv R by
unfold fwdFDeriv
autodiff

abbrev_fun_trans DataArrayN.exp in x [DecidableEq I] : revFDeriv R by
unfold revFDeriv
autodiff
33 changes: 33 additions & 0 deletions SciLean/Data/DataArray/Operations/Log.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import SciLean.Data.DataArray.Operations.Multiply
import SciLean.Data.ArrayType.Properties
import SciLean.Analysis.SpecialFunctions.Log

namespace SciLean.DataArrayN

open Scalar

variable
{R : Type*} [RealScalar R] [PlainDataType R]
{W} [NormedAddCommGroup W] [NormedSpace R W]
{U} [NormedAddCommGroup U] [AdjointSpace R U] [CompleteSpace U]
{I : Type*} [IndexType I]

set_default_scalar R

def_fun_prop (x : W → R^[I]) (hx : Differentiable R x) (hx' : ∀ w i, (x w)[i] ≠ 0) :
Differentiable R (fun w => (x w).log) by
unfold log
intro x
fun_prop (disch:=sorry_proof)

-- abbrev_fun_trans DataArrayN.log in x : fderiv R by
-- equals (fun x => fun dx =>L[R] dx.multiply x.exp) =>
-- fun_trans[multiply,exp,Function.HasUncurry.uncurry]

-- abbrev_fun_trans DataArrayN.exp in x : fwdFDeriv R by
-- unfold fwdFDeriv
-- autodiff

-- abbrev_fun_trans DataArrayN.exp in x [DecidableEq I] : revFDeriv R by
-- unfold revFDeriv
-- autodiff
Loading

0 comments on commit f5e257f

Please sign in to comment.