Skip to content

Commit

Permalink
Create ADTs separately (#3)
Browse files Browse the repository at this point in the history
Use SCtrs with ADT Name, Constructor Name, and SValue fields.
Create symbolic ADTs by using a symbolic selector value and instantiate fields recursively, selecting the instance with nested 'merge' calls (if-then-else).
  • Loading branch information
SophieBosio authored Apr 29, 2024
1 parent 758548a commit 5dba64e
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 97 deletions.
3 changes: 2 additions & 1 deletion examples/congruence.con
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,5 @@ exampleCongruent x y =*=
(xx == yy)
; False -> True .

main = exampleCongruent (C {5}) (C {5}) .


5 changes: 4 additions & 1 deletion src/Analysis/TypeInferrer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ fresh = Variable' <$> (get >>= \i -> -- Get current, fresh index (state)
return i) -- Return fresh

bind :: Eq x => x -> a -> x `MapsTo` a
bind x a look y = if x == y then a else look y
bind x a look y = if x == y -- Applying the bindings to some 'y' equal to 'x'
then a -- you should now get back 'a'
else look y -- If you call it with some other 'y',
-- then return the old binding for 'y'

newConstraint :: Type -> Type -> String -> Constraint
newConstraint t1 t2 msg =
Expand Down
18 changes: 6 additions & 12 deletions src/Environment/Environment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ module Environment.Environment where

import Core.Syntax

import Data.List (elemIndex)


type Mapping a b = a -> b
type MapsTo a b = Mapping a b -> Mapping a b
Expand Down Expand Up @@ -69,18 +71,10 @@ programEnvironment p =
case lookup d (datatypes p) of
Nothing -> error $ "Couldn't find data type with name '" ++ d ++ "'"
Just cs ->
case findSelector 0 c cs of
Just s -> return s
case elemIndex c (map nameOf cs) of
Just s -> return $ toInteger s
Nothing -> error $ "Constructor '" ++ c ++
"' not found in data type declaration of type '" ++ d ++ "'"
where
nameOf (Constructor x _) = x
}

matches :: C -> Constructor -> Bool
matches c (Constructor d _) = c == d

findSelector :: Integer -> C -> [Constructor] -> Maybe Integer
findSelector _ _ [] = Nothing
findSelector i c (ctr : ctrs)
| c `matches` ctr = Just i
| otherwise = findSelector (i + 1) c ctrs

46 changes: 28 additions & 18 deletions src/Validation/Formula.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,19 @@ import Environment.Environment
import Environment.ERSymbolic

import Data.SBV
import Control.Monad (zipWithM)


-- Custom symbolic variables
type RecursionDepth = Integer
data SValue =
SUnit
| SBoolean SBool
| SNumber SInteger
| SCtr String SInteger [SValue]
| SCtr D C [SValue]
| SArgs [SValue]
-- SArgs represents the fabricated argument list we create when flattening
-- function definitions into a Case-statement
-- SArgs represents the fabricated argument list we create when
-- flattening function definitions into a Case-statement
deriving Show


Expand All @@ -56,26 +58,33 @@ type Bindings = Mapping X SValue
type Formula a = ERSymbolic Type Bindings a

bind :: X -> SValue -> X `MapsTo` SValue
bind x tau look y = if x == y then tau else look y
bind x tau look y = if x == y -- Applying the bindings to some 'y' equal to 'x'
then tau -- you should now get back 'tau'
else look y -- If you call it with some other 'y',
-- then return the old binding for 'y'


-- SValue (symbolic) equality
sEqual :: SValue -> SValue -> SValue
sEqual SUnit SUnit = SBoolean sTrue
sEqual (SBoolean b) (SBoolean c) = SBoolean (b .== c)
sEqual (SNumber n) (SNumber m) = SBoolean (n .== m)
sEqual (SCtr x si xs) (SCtr y sj ys) = SBoolean $ sAnd $
fromBool (x == y)
: (si .== sj)
: map truthy (zipWith sEqual xs ys)
sEqual (SArgs xs) (SArgs ys) = SBoolean $ sAnd $ map truthy $
zipWith sEqual xs ys
sEqual _ _ = SBoolean sFalse
sEqual :: SValue -> SValue -> Formula SValue
sEqual SUnit SUnit = return $ SBoolean sTrue
sEqual (SBoolean b) (SBoolean c) = return $ SBoolean (b .== c)
sEqual (SNumber n) (SNumber m) = return $ SBoolean (n .== m)
sEqual (SCtr adt x xs) (SCtr adt' y ys) =
do eqs <- zipWithM sEqual xs ys
return $ SBoolean $ sAnd $
fromBool (adt == adt')
: fromBool (x == y )
: map truthy eqs
sEqual (SArgs xs) (SArgs ys) =
do eqs <- zipWithM sEqual xs ys
return $ SBoolean $ sAnd $ map truthy eqs
sEqual _ _ = return $ SBoolean sFalse

truthy :: SValue -> SBool
truthy (SBoolean b) = b
truthy SUnit = sTrue
truthy v = error $ "Expected a symbolic boolean value, but got " ++ show v
truthy v = error $
"Expected a symbolic boolean value, but got " ++ show v


-- SValues are 'Mergeable', meaning we can use SBV's if-then-else, called 'ite'.
Expand All @@ -87,8 +96,9 @@ merge :: SBool -> SValue -> SValue -> SValue
merge _ SUnit SUnit = SUnit
merge b (SNumber x) (SNumber y) = SNumber $ ite b x y
merge b (SBoolean x) (SBoolean y) = SBoolean $ ite b x y
merge b (SCtr x si xs) (SCtr y sj ys)
| x == y = SCtr x (ite b si sj) (mergeList b xs ys)
merge b (SCtr adt x xs) (SCtr adt' y ys)
| adt == adt'
&& x == y = SCtr adt x (mergeList b xs ys)
| otherwise = error $
"Type mismatch between data type constructors '"
++ show x ++ "' and '" ++ show y ++ "'"
Expand Down
7 changes: 4 additions & 3 deletions src/Validation/SymUnifier.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ sUnify (Value _) _ = mempty
sUnify (Variable x _) sv = substitution $ bind x sv
sUnify (List ps _) (SArgs svs) =
foldr (\(p, sv) u -> u <> sUnify p sv) mempty $ zip ps svs
sUnify (PConstructor c ps (ADT t)) (SCtr d _ svs)
| t == d = foldr (\(p, sv) u -> u <> sUnify p sv) mempty $ zip ps svs
| otherwise = substError $
sUnify (PConstructor c ps (ADT t)) (SCtr adt d svs)
| t == adt
&& c == d = foldr (\(p, sv) u -> u <> sUnify p sv) mempty $ zip ps svs
| otherwise = substError $
"Unexpected type occurred when trying to unify\n\
\concrete pattern with constructor '" ++ c ++ "' and type '" ++ show t
++ "' against symbolic value of type '" ++ d ++ "'"
Expand Down
103 changes: 41 additions & 62 deletions src/Validation/Translator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ import Data.SBV


-- Recursion depth for ADTs
type RecursionDepth = Integer

defaultRecDepth :: RecursionDepth
defaultRecDepth = 20

Expand Down Expand Up @@ -80,10 +78,9 @@ translate (Let p t1 t2 _) =
translate (Case t0 ts _) =
do sp <- translate t0
translateBranches sp ts
translate (TConstructor c ts adt) =
translate (TConstructor c ts (ADT d)) =
do sts <- mapM translate ts
sel <- symSelector adt c
return $ SCtr c sel sts
return $ SCtr d c sts
translate (Plus t0 t1 _) =
do t0' <- translate t0 >>= numeric
t1' <- translate t1 >>= numeric
Expand All @@ -103,10 +100,12 @@ translate (Gt t0 t1 _) =
translate (Equal t0 t1 _) =
do t0' <- translate t0
t1' <- translate t1
return $ t0' `sEqual` t1'
t0' `sEqual` t1'
translate (Not t0 _) =
do t0' <- translate t0 >>= boolean
return $ SBoolean $ sNot t0'
translate t@(TConstructor {}) = error
$ "Ill-typed constructor argument '" ++ show t ++ "'"
-- translate (Rec x t0 a) -- future work

translatePattern :: Pattern Type -> Formula SValue
Expand All @@ -116,22 +115,24 @@ translatePattern (Value v) = translateValue v
translatePattern (Variable x _) =
do bindings <- ask
return $ bindings x
translatePattern (PConstructor c ps adt) =
translatePattern (PConstructor c ps (ADT d)) =
do sps <- mapM translatePattern ps
sel <- symSelector adt c
return $ SCtr c sel sps
return $ SCtr d c sps
translatePattern (List ps _) =
do sps <- mapM translatePattern ps
return $ SArgs sps
translatePattern p@(PConstructor {}) = error
$ "Ill-typed constructor argument '" ++ show p ++ "'"

translateValue :: Value Type -> Formula SValue
translateValue (Unit _) = return SUnit
translateValue (Number n _) = return $ SNumber $ literal n
translateValue (Boolean b _) = return $ SBoolean $ literal b
translateValue (VConstructor c vs adt) =
translateValue (VConstructor c vs (ADT d)) =
do svs <- mapM translateValue vs
sel <- symSelector adt c
return $ SCtr c sel svs
return $ SCtr d c svs
translateValue v@(VConstructor {}) = error
$ "Ill-typed constructor argument '" ++ show v ++ "'"

translateBranches :: SValue -> [(Pattern Type, Term Type)] -> Formula SValue
translateBranches _ [] = error "Non-exhaustive patterns in case statement."
Expand All @@ -143,10 +144,11 @@ translateBranches sv ((alt, body) : rest) =
case symUnify alt sv of
NoMatch _ -> translateBranches sv rest
MatchBy bs -> do alt' <- local bs $ translatePattern alt
let cond = truthy $ sEqual alt' sv
cond <- alt' `sEqual` sv
-- let cond = truthy $ sEqual alt' sv
body' <- local bs $ translate body
next <- translateBranches sv rest
return $ merge cond body' next
return $ merge (truthy cond) body' next


-- Create symbolic input variables
Expand Down Expand Up @@ -199,72 +201,49 @@ createSymbolic depth (Variable x (TypeList ts)) =
let ps = zipWith Variable names ts
sxs <- mapM (createSymbolic depth) ps
return $ SArgs sxs
createSymbolic 0 (Variable x (ADT adt)) =
do env <- environment
ctrs <- constructors env adt
case removeRecursiveCtrs ctrs of
[] -> error $
"Fatal: Maxed out recursion depth when creating symbolic ADT '"
++ show adt ++ "'"
ctrs' -> do (si, sFields) <- symFields 0 adt ctrs'
return $ SCtr adt si sFields
createSymbolic 0 (Variable x (ADT adt)) = error $

Check failure on line 204 in src/Validation/Translator.hs

View workflow job for this annotation

GitHub Actions / Using Stack version 2.13.1

Defined but not used: ‘x’
"Maxed out recursion depth when creating symbolic ADT " ++ adt ++ "'"
createSymbolic depth (Variable x (ADT adt)) =

Check failure on line 206 in src/Validation/Translator.hs

View workflow job for this annotation

GitHub Actions / Using Stack version 2.13.1

Defined but not used: ‘x’
do env <- environment
ctrs <- constructors env adt
(si, sFields) <- symFields (depth - 1) adt ctrs
return $ SCtr adt si sFields
si <- createSelector ctrs
selectConstructor (depth - 1) adt si ctrs
createSymbolic _ p = error $
"Unexpected request to create symbolic sub-pattern '"
++ show p ++ "' of type '" ++ show (annotation p) ++ "'"
++ "\nPlease note that generating arbitrary functions is not supported."


-- * Helpers for creating symbolic ADT variables
symFields :: RecursionDepth -> D -> [Constructor] -> Formula (SInteger, [SValue])
symFields depth adt ctrs =
do si <- lift $ sInteger "selector"
createSelector :: [Constructor] -> Formula SInteger
createSelector ctrs =
do si <- lift sInteger_
let cardinality = literal $ toInteger $ length ctrs
lift $ constrain $
(si .>= 0) .&& (si .< cardinality)
types <- symSelect si adt ctrs
let names = zipWith (\tau i -> show (hash (adt ++ show tau)) ++ show i)
return si

selectConstructor :: RecursionDepth -> D -> SInteger -> [Constructor] -> Formula SValue
selectConstructor _ d _ [] = error $
"Fatal: Failed to create symbolic variable for ADT '" ++ show d ++ "'"
selectConstructor depth d _ [Constructor c types] =
do let names = zipWith (\tau i -> show (hash (d ++ show tau)) ++ show i)
types
([0..] :: [Integer])
let fields = zipWith Variable names types
sFields <- mapM (createSymbolic depth) fields
return (si, sFields)
return $ SCtr d c sFields
selectConstructor depth d si ((Constructor c types) : ctrs) =
do env <- environment
sel <- selector env d c
let names = zipWith (\tau i -> show (hash (d ++ show tau)) ++ show i)
types
([0..] :: [Integer])
let fields = zipWith Variable names types
sFields <- mapM (createSymbolic depth) fields
next <- selectConstructor depth d si ctrs
return $ merge (si .== literal sel) (SCtr d c sFields) next

symSelect :: SInteger -> D -> [Constructor] -> Formula [Type]
symSelect _ adt [ ] = error $ "Fatal: Failed to create symbolic variable for ADT '"
++ show adt ++ "'"
symSelect si adt [ctr] =
do env <- environment
i <- selector env adt (nameOf ctr)
return $ ite (si .== literal i)
(fieldsOf ctr)
(error $ "Fatal: Failed to create input variable for ADT '"
++ show adt ++ "'")
symSelect si adt (ctr : ctrs) =
do env <- environment
i <- selector env adt (nameOf ctr)
next <- symSelect si adt ctrs
return $ ite (si .== literal i)
(fieldsOf ctr)
next

nameOf :: Constructor -> C
nameOf (Constructor c _) = c

fieldsOf :: Constructor -> [Type]
fieldsOf (Constructor _ taus) = taus

removeRecursiveCtrs :: [Constructor] -> [Constructor]
removeRecursiveCtrs = filter nonRecursive
where
nonRecursive ctr = all nonAlgebraic $ fieldsOf ctr
nonAlgebraic (ADT _) = False
nonAlgebraic _ = True


-- Symbolic "unification" and unification constraint generation
unifyOrFail :: Pattern Type -> SValue -> Formula Transformation
Expand Down

0 comments on commit 5dba64e

Please sign in to comment.