diff --git a/examples/GaussianMixtureModel.lean b/examples/GaussianMixtureModel.lean index 60bb448d..426cf0ac 100644 --- a/examples/GaussianMixtureModel.lean +++ b/examples/GaussianMixtureModel.lean @@ -1,15 +1,72 @@ import SciLean import SciLean.Data.DataArray.Operations.Simps import SciLean.Lean.Meta.Basic +import SciLean.Tactic.StructureDecomposition -open SciLean Scalar +open SciLean Scalar SciLean.Meta + +open Lean +partial def asdf (e : Expr) : Bool := Id.run do + match e with + | .lam _ t b _ => + if t.isAppOfArity' ``Prod 2 then + return true + else + return asdf b + | .mdata _ e => asdf e + | _ => return false + + +open Lean Meta in +dsimproc_decl prodFunSimproc (_) := fun e => do + + -- unless asdf e do return .continue + unless e.isLambda do return .continue + + lambdaTelescope e fun xs b => do + -- check if lambda has been already processed + if let .letE _ _ (.mdata d _) _ _ := b then + if .some true = d.get? `prodFunSimproc then + return .continue + + let e' ← xs.foldrM (init:=b) fun x b => do + let a ← splitStructureElem x + let xs := a.1 + let mk := a.2 + + if xs.size = 1 then + mkLambdaFVars #[x] b + else + let xname ← x.fvarId!.getUserName + -- mark values with mdata to preven infinite loop + let data := MData.empty.set `prodFunSimproc true + let xs ← xs.mapM (fun x => do pure (Expr.mdata data x)) + + withLetDecls (xs.mapIdx (fun i _ => xname.appendAfter (toString i))) xs fun vars => do + let x' := mk.beta vars + let b' := b.replaceFVar x x' + -- let r ← Simp.simp b' + let vars := #[x] ++ vars + mkLambdaFVars vars b' + return .continue e' + + +#check (fun (x : ℝ×ℝ×ℝ) => x.1 + x.2.1 + x.2.2) rewrite_by simp -zeta only [prodFunSimproc] +#check (fun (i : ℕ) => i) rewrite_by lsimp [↓prodFunSimproc] +#check (fun (i : ℕ×ℕ) => i.1 + i.2) rewrite_by lsimp [prodFunSimproc] +#check (fun (j : ℕ×ℕ) (i : ℕ×ℕ) => i.1 + i.2 + j.1) rewrite_by lsimp [↓prodFunSimproc] +#check (fun (j : ℕ) (i : ℕ×ℕ) => i.1 + i.2) rewrite_by lsimp [prodFunSimproc] + +#check (fun (i : ℕ×ℕ) (j : ℕ) => i.1 + i.2) rewrite_by lsimp [prodFunSimproc] +#check (fun (x : ℕ×ℕ) (y : ℕ×ℕ) => x.1 + y.1) rewrite_by lsimp [prodFunSimproc] +#check (fun (i : ℕ) (x : ℕ×ℕ) (y : ℕ×ℕ) => x.1 + y.1) rewrite_by lsimp [prodFunSimproc] section Missing variable - {R} [RCLike R] - {X} [NormedAddCommGroup X] [NormedSpace R X] - {Y} [NormedAddCommGroup Y] [NormedSpace R Y] + {R : Type} [RCLike R] + {X : Type} [NormedAddCommGroup X] [NormedSpace R X] + {Y : Type} [NormedAddCommGroup Y] [NormedSpace R Y] @[fun_prop] theorem ite.arg_te.Differentiable_rule {c : Prop} [h : Decidable c] @@ -21,26 +78,24 @@ theorem dite.arg_te.Differentiable_rule {c : Prop} [h : Decidable c] (t : c → X → Y) (e : ¬c → X → Y) (ht : ∀ h, Differentiable R (t h)) (he : ∀ h, Differentiable R (e h)) : Differentiable R (fun x => dite c (t · x) (e · x)) := by split_ifs <;> aesop +@[simp, simp_core] +theorem oneHot_unit [Zero α] (i : Unit) (x : α) : oneHot i x = x := rfl + end Missing -variable {R} [RealScalar R] [PlainDataType R] +variable {R : Type} [RealScalar R] [PlainDataType R] set_default_scalar R -variable {D N K : ℕ} +#check (∂> (x : R×R×R), (let a := (x.1 + x.2.2); let b := a*x.1; a*b*x.2.1)) rewrite_by autodiff [↓ prodFunSimproc, ↑ prodFunSimproc] -namespace SciLean.MatrixOperations +variable {D N K : ℕ} -@[scoped simp, scoped simp_core] -theorem matrix_inverse_inverse {I} [IndexType I] [DecidableEq I] (A : R^[I,I]) (hA : A.Invertible) : - (A⁻¹)⁻¹ = A := by simp[hA] -@[scoped simp, scoped simp_core] -theorem det_inv_eq_inv_det {I} [IndexType I] [DecidableEq I] (A : R^[I,I]) : - (A⁻¹).det = (A.det)⁻¹ := by simp +namespace SciLean.MatrixOperations -variable {I J K : Type*} [IndexType I] [IndexType J] [IndexType K] +variable {I J K : Type} [IndexType I] [IndexType J] [IndexType K] @[scoped simp, scoped simp_core] theorem inner_QQT (x y : R^[I]) (Q : R^[I,J]) : @@ -70,22 +125,6 @@ theorem gaussian_normalization_invQTQ {d : ℕ} (Q : R^[d,d]) : = (2 * π)^(-(d:R)/2) * Q.det := sorry --- -- not sure if is shoud be defined for `R^[I]` or `I → R` --- def logsumexp (x : R^[I]) : R:= --- let xmax := IndexType.maxD (x[·]) 0 --- log (∑ i, exp (x[i] - xmax)) - xmax - --- -- derivative of `logsumexp` is `softmax` --- -- related to `softmax` is `softmax' x y = ⟪softmax x, y⟫` --- def softmax' (x dx : R^[I]) : R := --- let xmax := IndexType.maxD (x[·]) 0 --- (∑ i, dx[i] * exp (x[i] - xmax)) / ∑ j, exp (x[j] - xmax) - --- -- gradient of `logsumexp` is `softmax` --- def softmax (x : R^[I]) : R^[I] := --- let xmax := IndexType.maxD (x[·]) 0 --- ⊞ i => exp (x[i] - xmax) / ∑ j, exp (x[j] - xmax) - theorem log_sum_exp (x : I → R) : log (∑ i, exp (x i)) = (⊞ i => x i).logsumexp := sorry end SciLean.MatrixOperations @@ -99,69 +138,33 @@ def likelihood (x : R^[D]^[N]) (w : R^[K]) (μ : R^[D]^[K]) (σ : R^[D,D]^[K]) : namespace Param -def lowerTriangularIndex (i j : Fin n) (h : i < j) : Fin (((n-1)*n)/2) := - ⟨i.1, sorry⟩ - -def Q (q : R^[D]) (l : R^[((D-1)*D)/2]) : R^[D,D] := - ⊞ (i j : Fin D) => - if i = j then exp (q[i]) - else if i < j then l[lowerTriangularIndex i j sorry] - else 0 - +def Q (q : R^[D]) (l : R^[((D-1)*D)/2]) : R^[D,D] := q.exp.diag + l.lowerTriangular D 1 -#check ite --- {α : Sort u} → (c : Prop) → [h : Decidable c] → α → α → α -example (l : R^[n]) : Differentiable R fun (q : R) (i : Fin m) => - if i.1 < 5 then q - else if h : i.1 < n then l[⟨i.1,h⟩] else 0 := by - set_option trace.Meta.Tactic.fun_prop true in - fun_prop +def_fun_prop Q in q l : Differentiable R +abbrev_fun_trans Q in q l : fwdFDeriv R by unfold Q; autodiff; lsimp only [↓prodFunSimproc] +abbrev_fun_trans Q in q l arg_subsets : revFDeriv R by unfold Q; autodiff; lsimp only [↓prodFunSimproc] +abbrev_fun_trans Q in q l arg_subsets : revFDerivProj R Unit by unfold Q; autodiff; lsimp only [↓prodFunSimproc] +abbrev_fun_trans Q in q l arg_subsets : revFDerivProjUpdate R Unit by unfold Q; autodiff; lsimp only [↓prodFunSimproc] -set_option trace.Meta.Tactic.fun_prop true in -example (l : R^[((D-1)*D)/2]) : Differentiable R fun (q : R^[D]) (p : Fin D × Fin D) => - if p.1 = p.2 then SciLean.Scalar.exp q[p.1] - else if h : p.1 < p.2 then l[Param.lowerTriangularIndex p.1 p.2 h] else 0 := by - fun_prop - -set_option trace.Meta.Tactic.fun_trans true in -def_fun_trans (l : R^[((D-1)*D)/2]) : fwdFDeriv R (fun q : R^[D] => Param.Q q l) by - unfold Q - simp[Function.HasUncurry.uncurry] - -- autodiff -- casuses panic - fun_trans - -#exit +variable (q : R^[D]) (l : R^[((D-1)*D)/2]) @[simp, simp_core] -theorem det_Q (q : R^[D]) (l : R^[((D-1)*D)/2]) : (Q q l).det = exp q.sum := sorry - - - -@[simp, simp_core] -theorem det_QTQ (q : R^[D]) (l : R^[((D-1)*D)/2]) : ((Q q l)ᵀ * (Q q l)).det = exp (2 * q.sum) := sorry +theorem det_Q : (Q q l).det = exp q.sum := sorry @[simp, simp_core] -theorem QTQ_invertible (q : R^[D]) (l : R^[((D-1)*D)/2]) : ((Q q l)ᵀ * (Q q l)).Invertible := sorry +theorem det_QTQ : ((Q q l)ᵀ * (Q q l)).det = exp (2 * q.sum) := sorry @[simp, simp_core] -theorem trace_QTQ (q : R^[D]) (l : R^[((D-1)*D)/2]) : - ((Q q l)ᵀ * Q q l).trace - = ‖q.exp‖₂² + ‖l‖₂² := sorry +theorem QTQ_invertible : ((Q q l)ᵀ * (Q q l)).Invertible := sorry @[simp, simp_core] -theorem trace_QQT (q : R^[D]) (l : R^[((D-1)*D)/2]) : - (Q q l * (Q q l)ᵀ).trace - = ‖q.exp‖₂² + ‖l‖₂² := sorry +theorem trace_QTQ : ((Q q l)ᵀ * Q q l).trace = ‖q.exp‖₂² + ‖l‖₂² := sorry end Param - - - - open Param in noncomputable def likelihood' (x : R^[D]^[N]) (α : R^[K]) (μ : R^[D]^[K]) (q : R^[D]^[K]) (l : R^[((D-1)*D)/2]^[K]) : R := @@ -175,6 +178,8 @@ def prior (m : R) (σ : R^[D,D]^[K]) := ∏ k, /- C(D,m) -/ (σ[k].det)^m * exp theorem log_prod {I} [IndexType I] (x : I → R) : log (∏ i, x i) = ∑ i, log (x i) := sorry + + open Lean Meta /-- Take expression full of multiplications and divitions and split it into lists of multiplication and division factors. @@ -267,15 +272,14 @@ attribute [rsimp] SciLean.ArrayType.sum_ofFn theorem IndexType.sum_const {I} [IndexType I] (x : R) : (∑ (i : I), x) = (Size.size I : R) • x := sorry + @[simp_core] theorem neg_add_rev' {G : Type*} [SubtractionCommMonoid G] (a b : G) : -(a + b) = -a + -b := by simp[add_comm] -def sum (x : R^[I]) : R := ∑ i, x[i] - @[rsimp] -theorem sum_normalize (x : R^[I]) : ∑ i, x[i] = sum x := rfl +theorem sum_normalize (x : R^[I]) : ∑ i, x[i] = x.sum := rfl @[rsimp] theorem norm_normalize (x : R^[I]) : ∑ i, ‖x[i]‖₂² = ‖x‖₂² := rfl @@ -302,62 +306,38 @@ theorem isum_norm (x : R^[I]^[J]) : DataArrayN.norm2_def] rw[sum_over_prod] -open Param in + +open Param Scalar in noncomputable def loss (m : R) (x : R^[D]^[N]) (α : R^[K]) (μ : R^[D]^[K]) (q : R^[D]^[K]) (l : R^[((D-1)*D)/2]^[K]) : R := let σ := ⊞ k => ((Q q[k] l[k])ᵀ * Q q[k] l[k])⁻¹ - (- log (likelihood x (α.softmax) μ σ /-* prior m σ -/)) + (- log (likelihood x (α.softmax) μ σ * prior m σ)) rewrite_by unfold likelihood - simp only [DataArrayN.softmax_spec, DataArrayN.softmaxSpec] - simp only [simp_core, likelihood, prior, σ] + simp only [simp_core, likelihood, prior, σ, DataArrayN.softmax_def] simp only [simp_core, mul_pull_from_sum, refinedRewritePost, sum_push, log_mul, log_prod, mul_exp, log_sum_exp, log_pow, log_div, log_inv] - simp only [simp_core] ring_nf -#exit - -def_fun_trans loss in α : fwdFDeriv R by - unfold loss - autodiff -def_fun_trans loss in α : revFDeriv R by - unfold loss - autodiff +set_option pp.deepTerms.threshold 10000 +set_option maxHeartbeats 10000000 -def_fun_trans loss in μ : fwdFDeriv R by - unfold loss - autodiff +macro "cleanup_pass" : conv => `(conv| lsimp (config:={singlePass:=true}) only [simp_core, ↓prodFunSimproc]) -def_fun_trans loss in α μ : fwdFDeriv R by +def_fun_trans loss in α μ q l : fwdFDeriv R by unfold loss - autodiff - - --- attribute [-simp_core] revFDeriv_on_DataArrayN - --- def_fun_trans loss in μ : revFDeriv R by --- unfold loss --- autodiff - - --- def_fun_trans loss in q : fwdFDeriv R by --- unfold loss --- autodiff - - --- def_fun_trans loss in l : fwdFDeriv R by --- unfold loss --- autodiff - - + autodiff (config:={singlePass:=true}) + cleanup_pass; cleanup_pass; cleanup_pass; cleanup_pass; cleanup_pass; +def_fun_trans loss in α μ q l : revFDerivProj R Unit by + unfold loss + autodiff (config:={singlePass:=true}) + cleanup_pass; cleanup_pass; cleanup_pass; cleanup_pass; cleanup_pass -set_option pp.raw true in -#check (1 : ℝ) -variable (x : Fin 10 → R) -#check (∑ i : Fin 10, ((-1:R)/2) * x i) rewrite_by simp [mul_pull_from_sum] -#check (∑ i : Fin 10, -(2* x i)) rewrite_by simp [mul_pull_from_sum] +def_fun_trans loss in α μ q l : revFDerivProjUpdate R Unit by + unfold loss + autodiff (config:={singlePass:=true}) + cleanup_pass; cleanup_pass; cleanup_pass; cleanup_pass; cleanup_pass;