Skip to content

Commit

Permalink
import ML module by default
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 6, 2023
1 parent 3a7aaf4 commit bdc875c
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 7 deletions.
1 change: 1 addition & 0 deletions SciLean.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import SciLean.Data.ArrayType
import SciLean.Data.DataArray

import SciLean.Modules.DifferentialEquations
import SciLean.Modules.ML

import SciLean.Tactic.LSimp2.Elab

Expand Down
5 changes: 4 additions & 1 deletion SciLean/Data/Idx.lean
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,12 @@ instance : HSub (Idx n) Int64 (Idx n) := ⟨λ x y => ⟨(x.1 - (y.1 + n))%n, so
instance : VAdd Int64 (Idx n) := ⟨λ x y => y + x⟩

def toFin {n} (i : Idx n) : Fin n.toNat := ⟨i.1.toNat, sorry_proof⟩
def toFloat {n} (i : Idx n) : Float := i.1.toNat.toFloat
def toFin' {n : Nat} (i : Idx n.toUSize) : Fin n := ⟨i.1.toNat, sorry_proof⟩

@[extern c inline "(double)#1"]
def _root_.USize.toFloat (n : USize) : Float := n.toNat.toFloat
def toFloat {n} (i : Idx n) : Float := i.1.toFloat

def shiftPos (x : Idx n) (s : USize) := x + s
def shiftNeg (x : Idx n) (s : USize) := x - s
def shift (x : Idx n) (s : Int) :=
Expand Down
9 changes: 5 additions & 4 deletions SciLean/Data/Index.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export Index (toIdx fromIdx)

namespace Index

@[macro_inline]
-- @[macro_inline]
instance : Index Empty where
size := 0
isValid := true
Expand All @@ -34,7 +34,7 @@ instance : Index Empty where
fromIdx_toIdx := sorry_proof
toIdx_fromIdx := sorry_proof

@[macro_inline]
-- @[macro_inline]
instance : Index Unit where
size := 1
isValid := true
Expand All @@ -45,7 +45,7 @@ instance : Index Unit where
fromIdx_toIdx := sorry_proof
toIdx_fromIdx := sorry_proof

@[macro_inline]
-- @[macro_inline]
instance : Index (Idx n) where
size := n
isValid := true
Expand All @@ -56,6 +56,7 @@ instance : Index (Idx n) where
fromIdx_toIdx := by simp
toIdx_fromIdx := by simp

-- @[macro_inline]
instance : Index (Idx' a b) where
size := let n := b - a; if 0 < n then n.toUSize else 0
isValid := true
Expand Down Expand Up @@ -88,7 +89,7 @@ instance [Index ι] [Index κ] : Index (ι×κ) where


-- Row major ordering, this respects `<` defined on `ι × κ`
@[macro_inline]
-- @[macro_inline]
instance [Index ι] [Index κ] : Index (ι×ₗκ) where
size := (min ((size ι).toNat * (size κ).toNat) (USize.size -1)).toUSize
isValid :=
Expand Down
4 changes: 2 additions & 2 deletions SciLean/Modules/ML/Convolution.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import SciLean.Data.Prod

namespace SciLean.ML

set_option synthInstance.maxSize 2000

variable
{R : Type} [RealScalar R] [PlainDataType R]

Expand All @@ -19,5 +21,3 @@ def conv2d
prop_by unfold conv2d; fprop
trans_by unfold conv2d; ftrans

#exit

1 change: 1 addition & 0 deletions SciLean/Modules/ML/Loss.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import SciLean.Core
import SciLean.Data.DataArray
import SciLean.Data.ArrayType
import SciLean.Data.Prod
import SciLean.Core.Functions.Exp

Expand Down

0 comments on commit bdc875c

Please sign in to comment.