Skip to content

Commit

Permalink
use ListN in ArrayType literals
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 10, 2023
1 parent 3789d19 commit ba8615a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 13 deletions.
22 changes: 21 additions & 1 deletion SciLean/Data/ArrayType/Basic.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import SciLean.Util.SorryProof
import SciLean.Data.Index
import SciLean.Data.ListN

namespace SciLean

Expand Down Expand Up @@ -131,7 +132,7 @@ theorem getElem_map [ArrayType Cont Idx Elem] [EnumType Idx] (f : Elem → Elem)

instance [ArrayType Cont Idx Elem] [ToString Elem] [EnumType Idx] : ToString (Cont) := ⟨λ x => Id.run do
let mut fst := true
let mut s := "["
let mut s := "["
for i in fullRange Idx do
if fst then
s := s ++ toString x[i]
Expand All @@ -140,6 +141,25 @@ instance [ArrayType Cont Idx Elem] [ToString Elem] [EnumType Idx] : ToString (Co
s := s ++ ", " ++ toString x[i]
s ++ "]"

/-- Converts array to ArrayType
WARNING: Does not do what expected for arrays of size bigger or equal then USize.size
For example, array of size USize.size is converted to an array of size zero
-/
def _root_.Array.toArrayType {n Elem} (Cont : Type u) [ArrayType Cont (SciLean.Idx n) Elem]
(a : Array Elem) (_h : n = a.size.toUSize) : Cont :=
introElem fun (i : SciLean.Idx n) => a[i.1]'sorry_proof

/-- Converts ListN to ArrayType
WARNING: Does not do what expected for lists of size bigger or equal then USize.size
For example, array of size USize.size is converted to an array of size zero
-/
def _root_.ListN.toArrayType {n Elem} (Cont : Type) [ArrayType Cont (SciLean.Idx (n.toUSize)) Elem]
(l : ListN Elem n) : Cont :=
introElem fun i => l.toArray[i.1.toNat]'sorry_proof


section Operations

variable [ArrayType Cont Idx Elem] [EnumType Idx]
Expand Down
25 changes: 14 additions & 11 deletions SciLean/Data/ArrayType/Notation.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import SciLean.Data.ArrayType.Basic

import SciLean.Data.ListN
import Qq
namespace SciLean

open Lean Parser
Expand Down Expand Up @@ -76,25 +77,27 @@ def unexpandIntroElemNotation : Lean.PrettyPrinter.Unexpander
-- Notation: ⊞[1,2,3] --
------------------------

/-- Converts array to the canonical ArrayType

WARNING: Does not do what expected for arrays of size bigger or equal then USize.size
For example, array of size USize.size is converted to an array of size zero
-/
def _root_.Array.toArrayType {Elem} (a : Array Elem) (n : USize) (_h : n = a.size.toUSize)
{Cont} [ArrayType Cont (Idx n) Elem] [ArrayTypeNotation Cont (Idx n) Elem]
: Cont := ⊞ (i : Idx n) => a[i.1]'sorry_proof
syntax (name:=arrayTypeLiteral) " ⊞[" term,* "] " : term

macro " ⊞[" xs:term,* "] " : term => do
let n := Syntax.mkNumLit (toString xs.getElems.size)
`(term| Array.toArrayType #[$xs,*] $n (by rfl))
open Lean Meta Elab Term Qq
macro_rules
| `(⊞[ $x:term, $xs:term,* ]) => do
let n := Syntax.mkNumLit (toString (xs.getElems.size + 1))
`(ListN.toArrayType (arrayTypeCont (Idx ($n).toUSize) (typeOf $x)) [$x,$xs,*]')
-- let n := Syntax.mkNumLit (toString xs.getElems.size)
-- `(term| ListN.toArrayType (arrayType #[$xs,*] $n (by rfl))

@[app_unexpander Array.toArrayType]
def unexpandArrayToArrayType : Lean.PrettyPrinter.Unexpander
| `($(_) #[$ys,*] $_*) =>
`(⊞[$ys,*])
| _ => throw ()

-- variable {CC : USize → Type} [∀ n, ArrayType (CC n) (Idx n) Float] [∀ n, ArrayTypeNotation (CC n) (Idx n) Float]
-- #check [1.0,2.0,3.0]'.toArrayType (CC 3)
-- #check ⊞[1.0,2.0,3.0]


-- Notation: Float ^ Idx n --
-----------------------------
Expand Down
6 changes: 5 additions & 1 deletion SciLean/Data/ListN.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import Lean

inductive ListN (α : Type) : Nat → Type
inductive ListN (α : Type u) : Nat → Type u
| nil : ListN α 0
| cons {n} (x : α) (xs : ListN α n) : ListN α (n+1)

Expand All @@ -20,6 +20,8 @@ where
| .nil => a
| .cons x xs => go (a.push x) xs

instance [ToString α] : ToString (ListN α n) := ⟨fun l => toString l.toList ++ "'"


@[simp]
def map₂ (op : α → β → γ) (l : ListN α n) (l' : ListN β n) : ListN γ n :=
Expand Down Expand Up @@ -90,3 +92,5 @@ def unexpandListNCons : Lean.PrettyPrinter.Unexpander
| `($(_) $x [$xs',*]') =>
`([$x, $xs',*]')
| _ => throw ()


0 comments on commit ba8615a

Please sign in to comment.