Skip to content

Commit

Permalink
overhaul of ListN
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 10, 2023
1 parent 0dc2c29 commit 3789d19
Showing 1 changed file with 52 additions and 30 deletions.
82 changes: 52 additions & 30 deletions SciLean/Data/ListN.lean
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
import Lean

structure ListNn) where
val : List α
property : val.length = n
inductive ListN: Type) : Nat → Type
| nil : ListN α 0
| cons {n} (x : α) (xs : ListN α n) : ListN α (n+1)

namespace ListN

def toList (l : ListN α n) : List α :=
match l with
| .nil => []
| .cons x xs => x :: xs.toList

def toArray (l : ListN α n) : Array α := Id.run do
let mut a : Array α := .mkEmpty n
go a l
where
go {m} (a : Array α) (l : ListN α m) : Array α :=
match l with
| .nil => a
| .cons x xs => go (a.push x) xs


@[simp]
def map₂ (op : α → β → γ) (l : ListN α n) (l' : ListN β n) : ListN γ n :=
match n, l, l' with
| 0, ⟨[],_⟩, ⟨[], _⟩ => ⟨[], rfl⟩
| n+1, ⟨x :: xs, hx⟩, ⟨y :: ys, hy⟩ =>
let xs : ListN α n := ⟨xs, by simp at hx; exact hx⟩
let ys : ListN β n := ⟨ys, by simp at hy; exact hy⟩
let zs := map₂ op xs ys
⟨op x y :: zs.val, by simp; exact zs.property⟩
match l, l' with
| .nil, .nil => .nil
| .cons x xs, .cons y ys => .cons (op x y) (map₂ op xs ys)


instance [Add α] : Add (ListN α n) := ⟨fun x y => x.map₂ (·+·) y⟩
Expand All @@ -26,45 +36,57 @@ instance [Div α] : Div (ListN α n) := ⟨fun x y => x.map₂ (·/·) y⟩

@[simp]
theorem add_elemwise {n : Nat} [Add α]
(x y : α) (xs ys : List α) (hx : xs.length = n) (hy : ys.length = n)
: (ListN.mk (n:=n+1) (x :: xs) (by simp; exact hx)) + (ListN.mk (y :: ys) (by simp; exact hy))
(x y : α) (xs ys : ListN α n)
: (ListN.cons x xs) + (ListN.cons y ys)
=
ListN.mk ((x + y) :: (ListN.mk xs hx + ListN.mk ys hy).1) (by simp; exact (ListN.mk xs hx + ListN.mk ys hy).2) := by rfl
(ListN.cons (x+y) (xs + ys)) := by rfl

@[simp]
@[simp]
theorem sub_elemwise {n : Nat} [Sub α]
(x y : α) (xs ys : List α) (hx : xs.length = n) (hy : ys.length = n)
: (ListN.mk (n:=n+1) (x :: xs) (by simp; exact hx)) - (ListN.mk (y :: ys) (by simp; exact hy))
(x y : α) (xs ys : ListN α n)
: (ListN.cons x xs) - (ListN.cons y ys)
=
ListN.mk ((x - y) :: (ListN.mk xs hx - ListN.mk ys hy).1) (by simp; exact (ListN.mk xs hx - ListN.mk ys hy).2) := by rfl
(ListN.cons (x-y) (xs - ys)) := by rfl

@[simp]
theorem mul_elemwise {n : Nat} [Mul α]
(x y : α) (xs ys : List α) (hx : xs.length = n) (hy : ys.length = n)
: (ListN.mk (n:=n+1) (x :: xs) (by simp; exact hx)) * (ListN.mk (y :: ys) (by simp; exact hy))
(x y : α) (xs ys : ListN α n)
: (ListN.cons x xs) * (ListN.cons y ys)
=
ListN.mk ((x * y) :: (ListN.mk xs hx * ListN.mk ys hy).1) (by simp; exact (ListN.mk xs hx * ListN.mk ys hy).2) := by rfl
(ListN.cons (x*y) (xs * ys)) := by rfl

@[simp]
theorem div_elemwise {n : Nat} [Div α]
(x y : α) (xs ys : List α) (hx : xs.length = n) (hy : ys.length = n)
: (ListN.mk (n:=n+1) (x :: xs) (by simp; exact hx)) / (ListN.mk (y :: ys) (by simp; exact hy))
(x y : α) (xs ys : ListN α n)
: (ListN.cons x xs) / (ListN.cons y ys)
=
ListN.mk ((x / y) :: (ListN.mk xs hx / ListN.mk ys hy).1) (by simp; exact (ListN.mk xs hx / ListN.mk ys hy).2) := by rfl
(ListN.cons (x/y) (xs / ys)) := by rfl


--------------------------------------------------------------------------------
-- Notation --------------------------------------------------------------------
--------------------------------------------------------------------------------

open Lean in
/-- Notation for list literals with list lenght in its type. -/
macro "[" xs:term,* "]'" : term => do
syntax "[" term,* "]'" : term

open Lean in
macro_rules
| `(term| []') => `(ListN.nil)
| `(term| [$x:term]') => `(ListN.cons $x .nil)
| `(term| [$x:term, $xs:term,*]') => do
let n := Syntax.mkNumLit (toString xs.getElems.size)
`(ListN.mk (n:=$n) [$xs,*] rfl)
`(ListN.cons (n:=$n) $x [$xs,*]')

@[app_unexpander ListN.mk]
def unexpandListNMk : Lean.PrettyPrinter.Unexpander
| `($(_) [$xs,*] $_*) =>
`([$xs,*]')
@[app_unexpander ListN.nil]
def unexpandListNNil : Lean.PrettyPrinter.Unexpander
| `($(_)) =>
`([]')

@[app_unexpander ListN.cons]
def unexpandListNCons : Lean.PrettyPrinter.Unexpander
| `($(_) $x []') =>
`([$x]')
| `($(_) $x [$xs',*]') =>
`([$x, $xs',*]')
| _ => throw ()

0 comments on commit 3789d19

Please sign in to comment.