From 71369201003389c6790e23f38ecc5b78ba03319c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Sun, 6 Dec 2020 17:34:46 +0000 Subject: [PATCH 1/4] add codepount, Eq and Ord for Char --- examples/eval-tests.dx | 15 +++++++++++++++ prelude.dx | 13 +++++++++++-- src/lib/Autodiff.hs | 4 +++- src/lib/Imp.hs | 2 ++ src/lib/Syntax.hs | 4 +++- src/lib/Type.hs | 3 +++ 6 files changed, 37 insertions(+), 4 deletions(-) diff --git a/examples/eval-tests.dx b/examples/eval-tests.dx index ad2b61259..63c16418c 100644 --- a/examples/eval-tests.dx +++ b/examples/eval-tests.dx @@ -771,6 +771,21 @@ s1 = "hello world" :p s1 > ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd'] +:p codepoint 'a' +97 + +:p 'a' == 'a' +True + +:p 'a' == 'A' +False + +:p 'a' < 'b' +True + +:p 'a' > 'b' +false + :p x = 2 + 2 y = 2 + 4 diff --git a/prelude.dx b/prelude.dx index a9705dafe..1223357fa 100644 --- a/prelude.dx +++ b/prelude.dx @@ -545,7 +545,6 @@ def argmin (_:Ord o) ?=> (xs:n=>o) : n = zipped = for i. (i, xs.i) fst $ reduce zeroth compare zipped - 'Automatic differentiation -- TODO: add vector space constraints @@ -793,9 +792,12 @@ splitV : Iso a ({|} | a) = def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = reindex (buildWith $ splitV &>> iso) tab +'## Strings and Characters + Char : Type = %Char def MkChar (c:Int8) : Char = %MkChar c + String : Type = List Char CharPtr : Type = %CharPtr @@ -803,6 +805,14 @@ CharPtr : Type = %CharPtr interface Show a:Type where show : a -> String +-- TODO should be Int32 for codepoint (current is just the first codeunit) +def codepoint (c:Char) : Int8 = %codePoint c + +@instance charEq : Eq Char = MkEq \x y. codepoint x == codepoint y +@instance charOrd : Ord Char = (MkOrd charEq (\x y. codepoint x > codepoint y) + (\x y. codepoint x < codepoint y)) + + def showFloat' (x:Float) : String = (n, ptr) = %ffi showFloat (Int & CharPtr) x AsList n $ for i:(Fin n). @@ -811,7 +821,6 @@ def showFloat' (x:Float) : String = instance showFloat : Show Float where show = showFloat' - '## Floating point helper functions def sign (x:Float) : Float = diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index f3bc926ac..65ad2392e 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -150,6 +150,7 @@ linearizeOp op = case op of UnsafeFromOrdinal _ _ -> emitDiscrete ToOrdinal _ -> emitDiscrete IdxSetSize _ -> emitDiscrete + CodePoint _ -> emitDiscrete ThrowError _ -> emitWithZero CastOp t v -> do if tangentType vt == vt && tangentType t == t @@ -627,8 +628,9 @@ transposeOp op ct = case op of SliceOffset _ _ -> notLinear SliceCurry _ _ -> notLinear UnsafeFromOrdinal _ _ -> notLinear - ToOrdinal _ -> notLinear + ToOrdinal _ -> notLinear IdxSetSize _ -> notLinear + CodePoint _ -> notLinear ThrowError _ -> notLinear FFICall _ _ _ -> notLinear where diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 7717127b7..74ba926df 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -244,6 +244,8 @@ toImpOp (maybeDest, op) = case op of tileOffset' <- iaddI (fromScalarAtom tileOffset) extraOffset returnVal $ toScalarAtom tileOffset' ThrowError _ -> throwError () + CodePoint ~(Con (CharCon x)) -> do + returnVal x CastOp destTy x -> case (getType x, destTy) of (BaseTy _, BaseTy bt) -> returnVal =<< toScalarAtom <$> cast (fromScalarAtom x) bt _ -> error $ "Invalid cast: " ++ pprint (getType x) ++ " -> " ++ pprint destTy diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 92d3362c8..f18a4a1b0 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -332,6 +332,7 @@ data PrimOp e = | ToOrdinal e | IdxSetSize e | ThrowError e + | CodePoint e | CastOp e e -- Type, then value. See Type.hs for valid coercions. -- Extensible record and variant operations: -- Add fields to a record (on the left). Left arg contains values to add. @@ -1473,7 +1474,8 @@ builtinNames = M.fromList , ("vfadd", vbinOp FAdd), ("vfsub", vbinOp FSub), ("vfmul", vbinOp FMul) , ("idxSetSize" , OpExpr $ IdxSetSize ()) , ("unsafeFromOrdinal", OpExpr $ UnsafeFromOrdinal () ()) - , ("toOrdinal" , OpExpr $ ToOrdinal ()) + , ("toOrdinal" , OpExpr $ ToOrdinal ()) + , ("codePoint" , OpExpr $ CodePoint ()) , ("throwError" , OpExpr $ ThrowError ()) , ("ask" , OpExpr $ PrimEffect () $ MAsk) , ("tell" , OpExpr $ PrimEffect () $ MTell ()) diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 939eb93f4..f8351d67f 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -728,6 +728,9 @@ typeCheckOp op = case op of i |: TC (IntRange (IdxRepVal 0) (IdxRepVal $ fromIntegral vectorWidth)) return $ BaseTy $ Scalar sb ThrowError ty -> ty|:TyKind $> ty + -- TODO: type check that c is a character + -- TODO: this should really be a 32 bit integer for unicode reasons: 8 bit is just 1 codeunit + CodePoint c -> return $ BaseTy $ Scalar Int8Type CastOp t@(Var _) _ -> t |: TyKind $> t CastOp destTy e -> do sourceTy <- typeCheck e From e11c7838987303242dae0aa7be2dc4383faac035 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Sun, 6 Dec 2020 17:38:30 +0000 Subject: [PATCH 2/4] Clean up comments --- prelude.dx | 3 +-- src/lib/Imp.hs | 3 +-- src/lib/Type.hs | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/prelude.dx b/prelude.dx index 1223357fa..9e6526322 100644 --- a/prelude.dx +++ b/prelude.dx @@ -797,7 +797,6 @@ def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = Char : Type = %Char def MkChar (c:Int8) : Char = %MkChar c - String : Type = List Char CharPtr : Type = %CharPtr @@ -805,7 +804,7 @@ CharPtr : Type = %CharPtr interface Show a:Type where show : a -> String --- TODO should be Int32 for codepoint (current is just the first codeunit) +-- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint def codepoint (c:Char) : Int8 = %codePoint c @instance charEq : Eq Char = MkEq \x y. codepoint x == codepoint y diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 74ba926df..9341f2720 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -244,8 +244,7 @@ toImpOp (maybeDest, op) = case op of tileOffset' <- iaddI (fromScalarAtom tileOffset) extraOffset returnVal $ toScalarAtom tileOffset' ThrowError _ -> throwError () - CodePoint ~(Con (CharCon x)) -> do - returnVal x + CodePoint ~(Con (CharCon x)) -> returnVal x CastOp destTy x -> case (getType x, destTy) of (BaseTy _, BaseTy bt) -> returnVal =<< toScalarAtom <$> cast (fromScalarAtom x) bt _ -> error $ "Invalid cast: " ++ pprint (getType x) ++ " -> " ++ pprint destTy diff --git a/src/lib/Type.hs b/src/lib/Type.hs index f8351d67f..8b8834bfc 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -729,7 +729,7 @@ typeCheckOp op = case op of return $ BaseTy $ Scalar sb ThrowError ty -> ty|:TyKind $> ty -- TODO: type check that c is a character - -- TODO: this should really be a 32 bit integer for unicode reasons: 8 bit is just 1 codeunit + -- TODO: this should really be a 32 bit integer for unicode code point: but for now is 8 bit ASCII code point CodePoint c -> return $ BaseTy $ Scalar Int8Type CastOp t@(Var _) _ -> t |: TyKind $> t CastOp destTy e -> do From 8c2f150d59ccacdd8a0dff155b6eb3b32347df3b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Sun, 6 Dec 2020 17:47:33 +0000 Subject: [PATCH 3/4] Fix tests --- examples/eval-tests.dx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/eval-tests.dx b/examples/eval-tests.dx index 63c16418c..f2b0e1813 100644 --- a/examples/eval-tests.dx +++ b/examples/eval-tests.dx @@ -772,19 +772,19 @@ s1 = "hello world" > ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd'] :p codepoint 'a' -97 +> 97 :p 'a' == 'a' -True +> True :p 'a' == 'A' -False +> False :p 'a' < 'b' -True +> True :p 'a' > 'b' -false +> False :p x = 2 + 2 From 9e94421f4416f415c4728ace138f0d05ae266372 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Sun, 6 Dec 2020 22:37:41 +0000 Subject: [PATCH 4/4] typecheck input to CodePoint --- src/lib/Type.hs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 8b8834bfc..25c9f4dfa 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -728,9 +728,10 @@ typeCheckOp op = case op of i |: TC (IntRange (IdxRepVal 0) (IdxRepVal $ fromIntegral vectorWidth)) return $ BaseTy $ Scalar sb ThrowError ty -> ty|:TyKind $> ty - -- TODO: type check that c is a character -- TODO: this should really be a 32 bit integer for unicode code point: but for now is 8 bit ASCII code point - CodePoint c -> return $ BaseTy $ Scalar Int8Type + CodePoint c -> do + c |: CharTy + return $ BaseTy $ Scalar Int8Type CastOp t@(Var _) _ -> t |: TyKind $> t CastOp destTy e -> do sourceTy <- typeCheck e