diff --git a/examples/eval-tests.dx b/examples/eval-tests.dx index ad2b61259..f2b0e1813 100644 --- a/examples/eval-tests.dx +++ b/examples/eval-tests.dx @@ -771,6 +771,21 @@ s1 = "hello world" :p s1 > ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd'] +:p codepoint 'a' +> 97 + +:p 'a' == 'a' +> True + +:p 'a' == 'A' +> False + +:p 'a' < 'b' +> True + +:p 'a' > 'b' +> False + :p x = 2 + 2 y = 2 + 4 diff --git a/prelude.dx b/prelude.dx index a9705dafe..9e6526322 100644 --- a/prelude.dx +++ b/prelude.dx @@ -545,7 +545,6 @@ def argmin (_:Ord o) ?=> (xs:n=>o) : n = zipped = for i. (i, xs.i) fst $ reduce zeroth compare zipped - 'Automatic differentiation -- TODO: add vector space constraints @@ -793,6 +792,8 @@ splitV : Iso a ({|} | a) = def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v = reindex (buildWith $ splitV &>> iso) tab +'## Strings and Characters + Char : Type = %Char def MkChar (c:Int8) : Char = %MkChar c @@ -803,6 +804,14 @@ CharPtr : Type = %CharPtr interface Show a:Type where show : a -> String +-- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint +def codepoint (c:Char) : Int8 = %codePoint c + +@instance charEq : Eq Char = MkEq \x y. codepoint x == codepoint y +@instance charOrd : Ord Char = (MkOrd charEq (\x y. codepoint x > codepoint y) + (\x y. codepoint x < codepoint y)) + + def showFloat' (x:Float) : String = (n, ptr) = %ffi showFloat (Int & CharPtr) x AsList n $ for i:(Fin n). @@ -811,7 +820,6 @@ def showFloat' (x:Float) : String = instance showFloat : Show Float where show = showFloat' - '## Floating point helper functions def sign (x:Float) : Float = diff --git a/src/lib/Autodiff.hs b/src/lib/Autodiff.hs index f3bc926ac..65ad2392e 100644 --- a/src/lib/Autodiff.hs +++ b/src/lib/Autodiff.hs @@ -150,6 +150,7 @@ linearizeOp op = case op of UnsafeFromOrdinal _ _ -> emitDiscrete ToOrdinal _ -> emitDiscrete IdxSetSize _ -> emitDiscrete + CodePoint _ -> emitDiscrete ThrowError _ -> emitWithZero CastOp t v -> do if tangentType vt == vt && tangentType t == t @@ -627,8 +628,9 @@ transposeOp op ct = case op of SliceOffset _ _ -> notLinear SliceCurry _ _ -> notLinear UnsafeFromOrdinal _ _ -> notLinear - ToOrdinal _ -> notLinear + ToOrdinal _ -> notLinear IdxSetSize _ -> notLinear + CodePoint _ -> notLinear ThrowError _ -> notLinear FFICall _ _ _ -> notLinear where diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 7717127b7..9341f2720 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -244,6 +244,7 @@ toImpOp (maybeDest, op) = case op of tileOffset' <- iaddI (fromScalarAtom tileOffset) extraOffset returnVal $ toScalarAtom tileOffset' ThrowError _ -> throwError () + CodePoint ~(Con (CharCon x)) -> returnVal x CastOp destTy x -> case (getType x, destTy) of (BaseTy _, BaseTy bt) -> returnVal =<< toScalarAtom <$> cast (fromScalarAtom x) bt _ -> error $ "Invalid cast: " ++ pprint (getType x) ++ " -> " ++ pprint destTy diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index 92d3362c8..f18a4a1b0 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -332,6 +332,7 @@ data PrimOp e = | ToOrdinal e | IdxSetSize e | ThrowError e + | CodePoint e | CastOp e e -- Type, then value. See Type.hs for valid coercions. -- Extensible record and variant operations: -- Add fields to a record (on the left). Left arg contains values to add. @@ -1473,7 +1474,8 @@ builtinNames = M.fromList , ("vfadd", vbinOp FAdd), ("vfsub", vbinOp FSub), ("vfmul", vbinOp FMul) , ("idxSetSize" , OpExpr $ IdxSetSize ()) , ("unsafeFromOrdinal", OpExpr $ UnsafeFromOrdinal () ()) - , ("toOrdinal" , OpExpr $ ToOrdinal ()) + , ("toOrdinal" , OpExpr $ ToOrdinal ()) + , ("codePoint" , OpExpr $ CodePoint ()) , ("throwError" , OpExpr $ ThrowError ()) , ("ask" , OpExpr $ PrimEffect () $ MAsk) , ("tell" , OpExpr $ PrimEffect () $ MTell ()) diff --git a/src/lib/Type.hs b/src/lib/Type.hs index 939eb93f4..25c9f4dfa 100644 --- a/src/lib/Type.hs +++ b/src/lib/Type.hs @@ -728,6 +728,10 @@ typeCheckOp op = case op of i |: TC (IntRange (IdxRepVal 0) (IdxRepVal $ fromIntegral vectorWidth)) return $ BaseTy $ Scalar sb ThrowError ty -> ty|:TyKind $> ty + -- TODO: this should really be a 32 bit integer for unicode code point: but for now is 8 bit ASCII code point + CodePoint c -> do + c |: CharTy + return $ BaseTy $ Scalar Int8Type CastOp t@(Var _) _ -> t |: TyKind $> t CastOp destTy e -> do sourceTy <- typeCheck e