Skip to content

Commit

Permalink
basic operations on vectors and matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 27, 2024
1 parent cf41293 commit d52889b
Show file tree
Hide file tree
Showing 14 changed files with 889 additions and 94 deletions.
3 changes: 3 additions & 0 deletions SciLean/Data/DataArray/DataArray.lean
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ def DataArrayN.curry [Inhabited α] (x : DataArrayN α (ι×κ)) : DataArrayN (D
def DataArrayN.uncurry [Inhabited α] (x : DataArrayN (DataArrayN α κ) ι) : DataArrayN α (ι×κ) :=
⟨⟨x.data.byteData, Size.size ι, sorry_proof⟩, sorry_proof⟩

theorem DataArrayN.uncurry_def [Inhabited α] (x : DataArrayN (DataArrayN α κ) ι) :
x.uncurry = ⊞ i j => x[i][j] := sorry_proof

set_option linter.dupNamespace false in
open Lean in
private partial def parseDimProd (s : Syntax) : TSyntaxArray `dimSpec :=
Expand Down
92 changes: 62 additions & 30 deletions SciLean/Data/DataArray/Operations.lean
Original file line number Diff line number Diff line change
Expand Up @@ -60,75 +60,99 @@ abbrev reshape5 (x : X^[I]) (n₁ n₂ n₃ n₄ n₅ : ℕ)


variable [DecidableEq I]
[Add X] [Sub X] [Mul X] [Zero X] [One X]

def _root_.SciLean.Matrix.identity : X^[I,I] :=
variable {R : Type*} [inst : RealScalar R] [PlainDataType R]
-- [Add X] [Sub X] [Mul X] [Zero X] [One X]

def _root_.SciLean.Matrix.identity : R^[I,I] :=
⊞ (i j : I) => if i = j then 1 else 0

def multiply (x y : X^[I]) : X^[I] :=
def multiply (x y : R^[I]) : R^[I] :=
x.mapIdxMono (fun i xi => xi * y[i])

def diag (x : X^[I]) : X^[I,I] :=
def diag (x : R^[I]) : R^[I,I] :=
⊞ i j => if i = j then x[i] else 0

def kronprod (x : X^[I]) (y : X^[J]) : X^[I,J] :=
def diagonal (x : R^[I,I]) : R^[I] :=
⊞ i => x[i,i]

def outerprod (x : R^[I]) (y : R^[J]) : R^[I,J] :=
⊞ i j => x[i]*y[j]

-- todo: maybe add complex conjugate
def dot (x y : X^[I]) : X := ∑ i, x[i]*y[i]
/-- Sum all elements of a vector, matrix, tensor: `x.sum = ∑ i, x[i]`-/
def sum (x : R^[I]) : R := ∑ i, x[i]

/-- Matrix transpose -/
def transpose (A : R^[I,J]) : R^[J,I] := ⊞ j i => A[i,j]

@[inherit_doc transpose]
postfix:max "ᵀ" => transpose

def vecmul (A : X^[I,J]) (x : X^[J]) : X^[I] := ⊞ i => ∑ j, A[i,j] * x[j]
/-- Matrix trace: `A.trace = ∑ i, A[i,i]` -/
def trace (A : R^[I,I]) : R := ∑ i, A[i,i]

/-- Dot product between vectors, matrices, tensors: `x.dot y = ∑ i, x[i] * y[i]` -/
def dot (x y : R^[I]) : R := ∑ i, x[i]*y[i]

/-- Matrix × vector multiplication: `A.vecmul x = ⊞ i => ∑ j, A[i,j] * x[j]` -/
def vecmul (A : R^[I,J]) (x : R^[J]) : R^[I] := ⊞ i => ∑ j, A[i,j] * x[j]

/-- Matrix × matrix multiplication: `A.vecmul B = ⊞ i k => ∑ j, A[i,j] * B[j,k]` -/
def matmul (A : R^[I,J]) (B : R^[J,K]) : R^[I,K] := ⊞ i k => ∑ j, A[i,j] * B[j,k]

def matmul (A : X^[I,J]) (B : X^[J,K]) : X^[I,K] := ⊞ i k => ∑ j, A[i,j] * B[j,k]

noncomputable
def inv (A : X^[I,I]) : X^[I,I] :=
(fun B : X^[I,I] => A.matmul B).invFun Matrix.identity

def npow (A : X^[I,I]) (n : ℕ) : X^[I,I] :=
if h : n = 0 then
Matrix.identity
else if _ : n = 1 then
A
else
have : n.log2 < n := by apply (Nat.log2_lt h).2; exact Nat.lt_two_pow n
def inv (A : R^[I,I]) : R^[I,I] :=
(fun B : R^[I,I] => A.matmul B).invFun Matrix.identity

/-- Invertible matrix proposition -/
def Invertible (A : R^[I,I]) : Prop := (fun B : R^[I,I] => A.matmul B).Bijective

def npow (A : R^[I,I]) (n : ℕ) : R^[I,I] :=
match n with
| 0 => Matrix.identity
| 1 => A
| n+2 =>
if n % 2 = 0 then
npow (A.matmul A) (n/2)
npow (A.matmul A) (n/2+1)
else
(npow (A.matmul A) (n/2)).matmul A
(npow (A.matmul A) (n/2+1)).matmul A


noncomputable
def zpow (A : X^[I,I]) (n : ℤ) : X^[I,I] :=
def zpow (A : R^[I,I]) (n : ℤ) : R^[I,I] :=
if 0 ≤ n then
A.npow n.toNat
else
A.inv.npow (-n).toNat

/-- Matrix determinant -/
noncomputable
def det {R} [RealScalar R] [PlainDataType R] (A : R^[I,I]) : R :=
let f := LinearMap.mk' R (fun x : R^[I] => (⊞ i => ∑ j, A[i,j] * x[j])) sorry_proof
LinearMap.det f

namespace Matrix

variable [Add X] [Mul X] [Sub X] [Zero X] [One X]

instance : HMul (X^[I,J]) (X^[J,K]) (X^[I,K]) where
instance : HMul (R^[I,J]) (R^[J,K]) (R^[I,K]) where
hMul A B := A.matmul B

instance : HMul (X^[I,J]) (X^[J]) (X^[I]) where
instance : HMul (R^[I,J]) (R^[J]) (R^[I]) where
hMul A x := A.vecmul x

instance : HPow (X^[I,I]) ℕ (X^[I,I]) where
instance : HPow (R^[I,I]) ℕ (R^[I,I]) where
hPow A n := A.npow n

noncomputable
instance : Inv (X^[I,I]) where
instance : Inv (R^[I,I]) where
inv A := A.inv

noncomputable
instance : HPow (X^[I,I]) ℤ (X^[I,I]) where
instance : HPow (R^[I,I]) ℤ (R^[I,I]) where
hPow A n := A.zpow n

end Matrix

variable {R : Type*} [RealScalar R] [PlainDataType R] [DecidableEq I]
set_default_scalar R

open Scalar
Expand All @@ -147,3 +171,11 @@ def softmax' (x dx : R^[I]) : R :=
def logsumexp (x : R^[I]) : R :=
let xmax := IndexType.maxD (x[·]) 0
log (∑ i, exp (x[i] - xmax)) - xmax

/-- Elementwise exponential -/
def exp (x : R^[I]) : R^[I] :=
x.mapMono (fun xi => Scalar.exp xi)

/-- Elementwise logarithm -/
def log (x : R^[I]) : R^[I] :=
x.mapMono (fun xi => Scalar.log xi)
60 changes: 60 additions & 0 deletions SciLean/Data/DataArray/Operations/Diag.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import SciLean.Data.DataArray.Operations.Multiply

namespace SciLean

section Missing


theorem sum_over_prod' {R} [AddCommMonoid R] {I J : Type*} [IndexType I] [IndexType J]
{f : I → J → R} : ∑ i j, f i j = ∑ (i : I×J), f i.1 i.2 := sorry

theorem sum_over_prod {R} [AddCommMonoid R] {I J : Type*} [IndexType I] [IndexType J]
{f : I×J → R} : ∑ i, f i = ∑ i j, f (i,j) := sorry

@[rsimp]
theorem sum_ite {R} [AddCommMonoid R] {I : Type*} [IndexType I] [DecidableEq I]
{f : I → R} (j : I) : (∑ i, if i = j then f i else 0) = f j := sorry

@[rsimp]
theorem sum_ite' {R} [AddCommMonoid R] {I : Type*} [IndexType I] [DecidableEq I]
{f : I → R} (j : I) : (∑ i, if j = i then f i else 0) = f j := sorry

theorem sum_swap {R} [AddCommMonoid R] {I J : Type*} [IndexType I] [IndexType J]
{f : I → J → R} : ∑ i j, f i j = ∑ j i, f i j := sorry

end Missing


variable
{I : Type*} [IndexType I] [DecidableEq I]
{R : Type*} [RealScalar R] [PlainDataType R]


open DataArrayN

def_fun_prop diag in x
with_transitive : IsContinuousLinearMap R

#generate_linear_map_simps DataArrayN.diag.arg_x.IsLinearMap_rule

-- todo: change to abbrev_def_trans
def_fun_trans diag in x : fderiv R by autodiff

-- todo: change to abbrev_def_trans
def_fun_trans diag in x : fwdFDeriv R by autodiff

-- todo: change to abbrev_def_trans
def_fun_trans diag in x : adjoint R by
equals (fun x' => x'.diagonal) =>
funext x
apply AdjointSpace.ext_inner_left R
intro z
rw[← adjoint_ex _ (by fun_prop)]
simp[DataArrayN.inner_def,Function.HasUncurry.uncurry,
DataArrayN.diag,DataArrayN.diagonal,
sum_over_prod, sum_ite']

-- todo: change to abbrev_def_trans
def_fun_trans diag in x : revFDeriv R by
unfold revFDeriv
autodiff
42 changes: 42 additions & 0 deletions SciLean/Data/DataArray/Operations/Diagonal.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import SciLean.Data.DataArray.Operations.Diag

namespace SciLean



variable
{I : Type*} [IndexType I] [DecidableEq I]
{R : Type*} [RealScalar R] [PlainDataType R]

open DataArrayN


def_fun_prop diagonal in x
with_transitive
[RealScalar R] : IsContinuousLinearMap R

#generate_linear_map_simps DataArrayN.diagonal.arg_x.IsLinearMap_rule

-- todo: change to abbrev_def_trans
def_fun_trans diagonal in x [RealScalar R] : fderiv R by
fun_trans

-- todo: change to abbrev_def_trans
def_fun_trans diagonal in x [RealScalar R] : fwdFDeriv R by
autodiff

-- todo: change to abbrev_def_trans
def_fun_trans diagonal in x [DecidableEq I] [RealScalar R] : adjoint R by
equals (fun x' => x'.diag) =>
funext x
apply AdjointSpace.ext_inner_left R
intro z
rw[← adjoint_ex _ (by fun_prop)]
simp[DataArrayN.inner_def,Function.HasUncurry.uncurry,
DataArrayN.diagonal,DataArrayN.diag,
sum_over_prod, sum_ite']

-- todo: change to abbrev_def_trans
def_fun_trans diagonal in x [DecidableEq I] [RealScalar R] : revFDeriv R by
unfold revFDeriv
autodiff
51 changes: 51 additions & 0 deletions SciLean/Data/DataArray/Operations/Dot.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import SciLean.Data.DataArray.Operations.Diag

namespace SciLean

open DataArrayN

def_fun_prop dot in x with_transitive : IsContinuousLinearMap R
def_fun_prop dot in y with_transitive : IsContinuousLinearMap R
def_fun_prop dot in x y with_transitive : Differentiable R

#generate_linear_map_simps DataArrayN.dot.arg_x.IsLinearMap_rule
#generate_linear_map_simps DataArrayN.dot.arg_y.IsLinearMap_rule


-- todo: change to abbrev_def_trans
def_fun_trans dot in x y : fderiv R by
rw[fderiv_wrt_prod (by fun_prop)]
fun_trans

-- todo: change to abbrev_def_trans
def_fun_trans dot in x y : fwdFDeriv R by
rw[fwdFDeriv_wrt_prod (by fun_prop)]
autodiff

-- todo: change to abbrev_def_trans
def_fun_trans dot in x : adjoint R by
equals (fun z => z•y) =>
funext x
apply AdjointSpace.ext_inner_left R
intro z
rw[← adjoint_ex _ (by fun_prop)]
simp[DataArrayN.inner_def, DataArrayN.dot,
sum_over_prod, Function.HasUncurry.uncurry, sum_pull]
ac_rfl

-- todo: change to abbrev_def_trans
def_fun_trans dot in y : adjoint R by
equals (fun z => z•x) =>
funext y
apply AdjointSpace.ext_inner_left R
intro z
rw[← adjoint_ex _ (by fun_prop)]
simp[DataArrayN.inner_def, DataArrayN.dot,
sum_over_prod, Function.HasUncurry.uncurry, sum_pull]
ac_rfl

-- todo: change to abbrev_def_trans
def_fun_trans dot in x y : revFDeriv R by
rw[revFDeriv_wrt_prod (by fun_prop)]
unfold revFDeriv
autodiff
Loading

0 comments on commit d52889b

Please sign in to comment.