Skip to content

Commit

Permalink
clean up of array operations
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 27, 2024
1 parent 57c0023 commit b63767f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 19 deletions.
8 changes: 6 additions & 2 deletions SciLean/Data/DataArray/Operations.lean
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,17 @@ set_default_scalar R

open Scalar

/-- Softmax turns array into an array of values in (0,1) -/
def softmax (x : R^[I]) : R^[I] :=
let xmax := x.maxD 0
let xmax := x.max
let w := ∑ i, exp (x[i] - xmax)
⊞ i => exp (x[i] - xmax) / w

/-- Logarithm of sum of exponentials, its derivative is softmax.
Common when doing maximul likelihood. -/
def logsumexp (x : R^[I]) : R :=
let xmax := IndexType.maxD (x[·]) 0
let xmax := x.max
log (∑ i, exp (x[i] - xmax)) + xmax

/-- Elementwise exponential -/
Expand Down
3 changes: 1 addition & 2 deletions SciLean/Data/DataArray/Operations/Logsumexp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ variable

set_default_scalar R

/-- Softmax with awful numerical properties but nice for proving theorems. -/
/-- Logsumexp with awful numerical properties but nice for proving theorems. -/
def logsumexpSpec (x : R^[I]) : R :=
Scalar.log (∑ i, Scalar.exp (x[i]))


theorem logsumexp_spec (x : R^[I]) : logsumexp x = logsumexpSpec x := sorry_proof

def_fun_prop logsumexp in x : Differentiable R by
Expand Down
23 changes: 8 additions & 15 deletions examples/GaussianMixtureModel.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,12 @@ set_default_scalar R

variable {D N K : ℕ}

notation "π" => @RealScalar.pi defaultScalar% inferInstance

@[app_unexpander RealScalar.pi] def unexpandPi : Lean.PrettyPrinter.Unexpander
| `($_) => `(π)


#check |(1:ℝ)|

namespace SciLean.MatrixOperations


@[scoped simp, scoped simp_core]
theorem matrix_inverse_inverse {I} [IndexType I] [DecidableEq I] (A : R^[I,I]) :
(A⁻¹)⁻¹ = A := sorry
theorem matrix_inverse_inverse {I} [IndexType I] [DecidableEq I] (A : R^[I,I]) (hA : A.Invertible) :
(A⁻¹)⁻¹ = A := by simp[hA]

@[scoped simp, scoped simp_core]
theorem det_inv_eq_inv_det {I} [IndexType I] [DecidableEq I] (A : R^[I,I]) :
Expand Down Expand Up @@ -99,12 +91,16 @@ def Q (q : R^[D]) (l : R^[((D-1)*D)/2]) : R^[D,D] :=

def w (α : R^[K]) : R^[K] := ⊞ i => exp α[i] / ∑ k, exp α[k]


@[simp, simp_core]
theorem det_Q (q : R^[D]) (l : R^[((D-1)*D)/2]) : (Q q l).det = exp q.sum := sorry

@[simp, simp_core]
theorem det_QTQ (q : R^[D]) (l : R^[((D-1)*D)/2]) : ((Q q l)ᵀ * (Q q l)).det = exp (2 * q.sum) := sorry

@[simp, simp_core]
theorem QTQ_invertible (q : R^[D]) (l : R^[((D-1)*D)/2]) : ((Q q l)ᵀ * (Q q l)).Invertible := sorry

@[simp, simp_core]
theorem trace_QTQ (q : R^[D]) (l : R^[((D-1)*D)/2]) :
((Q q l)ᵀ * Q q l).trace
Expand Down Expand Up @@ -235,17 +231,14 @@ theorem sum_normalize (x : R^[I]) : ∑ i, x[i] = sum x := rfl
@[rsimp]
theorem norm_normalize (x : R^[I]) : ∑ i, ‖x[i]‖₂² = ‖x‖₂² := rfl

theorem sum_over_prod {R} [AddCommMonoid R] {I J : Type*} [IndexType I] [IndexType J]
{f : I → J → R} : ∑ i j, f i j = ∑ (i : I×J), f i.1 i.2 := sorry
-- theorem sum_over_prod {R} [AddCommMonoid R] {I J : Type*} [IndexType I] [IndexType J]
-- {f : I → J → R} : ∑ i j, f i j = ∑ (i : I×J), f i.1 i.2 := sorry

@[rsimp]
theorem isum_sum (x : R^[I]^[J]) : ∑ i, x[i].sum = x.uncurry.sum := by
simp[DataArrayN.uncurry_def,DataArrayN.sum,Function.HasUncurry.uncurry]
rw[sum_over_prod]

theorem _root_.SciLean.DataArrayN.norm2_def {R : Type*} [RCLike R] {I} [IndexType I] {X} [PlainDataType X] [Inner R X]
(x : X^[I]) : ‖x‖₂²[R] = ∑ i, ‖x[i]‖₂²[R] := rfl

@[rsimp]
theorem isum_norm_exp (x : R^[I]^[J]) :
∑ j, ‖x[j].exp‖₂² = ‖x.uncurry.exp‖₂² := by
Expand Down

0 comments on commit b63767f

Please sign in to comment.