From b63767f4d6c196f4c4fe902c1abdfa72f61083b7 Mon Sep 17 00:00:00 2001 From: lecopivo Date: Wed, 27 Nov 2024 16:02:47 -0500 Subject: [PATCH] clean up of array operations --- SciLean/Data/DataArray/Operations.lean | 8 +++++-- .../Data/DataArray/Operations/Logsumexp.lean | 3 +-- examples/GaussianMixtureModel.lean | 23 +++++++------------ 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/SciLean/Data/DataArray/Operations.lean b/SciLean/Data/DataArray/Operations.lean index 261ab794..9ad3d94f 100644 --- a/SciLean/Data/DataArray/Operations.lean +++ b/SciLean/Data/DataArray/Operations.lean @@ -240,13 +240,17 @@ set_default_scalar R open Scalar +/-- Softmax turns array into an array of values in (0,1) -/ def softmax (x : R^[I]) : R^[I] := - let xmax := x.maxD 0 + let xmax := x.max let w := ∑ i, exp (x[i] - xmax) ⊞ i => exp (x[i] - xmax) / w +/-- Logarithm of sum of exponentials, its derivative is softmax. + +Common when doing maximul likelihood. -/ def logsumexp (x : R^[I]) : R := - let xmax := IndexType.maxD (x[·]) 0 + let xmax := x.max log (∑ i, exp (x[i] - xmax)) + xmax /-- Elementwise exponential -/ diff --git a/SciLean/Data/DataArray/Operations/Logsumexp.lean b/SciLean/Data/DataArray/Operations/Logsumexp.lean index d6ccef31..735676fa 100644 --- a/SciLean/Data/DataArray/Operations/Logsumexp.lean +++ b/SciLean/Data/DataArray/Operations/Logsumexp.lean @@ -11,11 +11,10 @@ variable set_default_scalar R -/-- Softmax with awful numerical properties but nice for proving theorems. -/ +/-- Logsumexp with awful numerical properties but nice for proving theorems. -/ def logsumexpSpec (x : R^[I]) : R := Scalar.log (∑ i, Scalar.exp (x[i])) - theorem logsumexp_spec (x : R^[I]) : logsumexp x = logsumexpSpec x := sorry_proof def_fun_prop logsumexp in x : Differentiable R by diff --git a/examples/GaussianMixtureModel.lean b/examples/GaussianMixtureModel.lean index 8117f115..9cde7f72 100644 --- a/examples/GaussianMixtureModel.lean +++ b/examples/GaussianMixtureModel.lean @@ -10,20 +10,12 @@ set_default_scalar R variable {D N K : ℕ} -notation "π" => @RealScalar.pi defaultScalar% inferInstance - -@[app_unexpander RealScalar.pi] def unexpandPi : Lean.PrettyPrinter.Unexpander - | `($_) => `(π) - - -#check |(1:ℝ)| namespace SciLean.MatrixOperations - @[scoped simp, scoped simp_core] -theorem matrix_inverse_inverse {I} [IndexType I] [DecidableEq I] (A : R^[I,I]) : - (A⁻¹)⁻¹ = A := sorry +theorem matrix_inverse_inverse {I} [IndexType I] [DecidableEq I] (A : R^[I,I]) (hA : A.Invertible) : + (A⁻¹)⁻¹ = A := by simp[hA] @[scoped simp, scoped simp_core] theorem det_inv_eq_inv_det {I} [IndexType I] [DecidableEq I] (A : R^[I,I]) : @@ -99,12 +91,16 @@ def Q (q : R^[D]) (l : R^[((D-1)*D)/2]) : R^[D,D] := def w (α : R^[K]) : R^[K] := ⊞ i => exp α[i] / ∑ k, exp α[k] + @[simp, simp_core] theorem det_Q (q : R^[D]) (l : R^[((D-1)*D)/2]) : (Q q l).det = exp q.sum := sorry @[simp, simp_core] theorem det_QTQ (q : R^[D]) (l : R^[((D-1)*D)/2]) : ((Q q l)ᵀ * (Q q l)).det = exp (2 * q.sum) := sorry +@[simp, simp_core] +theorem QTQ_invertible (q : R^[D]) (l : R^[((D-1)*D)/2]) : ((Q q l)ᵀ * (Q q l)).Invertible := sorry + @[simp, simp_core] theorem trace_QTQ (q : R^[D]) (l : R^[((D-1)*D)/2]) : ((Q q l)ᵀ * Q q l).trace @@ -235,17 +231,14 @@ theorem sum_normalize (x : R^[I]) : ∑ i, x[i] = sum x := rfl @[rsimp] theorem norm_normalize (x : R^[I]) : ∑ i, ‖x[i]‖₂² = ‖x‖₂² := rfl -theorem sum_over_prod {R} [AddCommMonoid R] {I J : Type*} [IndexType I] [IndexType J] - {f : I → J → R} : ∑ i j, f i j = ∑ (i : I×J), f i.1 i.2 := sorry +-- theorem sum_over_prod {R} [AddCommMonoid R] {I J : Type*} [IndexType I] [IndexType J] +-- {f : I → J → R} : ∑ i j, f i j = ∑ (i : I×J), f i.1 i.2 := sorry @[rsimp] theorem isum_sum (x : R^[I]^[J]) : ∑ i, x[i].sum = x.uncurry.sum := by simp[DataArrayN.uncurry_def,DataArrayN.sum,Function.HasUncurry.uncurry] rw[sum_over_prod] -theorem _root_.SciLean.DataArrayN.norm2_def {R : Type*} [RCLike R] {I} [IndexType I] {X} [PlainDataType X] [Inner R X] - (x : X^[I]) : ‖x‖₂²[R] = ∑ i, ‖x[i]‖₂²[R] := rfl - @[rsimp] theorem isum_norm_exp (x : R^[I]^[J]) : ∑ j, ‖x[j].exp‖₂² = ‖x.uncurry.exp‖₂² := by