Skip to content

Commit

Permalink
removed priority on forall instance for SemiInnerProductSpace
Browse files Browse the repository at this point in the history
 - it felt unnecessary but at the same time it altered gradient computation for
 2d convolution ... very odd
  • Loading branch information
lecopivo committed Sep 28, 2023
1 parent f3a70d6 commit b3d6e63
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 19 deletions.
22 changes: 9 additions & 13 deletions SciLean/Core/Objects/FinVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ The class `FinVec ι K X` guarantees that any element `x : X` can be writtens as
```
-/
class Basis (ι : outParam $ Type v) (K : outParam $ Type w)(X : Type u) where
basis : ι X
proj : ι → X → K
basis (i : ι) : X
proj (i : ι) (x : X) : K

/-- Dual basis of the space `X` over the field `K` indexed by `ι`
Expand All @@ -27,16 +27,16 @@ and that it is dual to the normal basis
```
-/
class DualBasis (ι : outParam $ Type v) (K : outParam $ Type w) (X : Type u) where
dualBasis : ι X
dualProj : ι → X → K
dualBasis (i : ι) : X
dualProj (i : ι) (x : X) : K

/-- This should somehow relate to raising and lowering indices but I forgot how.
TODO: add explanation why this is useful
-/
class BasisDuality (X : Type u) where
toDual : X X -- transforms basis vectors to dual basis vectors
fromDual : X X -- transforma dual basis vectors to basis vectors
toDual (x : X) : X -- transforms basis vectors to dual basis vectors
fromDual (x : X) : X -- transforma dual basis vectors to basis vectors

section Basis

Expand Down Expand Up @@ -177,11 +177,6 @@ theorem proj_basis (i j : ι)
: ℼ i (ⅇ[X] j) = if i=j then 1 else 0 :=
by simp only [←inner_dualBasis_proj, inner_basis_dualBasis, eq_comm]; done

@[simp]
theorem proj_zero
: ℼ i (0 : X) = 0 :=
by sorry_proof

@[simp]
theorem dualProj_dualBasis (i j : ι)
: ℼ' i (ⅇ'[X] j) = if i=j then 1 else 0 :=
Expand Down Expand Up @@ -211,16 +206,17 @@ instance [EnumType ι] [EnumType κ] [Zero X] [Basis κ K X] [OrthonormalBasis
is_orthonormal := by simp[Inner.inner, Basis.basis]; sorry_proof


instance (priority:=high) {ι K} [EnumType ι] [IsROrC K]
instance (priority:=high) {ι : Type} {K : Type v} [EnumType ι] [IsROrC K]
: FinVec ι K (ι → K) where
is_basis := sorry_proof
duality := sorry_proof
to_dual := sorry_proof
from_dual := sorry_proof

instance {ι κ K X} [EnumType ι] [EnumType κ] [IsROrC K] [FinVec κ K X]
instance {ι κ : Type} {K X : Type _} [EnumType ι] [EnumType κ] [IsROrC K] [FinVec κ K X]
: FinVec (ι×κ) K (ι → X) where
is_basis := sorry_proof
duality := sorry_proof
to_dual := sorry_proof
from_dual := sorry_proof

2 changes: 1 addition & 1 deletion SciLean/Core/Objects/SemiInnerProductSpace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,6 @@ instance (X Y) [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y] : SemiInn

-- instance (X) [SemiInnerProductSpace K X] (ι) [Fintype ι] : SemiInnerProductSpace K (ι → X) := SemiInnerProductSpace.mkSorryProofs
-- instance (X) [SemiInnerProductSpace K X] (ι) [EnumType ι] : SemiInnerProductSpace K (ι → X) := SemiInnerProductSpace.mkSorryProofs
instance (priority:=low) (ι) (X : ι → Type _) [∀ i, SemiInnerProductSpace K (X i)] [EnumType ι] : SemiInnerProductSpace K ((i : ι) → X i)
instance (ι) (X : ι → Type _) [∀ i, SemiInnerProductSpace K (X i)] [EnumType ι] : SemiInnerProductSpace K ((i : ι) → X i)
:= SemiInnerProductSpace.mkSorryProofs

6 changes: 4 additions & 2 deletions test/basic_gradients.lean
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,11 @@ example (w : K ^ (Idx' (-5) 5 × Idx' (-5) 5))
: (∇ (x : K ^ (Idx 10 × Idx 10)), ⊞ (i : Idx 10 × Idx 10) => ∑ j, w[j] * x[(j.1.1 +ᵥ i.1, j.2.1 +ᵥ i.2)])
=
fun _x dy =>
⊞ i => ∑ j, w[j] * dy[(-j.fst.1 +ᵥ i.fst, -j.snd.1 +ᵥ i.snd)] :=
-- ⊞ i => ∑ j, w[j] * dy[(-j.fst.1 +ᵥ i.fst, -j.snd.1 +ᵥ i.snd)] :=
⊞ i => ∑ (j : (Idx' (-5) 5 × Idx' (-5) 5)), w[(j.2,j.1)] * dy[(-j.2.1 +ᵥ i.fst, -j.1.1 +ᵥ i.snd)] :=
by
conv => lhs; autodiff; autodiff; simp
conv =>
lhs; autodiff; autodiff; simp



6 changes: 3 additions & 3 deletions test/generate_ftrans.lean
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,16 @@ info: mymul.arg_xy.IsDifferentiable_rule.{w, u} {K : Type u} [instK : IsROrC K]

variable
{K : Type u} [RealScalar K]
{X : Type u} [SemiInnerProductSpace K X]
{ι : Type v} {κ : Type v'} [EnumType ι] [EnumType κ]
{X : Type v} [SemiInnerProductSpace K X]
{ι : Type} {κ : Type} [EnumType ι] [EnumType κ]

set_default_scalar K

def matmul (A : ι → κ → K) (x : κ → K) (i : ι) : K := ∑ j, A i j * x j

#generate_revCDeriv matmul A x
prop_by unfold matmul; fprop
trans_by unfold matmul; autodiff; autodiff
trans_by unfold matmul; ftrans only; autodiff; autodiff

#generate_revCDeriv matmul A x | i
prop_by unfold matmul; fprop
Expand Down

0 comments on commit b3d6e63

Please sign in to comment.