Skip to content

Commit

Permalink
derivative rules for matrix solve
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 27, 2024
1 parent d52889b commit e96f325
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 0 deletions.
8 changes: 8 additions & 0 deletions SciLean/Data/DataArray/Operations.lean
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def det {R} [RealScalar R] [PlainDataType R] (A : R^[I,I]) : R :=
let f := LinearMap.mk' R (fun x : R^[I] => (⊞ i => ∑ j, A[i,j] * x[j])) sorry_proof
LinearMap.det f


namespace Matrix

instance : HMul (R^[I,J]) (R^[J,K]) (R^[I,K]) where
Expand All @@ -153,6 +154,13 @@ instance : HPow (R^[I,I]) ℤ (R^[I,I]) where

end Matrix

noncomputable
def solve (A : R^[I,I]) (b : R^[I]) := A⁻¹ * b

noncomputable
def solve' (A : R^[I,I]) (B : R^[I,J]) := A⁻¹ * B


set_default_scalar R

open Scalar
Expand Down
101 changes: 101 additions & 0 deletions SciLean/Data/DataArray/Operations/Solve.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import SciLean.Data.DataArray.Operations.Inv
import SciLean.Data.DataArray.Operations.Vecmul

namespace SciLean

variable
{I : Type*} [IndexType I] [DecidableEq I]
{J : Type*} [IndexType J] [DecidableEq J]
{R : Type*} [RealScalar R] [PlainDataType R]
{X : Type*} [NormedAddCommGroup X] [NormedSpace R X]
{U : Type*} [NormedAddCommGroup U] [AdjointSpace R U] [CompleteSpace U]

namespace DataArrayN

set_option linter.unusedVariables false

variable
(A : X → R^[I,I]) (b : X → R^[I])
(hA : Differentiable R A) (hA' : ∀ x, (A x).Invertible)
(hb : Differentiable R b)


@[fun_prop]
theorem solve.arg_Ab.IsContinuousLinearMap_rule (A : R^[I,I]) (hb : IsContinuousLinearMap R b) :
IsContinuousLinearMap R (fun x => A.solve (b x)) := by unfold solve; fun_prop

include hA hA' hb in
@[fun_prop]
theorem solve.arg_Ab.Differentiable_rule :
Differentiable R (fun x => (A x).solve (b x)) := by unfold solve; fun_prop (disch:=apply hA')


include hA hA' hb in
@[fun_trans]
theorem solve.arg_A.fderiv_rule :
fderiv R (fun x => (A x).solve (b x))
=
fun x => fun dx =>L[R]
let dA := fderiv R A x dx
let db := fderiv R b x dx
let b := b x
let A := A x
(- A.solve (dA * A.solve b) + A.solve db) := by
unfold solve
conv =>
lhs
fun_trans (disch:=apply hA') only -- no idea why it does not work properly
sorry_proof



set_option trace.Meta.Tactic.fun_trans.rewrite true in
include hA hA' hb in
@[fun_trans]
theorem solve.arg_A.fwdFDeriv_rule :
fwdFDeriv R (fun x => (A x).solve (b x))
=
fun x dx =>
let AdA := fwdFDeriv R A x dx
let bdb := fwdFDeriv R b x dx
let A := AdA.1; let dA := AdA.2
let b := bdb.1; let db := bdb.2
let A' := A.inv
(A.solve b, - A.solve (dA * A.solve b) + A.solve db) := by
unfold solve; funext x dx
fun_trans (disch:=apply hA')
cases fwdFDeriv R A x dx; cases fwdFDeriv R b x dx;
dsimp
sorry_proof -- done up tomodulo associativity


@[fun_trans]
theorem solve.arg_Ab.revFDeriv_rule_matrix (A : U → R^[I,I]) (b : U → R^[I])
(hA : Differentiable R A) (hA' : ∀ x, (A x).Invertible)
(hb : Differentiable R b) :
revFDeriv R (fun x => (A x).solve (b x))
=
fun x =>
let AdA := revFDeriv R A x
let bdb := revFDeriv R b x
let A := AdA.1; let dA := AdA.2
let b := bdb.1; let db := bdb.2
let A' := Aᵀ
(A.solve b, fun y : R^[I] =>
let b' := A.solve b
let y' := A'.solve y
let du₁ := - dA (y'.outerprod b')
let du₂ := db y'
du₁ + du₂) := by

funext x
conv =>
lhs; unfold revFDeriv; dsimp; enter[2]
fun_trans (disch:=apply hA')
unfold solve
fun_trans

simp only [revFDeriv.revFDeriv_fst, Prod.mk.injEq, true_and]
funext y
have h : ∀ (A : R^[I,I]), A⁻¹ᵀ = Aᵀ⁻¹ := sorry_proof
simp[revFDeriv,solve,h]
61 changes: 61 additions & 0 deletions SciLean/Data/DataArray/Operations/Vecmul.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,64 @@ def_fun_trans vecmul in A x : revFDeriv R by
rw[revFDeriv_wrt_prod (by fun_prop)]
unfold revFDeriv
autodiff




----------------------------------------------------------------------------------------------------

#check zero_mul
#check mul_zero
namespace DataArrayN

variable
{I : Type*} [IndexType I] [DecidableEq I]
{J : Type*} [IndexType J] [DecidableEq J]
{K : Type*} [IndexType K] [DecidableEq K]
{R : Type*} [RealScalar R] [PlainDataType R]

@[simp, simp_core]
theorem identity_vecmul (x : R^[I]) : Matrix.identity (R:=R) (I:=I) * x = A := sorry

@[simp, simp_core]
theorem zero_vecmul (b : R^[J]) : (0 : R^[I,J]) * b = 0 := sorry

@[simp, simp_core]
theorem vecmul_zero (A : R^[I,J]) : A * (0 : R^[J]) = 0 := sorry

end DataArrayN


variable
{I : Type*} [IndexType I] [DecidableEq I]
{J : Type*} [IndexType J] [DecidableEq J]
{K : Type*} [IndexType K] [DecidableEq K]
{R : Type*} [RealScalar R] [PlainDataType R]


def_fun_prop _matrixvec (b : R^[J]) : IsContinuousLinearMap R (fun A : R^[I,J] => A * b) by
simp[HMul.hMul]; fun_prop

def_fun_prop _matrixvec (A : R^[I,J]) : IsContinuousLinearMap R (fun b : R^[J] => A * b) by
simp[HMul.hMul]; fun_prop

def_fun_prop _matrixvec : Differentiable R (fun Ab : R^[I,J] × R^[J] => Ab.1 * Ab.2) by
simp[HMul.hMul]; fun_prop

abbrev_fun_trans _matrixvec : fderiv R (fun Ab : R^[I,J] × R^[J] => Ab.1 * Ab.2) by
rw[fderiv_wrt_prod]
fun_trans

abbrev_fun_trans _matrixvec : fwdFDeriv R (fun Ab : R^[I,J] × R^[J] => Ab.1 * Ab.2) by
rw[fwdFDeriv_wrt_prod]; unfold fwdFDeriv; autodiff

abbrev_fun_trans _matrixvec (b : R^[J]) : adjoint R (fun A : R^[I,J] => A * b) by
equals (fun c => c.outerprod b) =>
simp[HMul.hMul]; fun_trans; rfl

abbrev_fun_trans _matrixvec (A : R^[I,J]) : adjoint R (fun b : R^[J] => A * b) by
equals (fun c => Aᵀ * c) =>
simp[HMul.hMul]; fun_trans; rfl

abbrev_fun_trans _matrixvec : revFDeriv R (fun Ab : R^[I,J] × R^[J] => Ab.1 * Ab.2) by
rw[revFDeriv_wrt_prod]; unfold revFDeriv; autodiff

0 comments on commit e96f325

Please sign in to comment.