Skip to content

Commit

Permalink
Define ordering and equality on Char (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored Dec 7, 2020
2 parents 23cdd2c + 9e94421 commit 6f79a59
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 4 deletions.
15 changes: 15 additions & 0 deletions examples/eval-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,21 @@ s1 = "hello world"
:p s1
> ['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd']

:p codepoint 'a'
> 97

:p 'a' == 'a'
> True

:p 'a' == 'A'
> False

:p 'a' < 'b'
> True

:p 'a' > 'b'
> False

:p
x = 2 + 2
y = 2 + 4
Expand Down
12 changes: 10 additions & 2 deletions prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,6 @@ def argmin (_:Ord o) ?=> (xs:n=>o) : n =
zipped = for i. (i, xs.i)
fst $ reduce zeroth compare zipped


'Automatic differentiation

-- TODO: add vector space constraints
Expand Down Expand Up @@ -793,6 +792,8 @@ splitV : Iso a ({|} | a) =
def sliceFields (iso: Iso ({|} | a) (b | c)) (tab: a=>v) : b=>v =
reindex (buildWith $ splitV &>> iso) tab

'## Strings and Characters

Char : Type = %Char
def MkChar (c:Int8) : Char = %MkChar c

Expand All @@ -803,6 +804,14 @@ CharPtr : Type = %CharPtr
interface Show a:Type where
show : a -> String

-- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint
def codepoint (c:Char) : Int8 = %codePoint c

@instance charEq : Eq Char = MkEq \x y. codepoint x == codepoint y
@instance charOrd : Ord Char = (MkOrd charEq (\x y. codepoint x > codepoint y)
(\x y. codepoint x < codepoint y))


def showFloat' (x:Float) : String =
(n, ptr) = %ffi showFloat (Int & CharPtr) x
AsList n $ for i:(Fin n).
Expand All @@ -811,7 +820,6 @@ def showFloat' (x:Float) : String =
instance showFloat : Show Float where
show = showFloat'


'## Floating point helper functions

def sign (x:Float) : Float =
Expand Down
4 changes: 3 additions & 1 deletion src/lib/Autodiff.hs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ linearizeOp op = case op of
UnsafeFromOrdinal _ _ -> emitDiscrete
ToOrdinal _ -> emitDiscrete
IdxSetSize _ -> emitDiscrete
CodePoint _ -> emitDiscrete
ThrowError _ -> emitWithZero
CastOp t v -> do
if tangentType vt == vt && tangentType t == t
Expand Down Expand Up @@ -627,8 +628,9 @@ transposeOp op ct = case op of
SliceOffset _ _ -> notLinear
SliceCurry _ _ -> notLinear
UnsafeFromOrdinal _ _ -> notLinear
ToOrdinal _ -> notLinear
ToOrdinal _ -> notLinear
IdxSetSize _ -> notLinear
CodePoint _ -> notLinear
ThrowError _ -> notLinear
FFICall _ _ _ -> notLinear
where
Expand Down
1 change: 1 addition & 0 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ toImpOp (maybeDest, op) = case op of
tileOffset' <- iaddI (fromScalarAtom tileOffset) extraOffset
returnVal $ toScalarAtom tileOffset'
ThrowError _ -> throwError ()
CodePoint ~(Con (CharCon x)) -> returnVal x
CastOp destTy x -> case (getType x, destTy) of
(BaseTy _, BaseTy bt) -> returnVal =<< toScalarAtom <$> cast (fromScalarAtom x) bt
_ -> error $ "Invalid cast: " ++ pprint (getType x) ++ " -> " ++ pprint destTy
Expand Down
4 changes: 3 additions & 1 deletion src/lib/Syntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ data PrimOp e =
| ToOrdinal e
| IdxSetSize e
| ThrowError e
| CodePoint e
| CastOp e e -- Type, then value. See Type.hs for valid coercions.
-- Extensible record and variant operations:
-- Add fields to a record (on the left). Left arg contains values to add.
Expand Down Expand Up @@ -1472,7 +1473,8 @@ builtinNames = M.fromList
, ("vfadd", vbinOp FAdd), ("vfsub", vbinOp FSub), ("vfmul", vbinOp FMul)
, ("idxSetSize" , OpExpr $ IdxSetSize ())
, ("unsafeFromOrdinal", OpExpr $ UnsafeFromOrdinal () ())
, ("toOrdinal" , OpExpr $ ToOrdinal ())
, ("toOrdinal" , OpExpr $ ToOrdinal ())
, ("codePoint" , OpExpr $ CodePoint ())
, ("throwError" , OpExpr $ ThrowError ())
, ("ask" , OpExpr $ PrimEffect () $ MAsk)
, ("tell" , OpExpr $ PrimEffect () $ MTell ())
Expand Down
4 changes: 4 additions & 0 deletions src/lib/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,10 @@ typeCheckOp op = case op of
i |: TC (IntRange (IdxRepVal 0) (IdxRepVal $ fromIntegral vectorWidth))
return $ BaseTy $ Scalar sb
ThrowError ty -> ty|:TyKind $> ty
-- TODO: this should really be a 32 bit integer for unicode code point: but for now is 8 bit ASCII code point
CodePoint c -> do
c |: CharTy
return $ BaseTy $ Scalar Int8Type
CastOp t@(Var _) _ -> t |: TyKind $> t
CastOp destTy e -> do
sourceTy <- typeCheck e
Expand Down

0 comments on commit 6f79a59

Please sign in to comment.