diff --git a/crux-mir-comp/src/Mir/Compositional/Convert.hs b/crux-mir-comp/src/Mir/Compositional/Convert.hs index 3b9cbe24a3..d6d5f2493b 100644 --- a/crux-mir-comp/src/Mir/Compositional/Convert.hs +++ b/crux-mir-comp/src/Mir/Compositional/Convert.hs @@ -315,7 +315,7 @@ termToReg sym sc varMap term shp = do where go :: forall tp'. TypeShape tp' -> SValue sym -> IO (RegValue sym tp') go shp sv = case (shp, sv) of - (UnitShape _, SAW.VUnit) -> return () + (UnitShape _, SAW.VTuple ts) | V.null ts -> return () (PrimShape _ BaseBoolRepr, SAW.VBool b) -> return b (PrimShape _ (BaseBVRepr w), SAW.VWord (W4.DBV e)) | Just Refl <- testEquality (W4.exprType e) (BaseBVRepr w) -> return e @@ -326,9 +326,9 @@ termToReg sym sc varMap term shp = do _ -> fail $ "termToReg: type error: need to produce " ++ show (shapeType shp) ++ ", but simulator returned a vector containing " ++ show x buildBitVector w bits - (TupleShape _ _ flds, _) -> do - svs <- tupleToListRev (Ctx.sizeInt $ Ctx.size flds) [] sv - goTuple flds svs + (TupleShape _ _ flds, SAW.VTuple ts) -> do + svs <- traverse SAW.force ts + goTuple flds (reverse (V.toList svs)) (ArrayShape (M.TyArray _ n) _ shp, SAW.VVector thunks) -> do svs <- mapM SAW.force $ toList thunks when (length svs /= n) $ fail $ @@ -350,21 +350,7 @@ termToReg sym sc varMap term shp = do _ -> error $ "termToReg: type error: need to produce " ++ show (shapeType shp) ++ ", but simulator returned " ++ show sv - -- | Convert an `SValue` tuple (built from nested `VPair`s) into a list of - -- the inner `SValue`s, in reverse order. - tupleToListRev :: Int -> [SValue sym] -> SValue sym -> IO [SValue sym] - tupleToListRev 2 acc (SAW.VPair x y) = do - x' <- SAW.force x - y' <- SAW.force y - return $ y' : x' : acc - tupleToListRev n acc (SAW.VPair x xs) | n > 2 = do - x' <- SAW.force x - xs' <- SAW.force xs - tupleToListRev (n - 1) (x' : acc) xs' - tupleToListRev n _ _ | n < 2 = error $ "bad tuple size " ++ show n - tupleToListRev n _ v = error $ "termToReg: expected tuple of " ++ show n ++ - " elements, but got " ++ show v - + -- `SValue`s expected in reverse order. goTuple :: forall ctx. Assignment FieldShape ctx -> [SValue sym] -> diff --git a/cryptol-saw-core/saw/Cryptol.sawcore b/cryptol-saw-core/saw/Cryptol.sawcore index 3a387e57c4..69b1e9f61c 100644 --- a/cryptol-saw-core/saw/Cryptol.sawcore +++ b/cryptol-saw-core/saw/Cryptol.sawcore @@ -21,11 +21,11 @@ bvExp n x y = foldr Bool (Vec n Bool) n (bvNat n 1) (reverse n Bool y); -updFst : (a b : sort 0) -> (a -> a) -> (a * b) -> (a * b); -updFst a b f x = (f x.(1), x.(2)); +updFst : (a b : sort 0) -> (a -> a) -> #(a, b) -> #(a, b); +updFst a b f x = (f x.0, x.1); -updSnd : (a b : sort 0) -> (b -> b) -> (a * b) -> (a * b); -updSnd a b f x = (x.(1), f x.(2)); +updSnd : (a b : sort 0) -> (b -> b) -> #(a, b) -> #(a, b); +updSnd a b f x = (x.0, f x.1); -------------------------------------------------------------------------------- -- Extended natural numbers @@ -282,23 +282,25 @@ fun_cong a b c d eq_ab eq_cd = (eq_cong (sort 0) a b eq_ab (sort 0) (\ (x:sort 0) -> (x -> c))) (eq_cong (sort 0) c d eq_cd (sort 0) (\ (x:sort 0) -> (b -> x))); -pair_cong : (a : sort 0) -> (a' : sort 0) -> (b : sort 0) -> (b' : sort 0) -> - Eq (sort 0) a a' -> Eq (sort 0) b b' -> Eq (sort 0) (a * b) (a' * b'); +pair_cong : + (a : sort 0) -> (a' : sort 0) -> + (b : sort 0) -> (b' : sort 0) -> + Eq (sort 0) a a' -> Eq (sort 0) b b' -> Eq (sort 0) #(a, b) #(a', b'); pair_cong a a' b b' eq_a eq_b = trans - (sort 0) (a * b) (a' * b) (a' * b') - (eq_cong (sort 0) a a' eq_a (sort 0) (\ (x:sort 0) -> (x * b))) - (eq_cong (sort 0) b b' eq_b (sort 0) (\ (x:sort 0) -> (a' * x))); + (sort 0) #(a, b) #(a', b) #(a', b') + (eq_cong (sort 0) a a' eq_a (sort 0) (\ (x:sort 0) -> #(x, b))) + (eq_cong (sort 0) b b' eq_b (sort 0) (\ (x:sort 0) -> #(a', x))); pair_cong1 : (a : sort 0) -> (a' : sort 0) -> (b : sort 0) -> - Eq (sort 0) a a' -> Eq (sort 0) (a * b) (a' * b); + Eq (sort 0) a a' -> Eq (sort 0) #(a, b) #(a', b); pair_cong1 a a' b eq_a = - (eq_cong (sort 0) a a' eq_a (sort 0) (\ (x:sort 0) -> (x * b))); + (eq_cong (sort 0) a a' eq_a (sort 0) (\ (x:sort 0) -> #(x, b))); pair_cong2 : (a : sort 0) -> (b : sort 0) -> (b' : sort 0) -> - Eq (sort 0) b b' -> Eq (sort 0) (a * b) (a * b'); + Eq (sort 0) b b' -> Eq (sort 0) #(a, b) #(a, b'); pair_cong2 a b b' eq_b = - (eq_cong (sort 0) b b' eq_b (sort 0) (\ (x:sort 0) -> (a * x))); + (eq_cong (sort 0) b b' eq_b (sort 0) (\ (x:sort 0) -> #(a, x))); axiom unsafeAssert_same_Num : (n : Num) -> Eq (Eq Num n n) (unsafeAssert Num n n) (Refl Num n); @@ -316,105 +318,105 @@ eListSel a n = -- List comprehensions from : (a b : isort 0) -> (m n : Num) -> seq m a -> (a -> seq n b) -> - seq (tcMul m n) (a * b); + seq (tcMul m n) #(a, b); from a b m n = Num#rec - (\ (m:Num) -> seq m a -> (a -> seq n b) -> seq (tcMul m n) (a * b)) + (\ (m:Num) -> seq m a -> (a -> seq n b) -> seq (tcMul m n) #(a, b)) (\ (m:Nat) -> Num#rec (\ (n:Num) -> Vec m a -> (a -> seq n b) -> - seq (tcMul (TCNum m) n) (a * b)) + seq (tcMul (TCNum m) n) #(a, b)) -- Case 1: (TCNum m, TCNum n) (\ (n:Nat) -> \ (xs : Vec m a) -> \ (k : a -> Vec n b) -> - join m n (a * b) - (map a (Vec n (a * b)) + join m n #(a, b) + (map a (Vec n #(a, b)) (\ (x : a) -> - map b (a * b) (\ (y : b) -> (x, y)) n (k x)) + map b #(a, b) (\ (y : b) -> (x, y)) n (k x)) m xs)) -- Case 2: n = (TCNum m, TCInf) (natCase (\ (m':Nat) -> (Vec m' a -> (a -> Stream b) -> - seq (if0Nat Num m' (TCNum 0) TCInf) (a * b))) + seq (if0Nat Num m' (TCNum 0) TCInf) #(a, b))) (\ (xs : Vec 0 a) -> - \ (k : a -> Stream b) -> EmptyVec (a * b)) + \ (k : a -> Stream b) -> EmptyVec #(a, b)) (\ (m' : Nat) -> \ (xs : Vec (Succ m') a) -> \ (k : a -> Stream b) -> - (\ (x : a) -> streamMap b (a * b) (\ (y:b) -> (x, y)) (k x)) + (\ (x : a) -> streamMap b #(a, b) (\ (y:b) -> (x, y)) (k x)) (at (Succ m') a xs 0)) m) n) (Num#rec - (\ (n:Num) -> Stream a -> (a -> seq n b) -> seq (tcMul TCInf n) (a * b)) + (\ (n:Num) -> Stream a -> (a -> seq n b) -> seq (tcMul TCInf n) #(a, b)) -- Case 3: (TCInf, TCNum n) (\ (n:Nat) -> natCase (\ (n':Nat) -> (Stream a -> (a -> Vec n' b) -> - seq (if0Nat Num n' (TCNum 0) TCInf) (a * b))) + seq (if0Nat Num n' (TCNum 0) TCInf) #(a, b))) (\ (xs : Stream a) -> - \ (k : a -> Vec 0 b) -> EmptyVec (a * b)) + \ (k : a -> Vec 0 b) -> EmptyVec #(a, b)) (\ (n' : Nat) -> \ (xs : Stream a) -> \ (k : a -> Vec (Succ n') b) -> streamJoin - (a * b) n' + #(a, b) n' (streamMap - a (Vec (Succ n') (a * b)) + a (Vec (Succ n') #(a, b)) (\ (x:a) -> - map b (a * b) (\ (y:b) -> (x, y)) (Succ n') (k x)) + map b #(a, b) (\ (y:b) -> (x, y)) (Succ n') (k x)) xs)) n) -- Case 4: (TCInf, TCInf) (\ (xs : Stream a) -> \ (k : a -> Stream b) -> - (\ (x : a) -> streamMap b (a * b) (\ (y : b) -> (x, y)) (k x)) + (\ (x : a) -> streamMap b #(a, b) (\ (y : b) -> (x, y)) (k x)) (streamGet a xs 0)) n) m; -mlet : (a b : isort 0) -> (n : Num) -> a -> (a -> seq n b) -> seq n (a * b); +mlet : (a b : isort 0) -> (n : Num) -> a -> (a -> seq n b) -> seq n #(a, b); mlet a b n = Num#rec - (\ (n:Num) -> a -> (a -> seq n b) -> seq n (a * b)) + (\ (n:Num) -> a -> (a -> seq n b) -> seq n #(a, b)) (\ (n:Nat) -> \ (x:a) -> \ (f:a -> Vec n b) -> - map b (a * b) (\ (y : b) -> (x, y)) n (f x)) + map b #(a, b) (\ (y : b) -> (x, y)) n (f x)) (\ (x:a) -> \ (f:a -> Stream b) -> - streamMap b (a * b) (\ (y : b) -> (x, y)) (f x)) + streamMap b #(a, b) (\ (y : b) -> (x, y)) (f x)) n; seqZip : (a b : isort 0) -> (m n : Num) -> seq m a -> seq n b -> - seq (tcMin m n) (a * b); + seq (tcMin m n) #(a, b); seqZip a b m n = Num#rec - (\ (m:Num) -> seq m a -> seq n b -> seq (tcMin m n) (a * b)) + (\ (m:Num) -> seq m a -> seq n b -> seq (tcMin m n) #(a, b)) (\ (m : Nat) -> Num#rec - (\ (n:Num) -> Vec m a -> seq n b -> seq (tcMin (TCNum m) n) (a * b)) + (\ (n:Num) -> Vec m a -> seq n b -> seq (tcMin (TCNum m) n) #(a, b)) (\ (n:Nat) -> zip a b m n) (\ (xs:Vec m a) -> \ (ys:Stream b) -> - gen m (a * b) (\ (i : Nat) -> (at m a xs i, streamGet b ys i))) + gen m #(a, b) (\ (i : Nat) -> (at m a xs i, streamGet b ys i))) n) (Num#rec - (\ (n:Num) -> Stream a -> seq n b -> seq (tcMin TCInf n) (a * b)) + (\ (n:Num) -> Stream a -> seq n b -> seq (tcMin TCInf n) #(a, b)) (\ (n:Nat) -> \ (xs:Stream a) -> \ (ys:Vec n b) -> - gen n (a * b) (\ (i : Nat) -> (streamGet a xs i, at n b ys i))) - (streamMap2 a b (a * b) (\ (x:a) -> \ (y:b) -> (x, y))) + gen n #(a, b) (\ (i : Nat) -> (streamGet a xs i, at n b ys i))) + (streamMap2 a b #(a, b) (\ (x:a) -> \ (y:b) -> (x, y))) n) m; -zipSame : (a b : isort 0) -> (n : Nat) -> Vec n a -> Vec n b -> Vec n (a * b); -zipSame a b n x y = gen n (a*b) (\ (i : Nat) -> (at n a x i, at n b y i)); +zipSame : (a b : isort 0) -> (n : Nat) -> Vec n a -> Vec n b -> Vec n #(a, b); +zipSame a b n x y = gen n #(a, b) (\ (i : Nat) -> (at n a x i, at n b y i)); -seqZipSame : (a b : isort 0) -> (n : Num) -> seq n a -> seq n b -> seq n (a * b); +seqZipSame : (a b : isort 0) -> (n : Num) -> seq n a -> seq n b -> seq n #(a, b); seqZipSame a b n = Num#rec - (\ (n : Num) -> seq n a -> seq n b -> seq n (a * b)) + (\ (n : Num) -> seq n a -> seq n b -> seq n #(a, b)) (\ (n : Nat) -> zipSame a b n) - (streamMap2 a b (a*b) (\ (x:a) -> \ (y:b) -> (x,y))) + (streamMap2 a b #(a, b) (\ (x:a) -> \ (y:b) -> (x,y))) n; -------------------------------------------------------------------------------- @@ -435,13 +437,27 @@ unitUnary _ = (); unitBinary : #() -> #() -> #(); unitBinary _ _ = (); -pairUnary : (a b : sort 0) -> (a -> a) -> (b -> b) -> (a * b) -> (a * b); -pairUnary a b f g xy = (f (fst a b xy), g (snd a b xy)); - -pairBinary : (a b : sort 0) -> (a -> a -> a) -> (b -> b -> b) - -> (a * b) -> (a * b) -> (a * b); -pairBinary a b f g x12 y12 = (f (fst a b x12) (fst a b y12), - g (snd a b x12) (snd a b y12)); +pairUnary : + (t : sort 0) -> + (ts : TypeList) -> + (t -> t) -> + (Tuple ts -> Tuple ts) -> + Tuple (TypeCons t ts) -> Tuple (TypeCons t ts); +pairUnary t ts f g x = + consTuple t ts + (f (headTuple t ts x)) + (g (tailTuple t ts x)); + +pairBinary : + (t : sort 0) -> + (ts : TypeList) -> + (t -> t -> t) -> + (Tuple ts -> Tuple ts -> Tuple ts) -> + Tuple (TypeCons t ts) -> Tuple (TypeCons t ts) -> Tuple (TypeCons t ts); +pairBinary t ts f g x y = + consTuple t ts + (f (headTuple t ts x) (headTuple t ts y)) + (g (tailTuple t ts x) (tailTuple t ts y)); funBinary : (a b : sort 0) -> (b -> b -> b) -> (a -> b) -> (a -> b) -> (a -> b); funBinary a b op f g x = op (f x) (g x); @@ -497,16 +513,24 @@ unitLe _ _ = True; unitLt : #() -> #() -> Bool; unitLt _ _ = False; -pairCmp : (a b : sort 0) -> (a -> a -> Bool -> Bool) -> (b -> b -> Bool -> Bool) - -> a * b -> a * b -> Bool -> Bool; -pairCmp a b f g x12 y12 k = - f (fst a b x12) (fst a b y12) (g (snd a b x12) (snd a b y12) k); +pairCmp : + (t : sort 0) -> + (ts : TypeList) -> + (t -> t -> Bool -> Bool) -> + (Tuple ts -> Tuple ts -> Bool -> Bool) -> + Tuple (TypeCons t ts) -> Tuple (TypeCons t ts) -> Bool -> Bool; +pairCmp t ts f g x12 y12 k = + f (headTuple t ts x12) (headTuple t ts y12) + (g (tailTuple t ts x12) (tailTuple t ts y12) k); pairLt : - (a b : sort 0) -> (a -> a -> Bool -> Bool) -> (b -> b -> Bool) -> - a * b -> a * b -> Bool; -pairLt a b f g x y = - f (fst a b x) (fst a b y) (g (snd a b x) (snd a b y)); + (t : sort 0) -> + (ts : TypeList) -> + (t -> t -> Bool -> Bool) -> + (Tuple ts -> Tuple ts -> Bool) -> + Tuple (TypeCons t ts) -> Tuple (TypeCons t ts) -> Bool; +pairLt t ts f g x y = + f (headTuple t ts x) (headTuple t ts y) (g (tailTuple t ts x) (tailTuple t ts y)); -------------------------------------------------------------------------------- -- Dictionaries and overloading @@ -555,7 +579,7 @@ PEqSeqBool n = PEqUnit : PEq #(); PEqUnit = { eq = \ (x y : #()) -> True }; -PEqPair : (a b : sort 0) -> PEq a -> PEq b -> PEq (a * b); +PEqPair : (t : sort 0) -> (ts : TypeList) -> PEq t -> PEq (Tuple ts) -> PEq (Tuple (TypeCons t ts)); PEqPair a b pa pb = { eq = pairEq a b pa.eq pb.eq }; @@ -607,7 +631,10 @@ PCmpSeqBool n = PCmpUnit : PCmp #(); PCmpUnit = { cmpEq = PEqUnit, cmp = unitCmp, le = unitLe, lt = unitLt }; -PCmpPair : (a b : sort 0) -> PCmp a -> PCmp b -> PCmp (a * b); +PCmpPair : + (t : sort 0) -> + (ts : TypeList) -> + PCmp t -> PCmp (Tuple ts) -> PCmp (Tuple (TypeCons t ts)); PCmpPair a b pa pb = { cmpEq = PEqPair a b pa.cmpEq pb.cmpEq , cmp = pairCmp a b pa.cmp pb.cmp @@ -654,7 +681,12 @@ PSignedCmpSeqBool n = PSignedCmpUnit : PSignedCmp #(); PSignedCmpUnit = { signedCmpEq = PEqUnit, scmp = unitCmp, sle = unitLe, slt = unitLt }; -PSignedCmpPair : (a b : sort 0) -> PSignedCmp a -> PSignedCmp b -> PSignedCmp (a * b); +PSignedCmpPair : + (t : sort 0) -> + (ts : TypeList) -> + PSignedCmp t -> + PSignedCmp (Tuple ts) -> + PSignedCmp (Tuple (TypeCons t ts)); PSignedCmpPair a b pa pb = { signedCmpEq = PEqPair a b pa.signedCmpEq pb.signedCmpEq , scmp = pairCmp a b pa.scmp pb.scmp @@ -773,9 +805,12 @@ PLogicUnit = , not = unitUnary }; -PLogicPair : (a b : sort 0) -> PLogic a -> PLogic b -> PLogic (a * b); +PLogicPair : + (t : sort 0) -> + (ts : TypeList) -> + PLogic t -> PLogic (Tuple ts) -> PLogic (Tuple (TypeCons t ts)); PLogicPair a b pa pb = - { logicZero = (pa.logicZero, pb.logicZero) + { logicZero = consTuple a b pa.logicZero pb.logicZero , and = pairBinary a b pa.and pb.and , or = pairBinary a b pa.or pb.or , xor = pairBinary a b pa.xor pb.xor @@ -892,14 +927,17 @@ PRingUnit = , int = \ (i : Integer) -> () }; -PRingPair : (a b : sort 0) -> PRing a -> PRing b -> PRing (a * b); +PRingPair : + (t : sort 0) -> + (ts : TypeList) -> + PRing t -> PRing (Tuple ts) -> PRing (Tuple (TypeCons t ts)); PRingPair a b pa pb = - { ringZero = (pa.ringZero, pb.ringZero) + { ringZero = consTuple a b pa.ringZero pb.ringZero , add = pairBinary a b pa.add pb.add , sub = pairBinary a b pa.sub pb.sub , mul = pairBinary a b pa.mul pb.mul , neg = pairUnary a b pa.neg pb.neg - , int = \ (i : Integer) -> (pa.int i, pb.int i) + , int = \ (i : Integer) -> consTuple a b (pa.int i) (pb.int i) }; -- Integral class @@ -1887,35 +1925,35 @@ processSHA2_512 n x = ec_double : (p : Num) -> - IntModNum p * IntModNum p * IntModNum p -> - IntModNum p * IntModNum p * IntModNum p; + #(IntModNum p, IntModNum p, IntModNum p) -> + #(IntModNum p, IntModNum p, IntModNum p); ec_double p x = - error (IntModNum p * IntModNum p * IntModNum p) "Unimplemented: ec_double"; + error #(IntModNum p, IntModNum p, IntModNum p) "Unimplemented: ec_double"; ec_add_nonzero : (p : Num) -> - IntModNum p * IntModNum p * IntModNum p -> - IntModNum p * IntModNum p * IntModNum p -> - IntModNum p * IntModNum p * IntModNum p; + #(IntModNum p, IntModNum p, IntModNum p) -> + #(IntModNum p, IntModNum p, IntModNum p) -> + #(IntModNum p, IntModNum p, IntModNum p); ec_add_nonzero p x y = - error (IntModNum p * IntModNum p * IntModNum p) "Unimplemented: ec_add_nonzero"; + error #(IntModNum p, IntModNum p, IntModNum p) "Unimplemented: ec_add_nonzero"; ec_mult : (p : Num) -> IntModNum p -> - IntModNum p * IntModNum p * IntModNum p -> - IntModNum p * IntModNum p * IntModNum p; + #(IntModNum p, IntModNum p, IntModNum p) -> + #(IntModNum p, IntModNum p, IntModNum p); ec_mult p x y = - error (IntModNum p * IntModNum p * IntModNum p) "Unimplemented: ec_mult"; + error #(IntModNum p, IntModNum p, IntModNum p) "Unimplemented: ec_mult"; ec_twin_mult : (p : Num) -> IntModNum p -> - IntModNum p * IntModNum p * IntModNum p -> - IntModNum p * IntModNum p * IntModNum p -> - IntModNum p * IntModNum p * IntModNum p; + #(IntModNum p, IntModNum p, IntModNum p) -> + #(IntModNum p, IntModNum p, IntModNum p) -> + #(IntModNum p, IntModNum p, IntModNum p); ec_twin_mult p x y z = - error (IntModNum p * IntModNum p * IntModNum p) "Unimplemented: ec_twin_mult"; + error #(IntModNum p, IntModNum p, IntModNum p) "Unimplemented: ec_twin_mult"; -------------------------------------------------------------------------------- -- Rewrite rules diff --git a/cryptol-saw-core/saw/CryptolM.sawcore b/cryptol-saw-core/saw/CryptolM.sawcore index 61df5156f2..159d1d95c1 100644 --- a/cryptol-saw-core/saw/CryptolM.sawcore +++ b/cryptol-saw-core/saw/CryptolM.sawcore @@ -122,23 +122,23 @@ eListSelM a = -- FIXME primitive fromM : (a b : sort 0) -> (m n : Num) -> mseq m a -> (a -> CompM (mseq n b)) -> - CompM (seq (tcMul m n) (a * b)); + CompM (seq (tcMul m n) #(a, b)); -- FIXME primitive mletM : (a b : sort 0) -> (n : Num) -> a -> (a -> CompM (mseq n b)) -> - CompM (mseq n (a * b)); + CompM (mseq n #(a, b)); -- FIXME primitive seqZipM : (a b : sort 0) -> (m n : Num) -> mseq m a -> mseq n b -> - CompM (mseq (tcMin m n) (a * b)); + CompM (mseq (tcMin m n) #(a, b)); {- seqZipM a b m n ms1 ms2 = seqMap - (CompM a * CompM b) (CompM (a * b)) (tcMin m n) - (\ (p : CompM a * CompM b) -> - bindM2 a b (a*b) p.(1) p.(2) (\ (x:a) (y:b) -> returnM (a*b) (x,y))) + #(CompM a, CompM b) (CompM #(a, b)) (tcMin m n) + (\ (p : #(CompM a, CompM b)) -> + bindM2 a b #(a, b) p.(0) p.(1) (\ (x:a) (y:b) -> returnM #(a, b) (x, y))) (seqZip (CompM a) (CompM b) m n ms1 ms2); -} diff --git a/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs b/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs index e61eaff02e..7626b43ddd 100644 --- a/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs +++ b/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs @@ -922,21 +922,18 @@ importExpr sc env expr = case sel of C.TupleSel i _maybeLen -> do e' <- importExpr sc env e - let t = fastTypeOf (envC env) e - case C.tIsTuple t of - Just ts -> scTupleSelector sc e' (i+1) (length ts) - Nothing -> panic "importExpr" ["invalid tuple selector", show i, pretty t] + scTupleSelector sc e' i C.RecordSel x _ -> do e' <- importExpr sc env e let t = fastTypeOf (envC env) e case C.tNoUser t of C.TRec fm -> do i <- the (elemIndex x (map fst (C.canonicalFields fm))) - scTupleSelector sc e' (i+1) (length (C.canonicalFields fm)) + scTupleSelector sc e' i C.TNewtype nt _args -> do let fs = C.ntFields nt i <- the (elemIndex x (map fst (C.canonicalFields fs))) - scTupleSelector sc e' (i+1) (length (C.canonicalFields fs)) + scTupleSelector sc e' i _ -> panic "importExpr" ["invalid record selector", pretty x, pretty t] C.ListSel i _maybeLen -> do let t = fastTypeOf (envC env) e @@ -1633,13 +1630,10 @@ asCryptolTypeValue v = SC.VDataType _ _ _ -> Nothing - SC.VUnitType -> return (Right (C.tTuple [])) - SC.VPairType v1 v2 -> do - Right t1 <- asCryptolTypeValue v1 - Right t2 <- asCryptolTypeValue v2 - case C.tIsTuple t2 of - Just ts -> return (Right (C.tTuple (t1 : ts))) - Nothing -> return (Right (C.tTuple [t1, t2])) + SC.VTupleType vs -> + do es <- traverse asCryptolTypeValue vs + ts <- traverse asRight es + pure (Right (C.tTuple (Vector.toList ts))) SC.VPiType _nm v1 (SC.VNondependentPi v2) -> do Right t1 <- asCryptolTypeValue v1 @@ -1656,7 +1650,8 @@ asCryptolTypeValue v = SC.VRecordType{} -> Nothing SC.VRecursorType{} -> Nothing SC.VTyTerm{} -> Nothing - + where + asRight = either (const Nothing) Just scCryptolType :: SharedContext -> Term -> IO (Maybe (Either C.Kind C.Type)) scCryptolType sc t = @@ -1734,24 +1729,25 @@ exportValue ty v = case ty of exportTupleValue :: [TV.TValue] -> SC.CValue -> [V.Eval V.Value] exportTupleValue tys v = - case (tys, v) of - ([] , SC.VUnit ) -> [] - ([t] , _ ) -> [exportValue t v] - (t : ts, SC.VPair x y) -> (exportValue t (run x)) : exportTupleValue ts (run y) - _ -> error $ "exportValue: expected tuple" + case v of + SC.VTuple (Vector.toList -> xs) + | length xs == length tys -> + [ exportValue t (run x) | (t, x) <- zip tys xs ] + _ -> panic "Verifier.SAW.Cryptol.exportValue" ["expected tuple"] where run = SC.runIdentity . force exportRecordValue :: [(C.Ident, TV.TValue)] -> SC.CValue -> [(C.Ident, V.Eval V.Value)] exportRecordValue fields v = - case (fields, v) of - ([] , SC.VUnit ) -> [] - ([(n, t)] , _ ) -> [(n, exportValue t v)] - ((n, t) : ts, SC.VPair x y) -> (n, exportValue t (run x)) : exportRecordValue ts (run y) - (_, SC.VRecordValue (alistAllFields - (map (C.identText . fst) fields) -> Just ths)) -> + case v of + -- TODO: remove VTuple case when cryptol-saw-core importer switches record imports to use record types. + SC.VTuple (Vector.toList -> xs) + | length xs == length fields -> + [ (n, exportValue t (run x)) | ((n, t), x) <- zip (Map.assocs (Map.fromList fields)) xs ] + SC.VRecordValue (alistAllFields + (map (C.identText . fst) fields) -> Just ths) -> zipWith (\(n,t) x -> (n, exportValue t (run x))) fields ths - _ -> error $ "exportValue: expected record" + _ -> panic "Verifier.SAW.Cryptol.exportValue" ["expected record"] where run = SC.runIdentity . force diff --git a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs index b296dabb4d..126ee2c392 100644 --- a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs +++ b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs @@ -224,7 +224,7 @@ data MonType = MTyForall LocalName MonKind (MonType -> MonType) | MTyArrow MonType MonType | MTySeq OpenTerm MonType - | MTyPair MonType MonType + | MTyTuple [MonType] | MTyRecord [(FieldName, MonType)] | MTyBase MonKind OpenTerm -- A "base type" or type var of a given kind | MTyNum OpenTerm @@ -241,7 +241,7 @@ boolMonType = mkMonType0 $ globalOpenTerm "Prelude.Bool" monTypeIsMono :: MonType -> Bool monTypeIsMono (MTyForall _ _ _) = False monTypeIsMono (MTyArrow tp1 tp2) = monTypeIsMono tp1 && monTypeIsMono tp2 -monTypeIsMono (MTyPair tp1 tp2) = monTypeIsMono tp1 && monTypeIsMono tp2 +monTypeIsMono (MTyTuple tps) = all monTypeIsMono tps monTypeIsMono (MTyRecord tps) = all (monTypeIsMono . snd) tps monTypeIsMono (MTySeq _ tp) = monTypeIsMono tp monTypeIsMono (MTyBase _ _) = True @@ -253,7 +253,7 @@ isBaseType :: MonType -> Bool isBaseType (MTyForall _ _ _) = False isBaseType (MTyArrow _ _) = False isBaseType (MTySeq _ _) = True -isBaseType (MTyPair _ _) = True +isBaseType (MTyTuple _) = True isBaseType (MTyRecord _) = True isBaseType (MTyBase (MKType _) _) = True isBaseType (MTyBase _ _) = True @@ -273,10 +273,10 @@ monTypeKind (MTyArrow t1 t2) = do s1 <- monTypeKind t1 >>= monKindToSort s2 <- monTypeKind t2 >>= monKindToSort return $ MKType $ maxSort [s1, s2] -monTypeKind (MTyPair tp1 tp2) = - do sort1 <- monTypeKind tp1 >>= monKindToSort - sort2 <- monTypeKind tp2 >>= monKindToSort - return $ MKType $ maxSort [sort1, sort2] +monTypeKind (MTyTuple tps) = + do kinds <- traverse monTypeKind tps + sorts <- traverse monKindToSort kinds + pure $ MKType $ maxSort sorts monTypeKind (MTyRecord tps) = do sorts <- mapM (monTypeKind . snd >=> monKindToSort) tps return $ MKType $ maxSort sorts @@ -319,8 +319,8 @@ toArgType (MTyArrow t1 t2) = arrowOpenTerm "_" (toArgType t1) (toCompType t2) toArgType (MTySeq n t) = applyOpenTermMulti (globalOpenTerm "CryptolM.mseq") [n, toArgType t] -toArgType (MTyPair mtp1 mtp2) = - pairTypeOpenTerm (toArgType mtp1) (toArgType mtp2) +toArgType (MTyTuple mtps) = + tupleTypeOpenTerm (map toArgType mtps) toArgType (MTyRecord tps) = recordTypeOpenTerm $ map (\(f,tp) -> (f, toArgType tp)) tps toArgType (MTyBase _ t) = t @@ -399,8 +399,8 @@ monadifyType ctx tp@(asPi -> Just (x, tp_in, tp_out)) = MTyArrow (monadifyType ctx tp_in) (monadifyType ((x,tp,Nothing):ctx) tp_out) monadifyType _ (asTupleType -> Just []) = mkMonType0 unitTypeOpenTerm -monadifyType ctx (asPairType -> Just (tp1, tp2)) = - MTyPair (monadifyType ctx tp1) (monadifyType ctx tp2) +monadifyType ctx (asTupleType -> Just tps) = + MTyTuple (map (monadifyType ctx) tps) monadifyType ctx (asRecordType -> Just tps) = MTyRecord $ map (\(fld,tp) -> (fld, monadifyType ctx tp)) $ Map.toList tps monadifyType ctx (asDataType -> Just (eq_pn, [k_trm, tp1, tp2])) @@ -535,7 +535,7 @@ monTypeIsPure :: MonType -> Bool monTypeIsPure (MTyForall _ _ _) = False -- NOTE: this could potentially be true monTypeIsPure (MTyArrow _ _) = False monTypeIsPure (MTySeq _ _) = False -monTypeIsPure (MTyPair mtp1 mtp2) = monTypeIsPure mtp1 && monTypeIsPure mtp2 +monTypeIsPure (MTyTuple mtps) = all monTypeIsPure mtps monTypeIsPure (MTyRecord fld_mtps) = all (monTypeIsPure . snd) fld_mtps monTypeIsPure (MTyBase _ _) = True monTypeIsPure (MTyNum _) = True @@ -550,10 +550,9 @@ monTypeIsSemiPure (MTyForall _ k tp_f) = monTypeIsSemiPure (MTyArrow tp_in tp_out) = monTypeIsPure tp_in && monTypeIsSemiPure tp_out monTypeIsSemiPure (MTySeq _ _) = False -monTypeIsSemiPure (MTyPair mtp1 mtp2) = - -- NOTE: functions in pairs are not semi-pure; only pure types in pairs are - -- semi-pure - monTypeIsPure mtp1 && monTypeIsPure mtp2 +monTypeIsSemiPure (MTyTuple mtps) = all monTypeIsPure mtps + -- NOTE: functions in tuples are not semi-pure; only pure types in + -- tuples are semi-pure monTypeIsSemiPure (MTyRecord fld_mtps) = -- Same as pairs, record types are only semi-pure if they are pure all (monTypeIsPure . snd) fld_mtps @@ -836,24 +835,24 @@ monadifyTerm' (Just mtp@(MTyArrow _ _)) t = get >>= \table -> return $ monadifyLambdas (monStEnv ro_st) table (monStCtx ro_st) mtp t -} -monadifyTerm' (Just mtp@(MTyPair mtp1 mtp2)) (asPairValue -> - Just (trm1, trm2)) = - fromArgTerm mtp <$> (pairOpenTerm <$> - monadifyArgTerm (Just mtp1) trm1 <*> - monadifyArgTerm (Just mtp2) trm2) +monadifyTerm' (Just mtp@(MTyTuple mtps)) (asTupleValue -> Just trms) + | length mtps == length trms = + fromArgTerm mtp <$> + tupleOpenTerm <$> + zipWithM monadifyArgTerm (map Just mtps) trms monadifyTerm' (Just mtp@(MTyRecord fs_mtps)) (asRecordValue -> Just trm_map) | length fs_mtps == Map.size trm_map , (fs,mtps) <- unzip fs_mtps , Just trms <- mapM (\f -> Map.lookup f trm_map) fs = fromArgTerm mtp <$> recordOpenTerm <$> zip fs <$> zipWithM monadifyArgTerm (map Just mtps) trms -monadifyTerm' _ (asPairSelector -> Just (trm, False)) = +monadifyTerm' _ (asTupleSelector -> Just (trm, i)) = do mtrm <- monadifyArg Nothing trm - mtp <- case getMonType mtrm of - MTyPair t _ -> return t - _ -> fail "Monadification failed: projection on term of non-pair type" - return $ fromArgTerm mtp $ - pairLeftOpenTerm $ toArgTerm mtrm + mtp <- + case getMonType mtrm of + MTyTuple ts | i < length ts -> pure (ts !! i) + _ -> fail "Monadification failed: projection on term of non-tuple type" + pure $ fromArgTerm mtp $ projTupleOpenTerm (toInteger i) $ toArgTerm mtrm monadifyTerm' (Just mtp@(MTySeq n mtp_elem)) (asFTermF -> Just (ArrayValue _ trms)) = do trms' <- traverse (monadifyArgTerm $ Just mtp_elem) trms @@ -861,13 +860,6 @@ monadifyTerm' (Just mtp@(MTySeq n mtp_elem)) (asFTermF -> applyOpenTermMulti (globalOpenTerm "CryptolM.seqToMseq") [n, toArgType mtp_elem, flatOpenTerm $ ArrayValue (toArgType mtp_elem) trms'] -monadifyTerm' _ (asPairSelector -> Just (trm, True)) = - do mtrm <- monadifyArg Nothing trm - mtp <- case getMonType mtrm of - MTyPair _ t -> return t - _ -> fail "Monadification failed: projection on term of non-pair type" - return $ fromArgTerm mtp $ - pairRightOpenTerm $ toArgTerm mtrm monadifyTerm' _ (asRecordSelector -> Just (trm, fld)) = do mtrm <- monadifyArg Nothing trm mtp <- case getMonType mtrm of diff --git a/cryptol-saw-core/src/Verifier/SAW/TypedTerm.hs b/cryptol-saw-core/src/Verifier/SAW/TypedTerm.hs index dac3549deb..93a3afd9d2 100644 --- a/cryptol-saw-core/src/Verifier/SAW/TypedTerm.hs +++ b/cryptol-saw-core/src/Verifier/SAW/TypedTerm.hs @@ -109,8 +109,8 @@ destTupleTypedTerm sc (TypedTerm tp t) = Nothing -> fail "asTupleTypedTerm: not a tuple type" Just ctys -> do let len = length ctys - let idxs = take len [1 ..] - ts <- traverse (\i -> scTupleSelector sc t i len) idxs + let idxs = take len [0..] + ts <- traverse (scTupleSelector sc t) idxs pure $ zipWith TypedTerm (map (TypedTermSchema . C.tMono) ctys) ts -- First order types and values ------------------------------------------------ diff --git a/heapster-saw/examples/arrays.sawcore b/heapster-saw/examples/arrays.sawcore index 6b1f16867b..1c1b5e05d9 100644 --- a/heapster-saw/examples/arrays.sawcore +++ b/heapster-saw/examples/arrays.sawcore @@ -3,39 +3,48 @@ module arrays where import Prelude; +letRecM_single : + (lrt : LetRecType) -> + (B : sort 0) -> + (lrtToType lrt -> lrtToType lrt) -> + (lrtToType lrt -> CompM B) -> CompM B; +letRecM_single lrt B fn body = + letRecM + (LRT_Cons lrt LRT_Nil) B + (\ (x : lrtToType lrt) -> consTuple (lrtToType lrt) TypeNil (fn x) ()) + (\ (x : lrtToType lrt) -> body x); + -- The helper function for noErrorsContains0 -- -- noErrorsContains0H len i v = -- orM (exists x. returnM x) (noErrorsContains0H len (i+1) v) noErrorsContains0H : (len i:Vec 64 Bool) -> BVVec 64 len (Vec 64 Bool) -> - CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool); + CompM #(BVVec 64 len (Vec 64 Bool), Vec 64 Bool); noErrorsContains0H len_top i_top v_top = - letRecM - (LRT_Cons - (LRT_Fun (Vec 64 Bool) (\ (len:Vec 64 Bool) -> - LRT_Fun (Vec 64 Bool) (\ (_:Vec 64 Bool) -> - LRT_Fun (BVVec 64 len (Vec 64 Bool)) (\ (_:BVVec 64 len (Vec 64 Bool)) -> - LRT_Ret (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool))))) - LRT_Nil) - (BVVec 64 len_top (Vec 64 Bool) * Vec 64 Bool) + letRecM_single + (LRT_Fun (Vec 64 Bool) (\ (len:Vec 64 Bool) -> + LRT_Fun (Vec 64 Bool) (\ (_:Vec 64 Bool) -> + LRT_Fun (BVVec 64 len (Vec 64 Bool)) (\ (_:BVVec 64 len (Vec 64 Bool)) -> + LRT_Ret #(BVVec 64 len (Vec 64 Bool), Vec 64 Bool))))) + #(BVVec 64 len_top (Vec 64 Bool), Vec 64 Bool) (\ (f : (len i:Vec 64 Bool) -> BVVec 64 len (Vec 64 Bool) -> - CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) -> - ((\ (len:Vec 64 Bool) (i:Vec 64 Bool) (v:BVVec 64 len (Vec 64 Bool)) -> - precondHint - (CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) - (and (bvsle 64 0x0000000000000000 i) - (bvsle 64 i 0x0fffffffffffffff)) - (orM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool) - (existsM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool) - (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool) - (returnM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool))) - (f len (bvAdd 64 i 0x0000000000000001) v))), ())) + CompM #(BVVec 64 len (Vec 64 Bool), Vec 64 Bool)) -> + (\ (len:Vec 64 Bool) (i:Vec 64 Bool) (v:BVVec 64 len (Vec 64 Bool)) -> + precondHint + (CompM #(BVVec 64 len (Vec 64 Bool), Vec 64 Bool)) + (and (bvsle 64 0x0000000000000000 i) + (bvsle 64 i 0x0fffffffffffffff)) + (orM #(BVVec 64 len (Vec 64 Bool), Vec 64 Bool) + (existsM #(BVVec 64 len (Vec 64 Bool), Vec 64 Bool) + #(BVVec 64 len (Vec 64 Bool), Vec 64 Bool) + (returnM #(BVVec 64 len (Vec 64 Bool), Vec 64 Bool))) + (f len (bvAdd 64 i 0x0000000000000001) v)))) (\ (f : (len i:Vec 64 Bool) -> BVVec 64 len (Vec 64 Bool) -> - CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) -> + CompM #(BVVec 64 len (Vec 64 Bool), Vec 64 Bool)) -> f len_top i_top v_top); -- The specification that contains0 has no errors noErrorsContains0 : (len:Vec 64 Bool) -> BVVec 64 len (Vec 64 Bool) -> - CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool); + CompM #(BVVec 64 len (Vec 64 Bool), Vec 64 Bool); noErrorsContains0 len v = noErrorsContains0H len 0x0000000000000000 v; diff --git a/heapster-saw/examples/clearbufs.sawcore b/heapster-saw/examples/clearbufs.sawcore index 794f9a00bf..d3326aba19 100644 --- a/heapster-saw/examples/clearbufs.sawcore +++ b/heapster-saw/examples/clearbufs.sawcore @@ -25,28 +25,28 @@ Mbox_def = Mbox; (m:Mbox) -> P m; Mbox__rec P f1 f2 m = Mbox#rec P f1 f2 m; -} ---unfoldMbox : Mbox -> Either #() (Mbox * Vec 64 Bool * BVVec 64 bv64_16 (Vec 64 Bool)); +--unfoldMbox : Mbox -> Either #() #(Mbox, Vec 64 Bool, BVVec 64 bv64_16 (Vec 64 Bool)); primitive -unfoldMbox : Mbox -> Either #() (Sigma (V64) (\ (len : V64) -> Mbox * BVVec 64 len (Vec 64 Bool))); +unfoldMbox : Mbox -> Either #() (Sigma (V64) (\ (len : V64) -> #(Mbox, BVVec 64 len (Vec 64 Bool)))); {-unfoldMbox m = - Mbox__rec (\ (_:Mbox) -> Either #() (Mbox * Vec 64 Bool * BVVec 64 bv64_16 (Vec 64 Bool) * #())) - (Left #() (Mbox * Vec 64 Bool * BVVec 64 bv64_16 (Vec 64 Bool) * #()) ()) - (\ (m:Mbox) (_:Either #() (Mbox * Vec 64 Bool * BVVec 64 bv64_16 (Vec 64 Bool) * #())) (len:Vec 64 Bool) (d:BVVec 64 bv64_16 (Vec 64 Bool)) -> - Right #() (Mbox * Vec 64 Bool * BVVec 64 bv64_16 (Vec 64 Bool) * #()) (m, len, d, ())) + Mbox__rec (\ (_:Mbox) -> Either #() #(Mbox, Vec 64 Bool, BVVec 64 bv64_16 (Vec 64 Bool), #())) + (Left #() #(Mbox, Vec 64 Bool, BVVec 64 bv64_16 (Vec 64 Bool), #()) ()) + (\ (m:Mbox) (_:Either #() #(Mbox, Vec 64 Bool, BVVec 64 bv64_16 (Vec 64 Bool), #())) (len:Vec 64 Bool) (d:BVVec 64 bv64_16 (Vec 64 Bool)) -> + Right #() #(Mbox, Vec 64 Bool, BVVec 64 bv64_16 (Vec 64 Bool), #()) (m, len, d, ())) m; -} primitive -foldMbox : Either #() (Sigma (V64) (\ (len : V64) -> Mbox * BVVec 64 len (Vec 64 Bool))) -> Mbox; +foldMbox : Either #() (Sigma (V64) (\ (len : V64) -> #(Mbox, BVVec 64 len (Vec 64 Bool)))) -> Mbox; ---(Mbox * Vec 64 Bool * (BVVec 64 bv64_16 (Vec 64 Bool)) * #()) -> Mbox; +--#(Mbox, Vec 64 Bool, (BVVec 64 bv64_16 (Vec 64 Bool)), #()) -> Mbox; {- foldMbox = - either #() (Mbox * Vec 64 Bool * (BVVec 64 bv64_16 (Vec 64 Bool)) * #()) Mbox + either #() #(Mbox, Vec 64 Bool, (BVVec 64 bv64_16 (Vec 64 Bool)), #()) Mbox (\ (_:#()) -> Mbox_nil) - (\ (tup : (Mbox * Vec 64 Bool * (BVVec 64 bv64_16 (Vec 64 Bool)) * #())) -> - Mbox_cons tup.1 tup.2 tup.3); + (\ (tup : #(Mbox, Vec 64 Bool, (BVVec 64 bv64_16 (Vec 64 Bool)), #())) -> + Mbox_cons tup.0 tup.1 tup.2); -} primitive diff --git a/heapster-saw/examples/global_var.sawcore b/heapster-saw/examples/global_var.sawcore index a1f63d7f19..eade3439b1 100644 --- a/heapster-saw/examples/global_var.sawcore +++ b/heapster-saw/examples/global_var.sawcore @@ -2,8 +2,8 @@ module GlobalVar where import Prelude; -acquireLockM : Vec 64 Bool -> CompM (Vec 64 Bool * Vec 64 Bool); -acquireLockM u = returnM (Vec 64 Bool * Vec 64 Bool) +acquireLockM : Vec 64 Bool -> CompM #(Vec 64 Bool, Vec 64 Bool); +acquireLockM u = returnM #(Vec 64 Bool, Vec 64 Bool) (u,u); releaseLockM : Vec 64 Bool -> Vec 64 Bool -> CompM (Vec 64 Bool); diff --git a/heapster-saw/examples/iter_linked_list.sawcore b/heapster-saw/examples/iter_linked_list.sawcore index fbb59e816b..421650dbc6 100644 --- a/heapster-saw/examples/iter_linked_list.sawcore +++ b/heapster-saw/examples/iter_linked_list.sawcore @@ -27,20 +27,20 @@ ListF__rec : (a b:sort 0) -> (P : ListF a b -> sort 0) -> (l:ListF a b) -> P l; ListF__rec a b P f1 f2 l = ListF#rec a b P f1 f2 l; -unfoldListF : (a b:sort 0) -> ListF a b -> Either b (a * ListF a b); +unfoldListF : (a b:sort 0) -> ListF a b -> Either b #(a, ListF a b); unfoldListF a b l = - ListF__rec a b (\ (_:ListF a b) -> Either b (a * ListF a b)) - (\ (x:b) -> Left b (a * ListF a b) x) - (\ (x:a) (l:ListF a b) (_:Either b (a * ListF a b)) -> - Right b (a * ListF a b) (x, l)) + ListF__rec a b (\ (_:ListF a b) -> Either b #(a, ListF a b)) + (\ (x:b) -> Left b #(a, ListF a b) x) + (\ (x:a) (l:ListF a b) (_:Either b #(a, ListF a b)) -> + Right b #(a, ListF a b) (x, l)) l; -foldListF : (a b:sort 0) -> Either b (a * ListF a b) -> ListF a b; +foldListF : (a b:sort 0) -> Either b #(a, ListF a b) -> ListF a b; foldListF a b = - either b (a * ListF a b) (ListF a b) + either b #(a, ListF a b) (ListF a b) (\ (x : b) -> NilF a b x) - (\ (tup : (a * ListF a b)) -> - ConsF a b tup.(1) tup.(2)); + (\ (tup : #(a, ListF a b)) -> + ConsF a b tup.0 tup.1); getListF : (a b:sort 0) -> ListF a b -> b; getListF a b = diff --git a/heapster-saw/examples/mbox.saw b/heapster-saw/examples/mbox.saw index de4edbb63d..b805da53d7 100644 --- a/heapster-saw/examples/mbox.saw +++ b/heapster-saw/examples/mbox.saw @@ -73,7 +73,7 @@ heapster_assume_fun env "__memcpy_chk" "(len:bv 64). arg0:byte_array, arg1:byte_array, arg2:eq(llvmword (len)) -o \ \ arg0:byte_array, arg1:byte_array" "\\ (len:Vec 64 Bool) (_ src : BVVec 64 len (Vec 8 Bool)) -> \ - \ returnM (BVVec 64 len (Vec 8 Bool) * BVVec 64 len (Vec 8 Bool)) (src, src)"; + \ returnM #(BVVec 64 len (Vec 8 Bool), BVVec 64 len (Vec 8 Bool)) (src, src)"; //------------------------------------------------------------------------------ diff --git a/heapster-saw/examples/mbox.sawcore b/heapster-saw/examples/mbox.sawcore index 0b751a6be9..86f6a30812 100644 --- a/heapster-saw/examples/mbox.sawcore +++ b/heapster-saw/examples/mbox.sawcore @@ -31,20 +31,20 @@ Mbox__rec : (P : Mbox -> sort 0) -> (m:Mbox) -> P m; Mbox__rec P f1 f2 m = Mbox#rec P f1 f2 m; -unfoldMbox : Mbox -> Either #() (Vec 64 Bool * Vec 64 Bool * Mbox * BVVec 64 bv64_128 (Vec 8 Bool)); +unfoldMbox : Mbox -> Either #() #(Vec 64 Bool, Vec 64 Bool, Mbox, BVVec 64 bv64_128 (Vec 8 Bool)); unfoldMbox m = - Mbox__rec (\ (_:Mbox) -> Either #() (Vec 64 Bool * Vec 64 Bool * Mbox * BVVec 64 bv64_128 (Vec 8 Bool))) - (Left #() (Vec 64 Bool * Vec 64 Bool * Mbox * BVVec 64 bv64_128 (Vec 8 Bool)) ()) - (\ (strt:Vec 64 Bool) (len:Vec 64 Bool) (m:Mbox) (_:Either #() (Vec 64 Bool * Vec 64 Bool * Mbox * BVVec 64 bv64_128 (Vec 8 Bool))) (d:BVVec 64 bv64_128 (Vec 8 Bool)) -> - Right #() (Vec 64 Bool * Vec 64 Bool * Mbox * BVVec 64 bv64_128 (Vec 8 Bool)) (strt, len, m, d)) + Mbox__rec (\ (_:Mbox) -> Either #() #(Vec 64 Bool, Vec 64 Bool, Mbox, BVVec 64 bv64_128 (Vec 8 Bool))) + (Left #() #(Vec 64 Bool, Vec 64 Bool, Mbox, BVVec 64 bv64_128 (Vec 8 Bool)) ()) + (\ (strt:Vec 64 Bool) (len:Vec 64 Bool) (m:Mbox) (_:Either #() #(Vec 64 Bool, Vec 64 Bool, Mbox, BVVec 64 bv64_128 (Vec 8 Bool))) (d:BVVec 64 bv64_128 (Vec 8 Bool)) -> + Right #() #(Vec 64 Bool, Vec 64 Bool, Mbox, BVVec 64 bv64_128 (Vec 8 Bool)) (strt, len, m, d)) m; -foldMbox : Either #() (Vec 64 Bool * Vec 64 Bool * Mbox * BVVec 64 bv64_128 (Vec 8 Bool)) -> Mbox; +foldMbox : Either #() #(Vec 64 Bool, Vec 64 Bool, Mbox, BVVec 64 bv64_128 (Vec 8 Bool)) -> Mbox; foldMbox = - either #() (Vec 64 Bool * Vec 64 Bool * Mbox * BVVec 64 bv64_128 (Vec 8 Bool)) Mbox + either #() #(Vec 64 Bool, Vec 64 Bool, Mbox, BVVec 64 bv64_128 (Vec 8 Bool)) Mbox (\ (_:#()) -> Mbox_nil) - (\ (tup : (Vec 64 Bool * Vec 64 Bool * Mbox * BVVec 64 bv64_128 (Vec 8 Bool))) -> - Mbox_cons tup.1 tup.2 tup.3 tup.(2).(2).(2)); + (\ (tup : #(Vec 64 Bool, Vec 64 Bool, Mbox, BVVec 64 bv64_128 (Vec 8 Bool))) -> + Mbox_cons tup.0 tup.1 tup.2 tup.3); {- getMbox : (a : sort 0) -> Mbox a -> a; diff --git a/heapster-saw/examples/memcpy.saw b/heapster-saw/examples/memcpy.saw index 92196ff710..3e0f90c933 100644 --- a/heapster-saw/examples/memcpy.saw +++ b/heapster-saw/examples/memcpy.saw @@ -10,7 +10,7 @@ heapster_assume_fun env "llvm.memcpy.p0i8.p0i8.i64" \ arg0:[l1]memblock(W,0,len,sh), arg1:[l2]memblock(rw,0,len,eqsh(len,b)), \ \ arg2:eq(llvmword(len)) -o \ \ arg0:[l1]memblock(W,0,len,eqsh(len,b)), arg1:[l2]memblock(rw,0,len,eqsh(len,b))" - "\\ (X:sort 0) (len:Vec 64 Bool) (x:X) (_:#()) -> returnM (#() * #()) ((),())"; + "\\ (X:sort 0) (len:Vec 64 Bool) (x:X) (_:#()) -> returnM #(#(), #()) ((),())"; heapster_typecheck_fun env "copy_int" diff --git a/heapster-saw/examples/rust_data.saw b/heapster-saw/examples/rust_data.saw index cdc8b4a48d..105e1e6a48 100644 --- a/heapster-saw/examples/rust_data.saw +++ b/heapster-saw/examples/rust_data.saw @@ -105,7 +105,7 @@ heapster_define_opaque_llvmshape env "Vec" 64 // Opaque type for HashMap heapster_define_opaque_llvmshape env "HashMap" 64 "T:llvmshape 64, U:llvmshape 64" "56" - "\\ (T:sort 0) (U:sort 0) -> List (T * U)"; + "\\ (T:sort 0) (U:sort 0) -> List #(T, U)"; // BinTree type heapster_define_rust_type env @@ -234,7 +234,7 @@ heapster_assume_fun_rename env exchange_malloc_sym "exchange_malloc" heapster_assume_fun env "llvm.uadd.with.overflow.i64" "(). arg0:int64<>, arg1:int64<> -o ret:struct(int64<>,int1<>)" "\\ (x y:Vec 64 Bool) -> \ - \ returnM (Vec 64 Bool * Vec 1 Bool) (bvAdd 64 x y, single Bool (bvCarry 64 x y))"; + \ returnM #(Vec 64 Bool, Vec 1 Bool) (bvAdd 64 x y, single Bool (bvCarry 64 x y))"; // llvm.expect.i1 heapster_assume_fun env "llvm.expect.i1" @@ -248,7 +248,7 @@ heapster_assume_fun env "llvm.memcpy.p0i8.p0i8.i64" \ arg0:[l1]memblock(W,0,len,sh), arg1:[l2]memblock(rw,0,len,eqsh(len,b)), \ \ arg2:eq(llvmword(len)) -o \ \ arg0:[l1]memblock(W,0,len,eqsh(len,b)), arg1:[l2]memblock(rw,0,len,eqsh(len,b))" - "\\ (X:sort 0) (len:Vec 64 Bool) (x:X) (_:#()) -> returnM (#() * #()) ((),())"; + "\\ (X:sort 0) (len:Vec 64 Bool) (x:X) (_:#()) -> returnM #(#(), #()) ((),())"; // Box>::clone box_list20_u64_clone_sym <- heapster_find_symbol_with_type env diff --git a/heapster-saw/examples/rust_data.sawcore b/heapster-saw/examples/rust_data.sawcore index 9d39cde030..76d8a52b59 100644 --- a/heapster-saw/examples/rust_data.sawcore +++ b/heapster-saw/examples/rust_data.sawcore @@ -3,17 +3,17 @@ module rust_data where import Prelude; -unfoldListPermH : (a:sort 0) -> List a -> Either #() (#() * a * List a); +unfoldListPermH : (a:sort 0) -> List a -> Either #() #(#(), a, List a); unfoldListPermH a l = - List__rec a (\ (_:List a) -> Either #() (#() * a * List a)) - (Left #() (#() * a * List a) ()) - (\ (x:a) (l:List a) (_:Either #() (#() * a * List a)) -> - Right #() (#() * a * List a) ((), x, l)) + List__rec a (\ (_:List a) -> Either #() #(#(), a, List a)) + (Left #() #(#(), a, List a) ()) + (\ (x:a) (l:List a) (_:Either #() #(#(), a, List a)) -> + Right #() #(#(), a, List a) ((), x, l)) l; -foldListPermH : (a:sort 0) -> Either #() (#() * a * List a) -> List a; +foldListPermH : (a:sort 0) -> Either #() #(#(), a, List a) -> List a; foldListPermH a = - either #() (#() * a * List a) (List a) + either #() #(#(), a, List a) (List a) (\ (_ : #()) -> Nil a) - (\ (tup : (#() * a * List a)) -> - Cons a tup.(2).(1) tup.(2).(2)); + (\ (tup : #(#(), a, List a)) -> + Cons a tup.1 tup.2); diff --git a/heapster-saw/examples/rust_lifetimes.saw b/heapster-saw/examples/rust_lifetimes.saw index 572d372e05..0edd81c7c4 100644 --- a/heapster-saw/examples/rust_lifetimes.saw +++ b/heapster-saw/examples/rust_lifetimes.saw @@ -26,7 +26,7 @@ heapster_assume_fun env "llvm.uadd.with.overflow.i64" "(). arg0:int64<>, arg1:int64<> -o \ \ ret:struct(int64<>,int1<>)" "\\ (x y:Vec 64 Bool) -> \ - \ returnM (Vec 64 Bool * Vec 1 Bool) \ + \ returnM #(Vec 64 Bool, Vec 1 Bool) \ \ (bvAdd 64 x y, gen 1 Bool (\\ (_:Nat) -> bvCarry 64 x y))"; // llvm.expect.i1 diff --git a/heapster-saw/examples/rust_lifetimes.sawcore b/heapster-saw/examples/rust_lifetimes.sawcore index ea6aa76bf6..cee95af331 100644 --- a/heapster-saw/examples/rust_lifetimes.sawcore +++ b/heapster-saw/examples/rust_lifetimes.sawcore @@ -3,17 +3,17 @@ module rust_lifetimes where import Prelude; -unfoldListPermH : (a:sort 0) -> List a -> Either #() (#() * a * List a); +unfoldListPermH : (a:sort 0) -> List a -> Either #() #(#(), a, List a); unfoldListPermH a l = - List__rec a (\ (_:List a) -> Either #() (#() * a * List a)) - (Left #() (#() * a * List a) ()) - (\ (x:a) (l:List a) (_:Either #() (#() * a * List a)) -> - Right #() (#() * a * List a) ((), x, l)) + List__rec a (\ (_:List a) -> Either #() #(#(), a, List a)) + (Left #() #(#(), a, List a) ()) + (\ (x:a) (l:List a) (_:Either #() #(#(), a, List a)) -> + Right #() #(#(), a, List a) ((), x, l)) l; -foldListPermH : (a:sort 0) -> Either #() (#() * a * List a) -> List a; +foldListPermH : (a:sort 0) -> Either #() #(#(), a, List a) -> List a; foldListPermH a = - either #() (#() * a * List a) (List a) + either #() #(#(), a, List a) (List a) (\ (_ : #()) -> Nil a) - (\ (tup : (#() * a * List a)) -> - Cons a tup.(2).(1) tup.(2).(2)); + (\ (tup : #(#(), a, List a)) -> + Cons a tup.1 tup.2); diff --git a/heapster-saw/examples/sha512.saw b/heapster-saw/examples/sha512.saw index f5b470e4a5..15ec18f8a5 100644 --- a/heapster-saw/examples/sha512.saw +++ b/heapster-saw/examples/sha512.saw @@ -10,7 +10,7 @@ heapster_define_perm env "int8" " " "llvmptr 8" "exists x:bv 8.eq(llvmword(x))"; heapster_assume_fun env "CRYPTO_load_u64_be" "(). arg0:ptr((R,0) |-> int64<>) -o \ \ arg0:ptr((R,0) |-> int64<>), ret:int64<>" - "\\ (x:Vec 64 Bool) -> returnM (Vec 64 Bool * Vec 64 Bool) (x, x)"; + "\\ (x:Vec 64 Bool) -> returnM #(Vec 64 Bool, Vec 64 Bool) (x, x)"; heapster_typecheck_fun env "return_state" "(). arg0:array(W,0,<8,*8,fieldsh(int64<>)) -o \ diff --git a/heapster-saw/examples/string_set.sawcore b/heapster-saw/examples/string_set.sawcore index ebb12ebe76..e22d5c4520 100644 --- a/heapster-saw/examples/string_set.sawcore +++ b/heapster-saw/examples/string_set.sawcore @@ -8,10 +8,10 @@ listInsertM a l s = returnM (List a) (Cons a s l); listRemoveM : (a : sort 0) -> (a -> a -> Bool) -> List a -> a -> - CompM (List a * a); + CompM #(List a, a); listRemoveM a test_eq l s = returnM - (List a * a) + #(List a, a) (List__rec a (\ (_:List a) -> List a) (Nil a) @@ -30,10 +30,10 @@ stringListInsertM : List String -> String -> CompM (List String); stringListInsertM l s = returnM (List String) (Cons String s l); -stringListRemoveM : List String -> String -> CompM (stringList * String); +stringListRemoveM : List String -> String -> CompM #(stringList, String); stringListRemoveM l s = returnM - (stringList * String) + #(stringList, String) (List__rec String (\ (_:List String) -> List String) (Nil String) diff --git a/heapster-saw/src/Verifier/SAW/Heapster/SAWTranslation.hs b/heapster-saw/src/Verifier/SAW/Heapster/SAWTranslation.hs index 10eeaeabea..e50c8d1e41 100644 --- a/heapster-saw/src/Verifier/SAW/Heapster/SAWTranslation.hs +++ b/heapster-saw/src/Verifier/SAW/Heapster/SAWTranslation.hs @@ -133,32 +133,29 @@ typeTransType1 (TypeTrans [] _) = unitTypeOpenTerm typeTransType1 (TypeTrans [tp] _) = tp typeTransType1 _ = error ("typeTransType1" ++ nlPrettyCallStack callStack) --- | Build the tuple type @T1 * (T2 * ... * (Tn-1 * Tn))@ of @n@ types, with the --- special case that 0 types maps to the unit type @#()@ (and 1 type just maps --- to itself). Note that this is different from 'tupleTypeOpenTerm', which --- always ends with unit, i.e., which returns @T1*(T2*...*(Tn-1*(Tn*#())))@. +-- | Build the tuple type @#(T1, T2, ... Tn-1, Tn)@ of @n@ types, with +-- the special case that 0 types maps to the unit type @#()@ (and 1 +-- type just maps to itself). tupleOfTypes :: [OpenTerm] -> OpenTerm tupleOfTypes [] = unitTypeOpenTerm tupleOfTypes [tp] = tp -tupleOfTypes (tp:tps) = pairTypeOpenTerm tp $ tupleOfTypes tps +tupleOfTypes tps = tupleTypeOpenTerm tps --- | Build the tuple @(t1,(t2,(...,(tn-1,tn))))@ of @n@ terms, with the --- special case that 0 types maps to the unit value @()@ (and 1 value just maps --- to itself). Note that this is different from 'tupleOpenTerm', which --- always ends with unit, i.e., which returns @t1*(t2*...*(tn-1*(tn*())))@. +-- | Build the tuple @(t1, t2, ..., tn-1, tn)@ of @n@ terms, with the +-- special case that 0 types maps to the unit value @()@ (and 1 value +-- just maps to itself). tupleOfTerms :: [OpenTerm] -> OpenTerm tupleOfTerms [] = unitOpenTerm tupleOfTerms [t] = t -tupleOfTerms (t:ts) = pairOpenTerm t $ tupleOfTerms ts +tupleOfTerms ts = tupleOpenTerm ts -- | Project the @i@th element from a term of type @'tupleOfTypes' tps@. Note -- that this requires knowing the length of @tps@. projTupleOfTypes :: [OpenTerm] -> Integer -> OpenTerm -> OpenTerm -projTupleOfTypes [] _ _ = error "projTupleOfTypes: projection of empty tuple!" -projTupleOfTypes [_] 0 tup = tup -projTupleOfTypes (_:_) 0 tup = pairLeftOpenTerm tup -projTupleOfTypes (_:tps) i tup = - projTupleOfTypes tps (i-1) $ pairRightOpenTerm tup +projTupleOfTypes tps i tup + | i == 0 && length tps == 1 = tup + | i < toInteger (length tps) = projTupleOpenTerm i tup + | otherwise = error "projTupleOfTypes: invalid tuple index" -- | Map the 'typeTransTypes' field of a 'TypeTrans' to a single type, where a -- single type is mapped to itself, an empty list of types is mapped to @unit@, @@ -4585,14 +4582,15 @@ typedBlockLetRecEntries = . filter (anyF typedEntryHasMultiInDegree) . (^. typedBlockEntries)) --- | Fold a function over each 'TypedEntry' in a 'TypedBlockMap' that +-- | Map a monadic function over each 'TypedEntry' in a 'TypedBlockMap' that -- corresponds to a letrec-bound variable -foldBlockMapLetRec :: +mapBlockMapLetRec :: (forall args ghosts. - TypedEntry TransPhase ext blocks tops rets args ghosts -> b -> b) -> - b -> TypedBlockMap TransPhase ext blocks tops rets -> b -foldBlockMapLetRec f r = - foldr (\(SomeTypedEntry entry) -> f entry) r . typedBlockLetRecEntries + TypedEntry TransPhase ext blocks tops rets args ghosts -> + TypeTransM ctx b) -> + TypedBlockMap TransPhase ext blocks tops rets -> TypeTransM ctx [b] +mapBlockMapLetRec f = + mapM (\(SomeTypedEntry entry) -> f entry) . typedBlockLetRecEntries -- | Construct a @LetRecType@ inductive description -- @@ -4627,13 +4625,10 @@ translateEntryLRT entry@(TypedEntry {..}) = -- entrypoints in a 'TypedBlockMap' translateBlockMapLRTs :: TypedBlockMap TransPhase ext blocks tops rets -> TypeTransM ctx OpenTerm -translateBlockMapLRTs = - foldBlockMapLetRec - (\entry rest_m -> - do entryType <- translateEntryLRT entry - rest <- rest_m - return $ ctorOpenTerm "Prelude.LRT_Cons" [entryType, rest]) - (return $ ctorOpenTerm "Prelude.LRT_Nil" []) +translateBlockMapLRTs blk_map = + foldr (\entryType rest -> ctorOpenTerm "Prelude.LRT_Cons" [entryType, rest]) + (ctorOpenTerm "Prelude.LRT_Nil" []) <$> + mapBlockMapLetRec translateEntryLRT blk_map -- | Lambda-abstract over all the entrypoints in a 'TypedBlockMap' that -- correspond to letrec-bound functions, putting the lambda-bound variables into @@ -4689,11 +4684,8 @@ translateBlockMapBodies :: PermCheckExtC ext => TypedBlockMapTrans ext blocks tops rets -> TypedBlockMap TransPhase ext blocks tops rets -> TypeTransM ctx OpenTerm -translateBlockMapBodies mapTrans = - foldBlockMapLetRec - (\entry restM -> - pairOpenTerm <$> translateEntryBody mapTrans entry <*> restM) - (return unitOpenTerm) +translateBlockMapBodies mapTrans blk_map = + tupleOpenTerm <$> mapBlockMapLetRec (translateEntryBody mapTrans) blk_map -- | Translate a typed CFG to a SAW term diff --git a/saw-core-aig/src/Verifier/SAW/Simulator/BitBlast.hs b/saw-core-aig/src/Verifier/SAW/Simulator/BitBlast.hs index 7850d9d1ae..a92f9882fa 100644 --- a/saw-core-aig/src/Verifier/SAW/Simulator/BitBlast.hs +++ b/saw-core-aig/src/Verifier/SAW/Simulator/BitBlast.hs @@ -129,11 +129,8 @@ flattenBValue (VWord lv) = return lv flattenBValue (VExtra (BStream _ _)) = error "Verifier.SAW.Simulator.BitBlast.flattenBValue: BStream" flattenBValue (VVector vv) = AIG.concat <$> traverse (flattenBValue <=< force) (V.toList vv) -flattenBValue VUnit = return $ AIG.concat [] -flattenBValue (VPair x y) = do - vx <- flattenBValue =<< force x - vy <- flattenBValue =<< force y - return $ AIG.concat [vx, vy] +flattenBValue (VTuple xs) = + AIG.concat <$> mapM (flattenBValue <=< force) (V.toList xs) flattenBValue (VRecordValue elems) = do AIG.concat <$> mapM (flattenBValue <=< force . snd) elems flattenBValue _ = error $ unwords ["Verifier.SAW.Simulator.BitBlast.flattenBValue: unsupported value"] diff --git a/saw-core-coq/coq/handwritten/CryptolToCoq/SAWCoreScaffolding.v b/saw-core-coq/coq/handwritten/CryptolToCoq/SAWCoreScaffolding.v index 36503eafd7..827bc2bf62 100644 --- a/saw-core-coq/coq/handwritten/CryptolToCoq/SAWCoreScaffolding.v +++ b/saw-core-coq/coq/handwritten/CryptolToCoq/SAWCoreScaffolding.v @@ -42,6 +42,57 @@ Instance Inhabited_Intro (a:Type) (b:a -> Type) (Hb: forall x, Inhabited (b x)) Global Hint Extern 5 (Inhabited ?A) => (apply (@MkInhabited A); intuition (eauto with typeclass_instances inh)) : typeclass_instances. +Definition TypeList : Type := list Type. +Definition TypeNil : TypeList := nil. +Definition TypeCons : Type -> TypeList -> TypeList := cons. + +Definition TypeList__rec : + forall (p : TypeList -> Type) + (f1 : p TypeNil) (f2 : forall t ts, p ts -> p (TypeCons t ts)) + (ts : TypeList), p ts := @list_rect Type. + +(* The append of a list of tuple types to a type *) +Fixpoint TupleApp (a : Type) (ts : list Type) : Type := + match ts with + | nil => a + | cons t ts' => TupleApp (prod a t) ts' + end. + +(* A tuple of types, satisfying + Tuple [] = () + Tuple [a] = a + Tuple [a1, ..., an] = (... ((a1, a2), a3), ..., an) + which lines up with the Coq tuple notation *) +Definition Tuple (ts : list Type) : Type := + match ts with + | nil => unit + | cons t ts' => TupleApp t ts' + end. + +Fixpoint headTuple {a : Type} {ts : TypeList} : Tuple (cons a ts) -> a := + match ts with + | nil => @id a + | cons t ts' => fun tup => fst (@headTuple (a * t) ts' tup) + end. + +Fixpoint mapTupleApp {a b : Type} {ts : list Type} (f : a -> b) : TupleApp a ts -> TupleApp b ts := + match ts with + | nil => f + | cons t ts' => @mapTupleApp (a * t) (b * t) ts' (fun x => (f (fst x), snd x)) + end. + +Definition tailTuple {a : Type} {ts : list Type} : Tuple (cons a ts) -> Tuple ts := + match ts with + | nil => (fun _ : a => tt) + | cons t ts' => @mapTupleApp (a * t) t ts' snd + end. + +Definition consTuple {a : Type} {ts : list Type} (x : a) : Tuple ts -> Tuple (cons a ts) := + match ts with + | nil => (fun _ : unit => x) + | cons t ts' => @mapTupleApp t (a * t) ts' (pair x) + end. + Definition String := String.string. Instance Inhabited_String : Inhabited String := diff --git a/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs b/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs index d8e61537de..05bfd47f59 100644 --- a/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs +++ b/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs @@ -213,6 +213,17 @@ sawCorePreludeSpecialTreatmentMap configuration = -- sawLet [ ("sawLet", mapsTo sawCoreScaffoldingModule "sawLet_def") ] + -- Tuples + ++ + [ ("TypeList", mapsTo sawCoreScaffoldingModule "TypeList") + , ("TypeNil", mapsTo sawCoreScaffoldingModule "TypeNil") + , ("TypeCons", mapsTo sawCoreScaffoldingModule "TypeCons") + , ("TypeList__rec", mapsTo sawCoreScaffoldingModule "TypeList__rec") + , ("Tuple", mapsTo sawCoreScaffoldingModule "Tuple") + , ("headTuple", mapsTo sawCoreScaffoldingModule "headTuple") + , ("tailTuple", mapsTo sawCoreScaffoldingModule "tailTuple") + , ("consTuple", mapsTo sawCoreScaffoldingModule "consTuple")] + -- Unsafe SAW features ++ [ ("error", mapsTo sawDefinitionsModule "error") diff --git a/saw-core-coq/src/Verifier/SAW/Translation/Coq/Term.hs b/saw-core-coq/src/Verifier/SAW/Translation/Coq/Term.hs index e39a3ceb96..42c5388905 100644 --- a/saw-core-coq/src/Verifier/SAW/Translation/Coq/Term.hs +++ b/saw-core-coq/src/Verifier/SAW/Translation/Coq/Term.hs @@ -218,6 +218,21 @@ translateIdentToIdent i = translateSort :: Sort -> Coq.Sort translateSort s = if s == propSort then Coq.Prop else Coq.Type +translateTuple :: [Coq.Term] -> Coq.Term +translateTuple [] = Coq.Var "tt" +translateTuple (x : xs) = Coq.App (Coq.Var "pair") [x, translateTuple xs] + +translateTupleType :: [Coq.Term] -> Coq.Term +translateTupleType [] = Coq.Ascription (Coq.Var "unit") (Coq.Sort Coq.Type) + -- We need to explicitly tell Coq that we want unit to be a Type, since + -- all SAW core sorts are translated to Types +translateTupleType (x : xs) = Coq.App (Coq.Var "prod") [x, translateTupleType xs] + +translateTupleSelector :: Int -> Coq.Term -> Coq.Term +translateTupleSelector i x + | i == 0 = Coq.App (Coq.Var "SAWCoreScaffolding.fst") [x] + | otherwise = translateTupleSelector (i - 1) (Coq.App (Coq.Var "SAWCoreScaffolding.snd") [x]) + flatTermFToExpr :: TermTranslationMonad m => FlatTermF Term -> @@ -225,17 +240,8 @@ flatTermFToExpr :: flatTermFToExpr tf = -- traceFTermF "flatTermFToExpr" tf $ case tf of Primitive pn -> translateIdent (primName pn) - UnitValue -> pure (Coq.Var "tt") - UnitType -> - -- We need to explicitly tell Coq that we want unit to be a Type, since - -- all SAW core sorts are translated to Types - pure (Coq.Ascription (Coq.Var "unit") (Coq.Sort Coq.Type)) - PairValue x y -> Coq.App (Coq.Var "pair") <$> traverse translateTerm [x, y] - PairType x y -> Coq.App (Coq.Var "prod") <$> traverse translateTerm [x, y] - PairLeft t -> - Coq.App <$> pure (Coq.Var "SAWCoreScaffolding.fst") <*> traverse translateTerm [t] - PairRight t -> - Coq.App <$> pure (Coq.Var "SAWCoreScaffolding.snd") <*> traverse translateTerm [t] + TupleValue xs -> translateTuple <$> traverse translateTerm (Vector.toList xs) + TupleSelector x i -> translateTupleSelector i <$> translateTerm x -- TODO: maybe have more customizable translation of data types DataTypeApp n is as -> translateIdentWithArgs (primName n) (is ++ as) CtorApp n is as -> translateIdentWithArgs (primName n) (is ++ as) @@ -644,10 +650,8 @@ defaultTermForType typ = do defaultT <- defaultTermForType typ' return $ Coq.App seqConst [ nT, typ'T, defaultT ] - (asPairType -> Just (x,y)) -> do - x' <- defaultTermForType x - y' <- defaultTermForType y - return $ Coq.App (Coq.Var "pair") [x',y'] + (asTupleType -> Just xs) -> + translateTuple <$> traverse defaultTermForType xs (asPiList -> (bs,body)) | not (null bs) diff --git a/saw-core-sbv/src/Verifier/SAW/Simulator/SBV.hs b/saw-core-sbv/src/Verifier/SAW/Simulator/SBV.hs index db209362ca..31700d0432 100644 --- a/saw-core-sbv/src/Verifier/SAW/Simulator/SBV.hs +++ b/saw-core-sbv/src/Verifier/SAW/Simulator/SBV.hs @@ -262,10 +262,8 @@ flattenSValue nm v = do Just w -> return ([w], "") Nothing -> case v of - VUnit -> return ([], "") - VPair x y -> do (xs, sx) <- flattenSValue nm =<< force x - (ys, sy) <- flattenSValue nm =<< force y - return (xs ++ ys, sx ++ sy) + VTuple (V.toList -> ts) -> do (xss, ss) <- unzip <$> traverse (force >=> flattenSValue nm) ts + pure (concat xss, concat ss) VRecordValue elems -> do (xss, sxs) <- unzip <$> mapM (flattenSValue nm <=< force . snd) elems @@ -642,13 +640,10 @@ parseUninterpreted cws nm ty = | i <- [0 .. n-1] ] return (VVector (V.fromList (map ready xs))) - VUnitType - -> return VUnit - - (VPairType ty1 ty2) - -> do x1 <- parseUninterpreted cws (nm ++ ".L") ty1 - x2 <- parseUninterpreted cws (nm ++ ".R") ty2 - return (VPair (ready x1) (ready x2)) + VTupleType tys + -> do let mkElem i ty' = parseUninterpreted cws (nm ++ "." ++ show i) ty' + xs <- V.imapM mkElem tys + pure (VTuple (fmap ready xs)) (VRecordType elem_tps) -> (VRecordValue <$> @@ -898,16 +893,11 @@ sbvSetOutput checkSz (FOTVec n t) (VVector xv) i = do Just ws -> do svCgOutputArr ("out_"++show i) ws return $! i+1 Nothing -> foldM (\i' x -> sbvSetOutput checkSz t x i') i xs -sbvSetOutput _checkSz (FOTTuple []) VUnit i = - return i -sbvSetOutput checkSz (FOTTuple [t]) v i = sbvSetOutput checkSz t v i -sbvSetOutput checkSz (FOTTuple (t:ts)) (VPair l r) i = do - l' <- liftIO $ force l - r' <- liftIO $ force r - sbvSetOutput checkSz t l' i >>= sbvSetOutput checkSz (FOTTuple ts) r' - -sbvSetOutput _checkSz (FOTRec fs) VUnit i | Map.null fs = do - return i +sbvSetOutput checkSz (FOTTuple ts) (VTuple xs) i = + do unless (length ts == V.length xs) $ + fail "sbvCodeGen: vector length mismatch when setting output values" + vs <- liftIO $ traverse force xs + foldM (\i' (t, v) -> sbvSetOutput checkSz t v i') i (zip ts (V.toList vs)) sbvSetOutput _checkSz (FOTRec fs) (VRecordValue []) i | Map.null fs = return i diff --git a/saw-core-what4/src/Verifier/SAW/Simulator/What4.hs b/saw-core-what4/src/Verifier/SAW/Simulator/What4.hs index f106645b9e..e9804cc008 100644 --- a/saw-core-what4/src/Verifier/SAW/Simulator/What4.hs +++ b/saw-core-what4/src/Verifier/SAW/Simulator/What4.hs @@ -878,13 +878,12 @@ parseUninterpreted sym ref app ty = -> (VArray . SArray) <$> mkUninterpreted sym ref app (BaseArrayRepr (Ctx.Empty Ctx.:> idx_repr) elm_repr) - VUnitType - -> return VUnit - - VPairType ty1 ty2 - -> do x1 <- parseUninterpreted sym ref (suffixUnintApp "_L" app) ty1 - x2 <- parseUninterpreted sym ref (suffixUnintApp "_R" app) ty2 - return (VPair (ready x1) (ready x2)) + VTupleType tys + -> do let mkElem i ty' = + do let app' = suffixUnintApp ("_" ++ show i) app + parseUninterpreted sym ref app' ty' + xs <- V.imapM mkElem tys + pure (VTuple (fmap ready xs)) VRecordType elem_tps -> (VRecordValue <$> @@ -943,10 +942,7 @@ applyUnintApp :: IO (UnintApp (SymExpr sym)) applyUnintApp sym app0 v = case v of - VUnit -> return app0 - VPair x y -> do app1 <- applyUnintApp sym app0 =<< force x - app2 <- applyUnintApp sym app1 =<< force y - return app2 + VTuple xv -> foldM (applyUnintApp sym) app0 =<< traverse force xv VRecordValue elems -> foldM (applyUnintApp sym) app0 =<< traverse (force . snd) elems VVector xv -> foldM (applyUnintApp sym) app0 =<< traverse force xv VBool sb -> return (extendUnintApp app0 sb BaseBoolRepr) @@ -1029,14 +1025,8 @@ vAsFirstOrderType v = -> FOTVec n <$> vAsFirstOrderType v2 VArrayType iv ev -> FOTArray <$> vAsFirstOrderType iv <*> vAsFirstOrderType ev - VUnitType - -> return (FOTTuple []) - VPairType v1 v2 - -> do t1 <- vAsFirstOrderType v1 - t2 <- vAsFirstOrderType v2 - case t2 of - FOTTuple ts -> return (FOTTuple (t1 : ts)) - _ -> return (FOTTuple [t1, t2]) + VTupleType tvs + -> FOTTuple <$> traverse vAsFirstOrderType (V.toList tvs) VRecordType tps -> (FOTRec <$> Map.fromList <$> mapM (\(f,tp) -> (f,) <$> vAsFirstOrderType tp) tps) @@ -1354,15 +1344,13 @@ parseUninterpretedSAW sym st sc ref trm app ty = -> (VArray . SArray) <$> mkUninterpretedSAW sym st sc ref trm app (BaseArrayRepr (Ctx.Empty Ctx.:> idx_repr) elm_repr) - VUnitType - -> return VUnit - - VPairType ty1 ty2 - -> do let trm1 = ArgTermPairLeft trm - let trm2 = ArgTermPairRight trm - x1 <- parseUninterpretedSAW sym st sc ref trm1 (suffixUnintApp "_L" app) ty1 - x2 <- parseUninterpretedSAW sym st sc ref trm2 (suffixUnintApp "_R" app) ty2 - return (VPair (ready x1) (ready x2)) + VTupleType tys + -> do let mkElem i ty' = + do let trm' = ArgTermTupleProj trm i + let app' = suffixUnintApp ("_" ++ show i) app + parseUninterpretedSAW sym st sc ref trm' app' ty' + xs <- V.imapM mkElem tys + pure (VTuple (fmap ready xs)) _ -> fail $ "could not create uninterpreted symbol of type " ++ show ty @@ -1392,15 +1380,13 @@ data ArgTerm | ArgTermToIntMod Natural ArgTerm -- ^ toIntMod n x | ArgTermFromIntMod Natural ArgTerm -- ^ fromIntMod n x | ArgTermVector Term [ArgTerm] -- ^ element type, elements - | ArgTermUnit - | ArgTermPair ArgTerm ArgTerm + | ArgTermTuple [ArgTerm] | ArgTermRecord [(FieldName, ArgTerm)] | ArgTermConst Term | ArgTermApply ArgTerm ArgTerm | ArgTermAt Natural Term ArgTerm Natural -- ^ length, element type, list, index - | ArgTermPairLeft ArgTerm - | ArgTermPairRight ArgTerm + | ArgTermTupleProj ArgTerm Int | ArgTermBVToNat Natural ArgTerm -- | Reassemble a saw-core term from an 'ArgTerm' and a list of parts. @@ -1437,14 +1423,10 @@ reconstructArgTerm atrm sc ts = do (xs, ts1) <- parseList ats ts0 x <- scVectorReduced sc ty xs return (x, ts1) - ArgTermUnit -> - do x <- scUnitValue sc - return (x, ts0) - ArgTermPair at1 at2 -> - do (x1, ts1) <- parse at1 ts0 - (x2, ts2) <- parse at2 ts1 - x <- scPairValue sc x1 x2 - return (x, ts2) + ArgTermTuple ats -> + do (xs, ts1) <- parseList ats ts0 + x <- scTupleReduced sc xs + pure (x, ts1) ArgTermRecord flds -> do let (tags, ats) = unzip flds (xs, ts1) <- parseList ats ts0 @@ -1463,14 +1445,10 @@ reconstructArgTerm atrm sc ts = i' <- scNat sc i x <- scAt sc n' ty x1 i' return (x, ts1) - ArgTermPairLeft at1 -> - do (x1, ts1) <- parse at1 ts0 - x <- scPairLeft sc x1 - return (x, ts1) - ArgTermPairRight at1 -> + ArgTermTupleProj at1 i -> do (x1, ts1) <- parse at1 ts0 - x <- scPairRight sc x1 - return (x, ts1) + x <- scTupleSelector sc x1 i + pure (x, ts1) ArgTermBVToNat w at1 -> do (x1, ts1) <- parse at1 ts0 x <- scBvToNat sc w x1 @@ -1495,7 +1473,6 @@ mkArgTerm sc ty val = (_, VWord ZBV) -> return ArgTermBVZero -- 0-width bitvector is a constant (_, VWord (DBV _)) -> return ArgTermVar (_, VArray{}) -> return ArgTermVar - (VUnitType, VUnit) -> return ArgTermUnit (VIntModType n, VIntMod _ _) -> pure (ArgTermToIntMod n ArgTermVar) (VVecType _ ety, VVector vv) -> @@ -1504,10 +1481,10 @@ mkArgTerm sc ty val = ety' <- termOfTValue sc ety return (ArgTermVector ety' xs) - (VPairType ty1 ty2, VPair v1 v2) -> - do x1 <- mkArgTerm sc ty1 =<< force v1 - x2 <- mkArgTerm sc ty2 =<< force v2 - return (ArgTermPair x1 x2) + (VTupleType tys, VTuple ts) | V.length tys == V.length ts -> + do vs <- traverse force ts + xs <- sequence (V.zipWith (mkArgTerm sc) tys vs) + pure (ArgTermTuple (V.toList xs)) (VRecordType tys, VRecordValue flds) | map fst tys == map fst flds -> do let tags = map fst tys @@ -1542,15 +1519,13 @@ termOfTValue sc val = case val of VBoolType -> scBoolType sc VIntType -> scIntegerType sc - VUnitType -> scUnitType sc VVecType n a -> do n' <- scNat sc n a' <- termOfTValue sc a scVecType sc n' a' - VPairType a b - -> do a' <- termOfTValue sc a - b' <- termOfTValue sc b - scPairType sc a' b' + VTupleType vs -> + do vs' <- traverse (termOfTValue sc) vs + scTupleType sc (V.toList vs') VRecordType flds -> do flds' <- traverse (traverse (termOfTValue sc)) flds scRecordType sc flds' @@ -1559,8 +1534,10 @@ termOfTValue sc val = termOfSValue :: SharedContext -> SValue sym -> IO Term termOfSValue sc val = case val of - VUnit -> scUnitValue sc - VNat n - -> scNat sc n + VNat n -> scNat sc n + VTuple ts -> + do vs <- traverse force ts + vs' <- traverse (termOfSValue sc) vs + scTuple sc (V.toList vs') TValue tv -> termOfTValue sc tv _ -> fail $ "termOfSValue: " ++ show val diff --git a/saw-core/prelude/Prelude.sawcore b/saw-core/prelude/Prelude.sawcore index 8eeaaf033b..3d334a69c0 100644 --- a/saw-core/prelude/Prelude.sawcore +++ b/saw-core/prelude/Prelude.sawcore @@ -22,57 +22,43 @@ sawLet : (a b : sort 0) -> a -> (a -> b) -> b; sawLet _ _ x f = f x; --- FIXME: below are some defined data-types that could be used in place of --- the SAW primitive types - -------------------------------------------------------------------------------- --- The Unit type - -data UnitType : sort 0 where { - Unit : UnitType; - } +-- Tuple types --- The recursor for the Unit type at sort 0 --- UnitType__rec : (p : UnitType -> sort 0) -> p Unit -> (u : UnitType) -> p u; --- UnitType__rec p f1 u = UnitType#rec p f1 u; -UnitType__rec (p : UnitType -> sort 0) (f1 : p Unit) (u : UnitType) : p u - = UnitType#rec p f1 u; +-- TypeNil, TypeCons and Tuple are used to represent the #(a, b, c) +-- syntax for tuple types, so it is important that they be defined +-- before any uses of tuple types in this file. --------------------------------------------------------------------------------- --- Pair types - -data PairType (a b : sort 0) : sort 0 where { - PairValue : a -> b -> PairType a b; -} +data TypeList : sort 1 where { + TypeNil : TypeList; + TypeCons : sort 0 -> TypeList -> TypeList; + } -pair_example : (a b : sort 0) -> a -> b -> PairType a b; -pair_example a b x y = PairValue a b x y; +TypeList__rec + (p : TypeList -> sort 1) + (f1 : p TypeNil) + (f2 : (t : sort 0) -> (ts : TypeList) -> p ts -> p (TypeCons t ts)) + (ts : TypeList) + : p ts + = TypeList#rec p f1 f2 ts; --- The recursor for primitive pair types at sort 1 -Pair__rec - (a b : sort 0) - (p : PairType a b -> sort 0) - (f : (x:a) -> (y:b) -> p (PairValue a b x y)) - (pair : PairType a b) - : p pair - = PairType#rec a b p f pair; +primitive Tuple : TypeList -> sort 0; -Pair_fst : (a b : sort 0) -> PairType a b -> a; -Pair_fst a b = Pair__rec a b (\ (p:PairType a b) -> a) - (\ (x:a) -> \ (y: b) -> x); +primitive headTuple : (t : sort 0) -> (ts : TypeList) -> Tuple (TypeCons t ts) -> t; +primitive tailTuple : (t : sort 0) -> (ts : TypeList) -> Tuple (TypeCons t ts) -> Tuple ts; +primitive consTuple : (t : sort 0) -> (ts : TypeList) -> t -> Tuple ts -> Tuple (TypeCons t ts); -Pair_snd : (a b : sort 0) -> PairType a b -> b; -Pair_snd a b = Pair__rec a b (\ (p:PairType a b) -> b) - (\ (x:a) -> \ (y:b) -> y); +-------------------------------------------------------------------------------- +-- Pair types -fst : (a b : sort 0) -> a * b -> a; -fst a b tup = tup.(1); +fst : (a b : sort 0) -> #(a, b) -> a; +fst a b tup = tup.0; -snd : (a b : sort 0) -> a * b -> b; -snd a b tup = tup.(2); +snd : (a b : sort 0) -> #(a, b) -> b; +snd a b tup = tup.1; -uncurry (a b c : sort 0) (f : a -> b -> c) : a * b -> c - = (\ (x : a * b) -> f x.(1) x.(2)); +uncurry (a b c : sort 0) (f : a -> b -> c) : #(a, b) -> c + = (\ (x : #(a, b)) -> f x.0 x.1); -------------------------------------------------------------------------------- -- String values @@ -321,11 +307,19 @@ implies__eq a b = Refl Bool (implies a b); -unitEq : UnitType -> UnitType -> Bool; +unitEq : Tuple TypeNil -> Tuple TypeNil -> Bool; unitEq _ _ = True; -pairEq : (a b : sort 0) -> (a -> a -> Bool) -> (b -> b -> Bool) -> a * b -> a * b -> Bool; -pairEq a b f g x y = and ( f x.(1) y.(1) ) ( g x.(2) y.(2) ); +pairEq : + (t : sort 0) -> + (ts : TypeList) -> + (t -> t -> Bool) -> + (Tuple ts -> Tuple ts -> Bool) -> + Tuple (TypeCons t ts) -> Tuple (TypeCons t ts) -> Bool; +pairEq t ts f g x y = + and + (f (headTuple t ts x) (headTuple t ts y)) + (g (tailTuple t ts x) (tailTuple t ts y)); -- @@ -878,13 +872,13 @@ expNat b e = Nat_cases Nat 1 (\ (e':Nat) -> \ (exp_b_e:Nat) -> mulNat b exp_b_e) e; -- | Natural division and modulus -primitive divModNat : Nat -> Nat -> Nat * Nat; +primitive divModNat : Nat -> Nat -> #(Nat, Nat); divNat : Nat -> Nat -> Nat; -divNat x y = (divModNat x y).(1); +divNat x y = (divModNat x y).0; modNat : Nat -> Nat -> Nat; -modNat x y = (divModNat x y).(2); +modNat x y = (divModNat x y).1; -- There are implicit constructors from integer literals. @@ -972,7 +966,7 @@ single = replicate 1; axiom at_single : (a : sort 0) -> (x : a) -> (i : Nat) -> Eq a (at 1 a (single a x) i) x; -- Zip together two lists (truncating the longer of the two). -primitive zip : (a b : sort 0) -> (m n : Nat) -> Vec m a -> Vec n b -> Vec (minNat m n) (a * b); +primitive zip : (a b : sort 0) -> (m n : Nat) -> Vec m a -> Vec n b -> Vec (minNat m n) #(a, b); primitive foldr : (a b : sort 0) -> (n : Nat) -> (a -> b -> b) -> b -> Vec n a -> b; primitive foldl : (a b : sort 0) -> (n : Nat) -> (b -> a -> b) -> b -> Vec n a -> b; @@ -1149,7 +1143,7 @@ bvCarry n x y = bvult n (bvAdd n x y) x; bvSCarry : (n : Nat) -> Vec (Succ n) Bool -> Vec (Succ n) Bool -> Bool; bvSCarry n x y = and (boolEq (msb n x) (msb n y)) (xor (msb n x) (msb n (bvAdd (Succ n) x y))); -bvAddWithCarry : (n : Nat) -> Vec n Bool -> Vec n Bool -> Bool * Vec n Bool; +bvAddWithCarry : (n : Nat) -> Vec n Bool -> Vec n Bool -> #(Bool, Vec n Bool); bvAddWithCarry n x y = (bvCarry n x y, bvAdd n x y); axiom bvAddZeroL : (n : Nat) -> (x : Vec n Bool) -> Eq (Vec n Bool) (bvAdd n (bvNat n 0) x) x; @@ -1497,20 +1491,20 @@ List__rec : (l : List a) -> P l; List__rec a P f1 f2 l = List#rec a P f1 f2 l; -unfoldList : (a:sort 0) -> List a -> Either #() (a * List a); +unfoldList : (a:sort 0) -> List a -> Either #() #(a, List a); unfoldList a l = - List__rec a (\ (_:List a) -> Either #() (a * List a)) - (Left #() (a * List a) ()) - (\ (x:a) (l:List a) (_:Either #() (a * List a)) -> - Right #() (a * List a) (x, l)) + List__rec a (\ (_:List a) -> Either #() #(a, List a)) + (Left #() #(a, List a) ()) + (\ (x:a) (l:List a) (_:Either #() #(a, List a)) -> + Right #() #(a, List a) (x, l)) l; -foldList : (a:sort 0) -> Either #() (a * List a) -> List a; +foldList : (a:sort 0) -> Either #() #(a, List a) -> List a; foldList a = - either #() (a * List a) (List a) + either #() #(a, List a) (List a) (\ (_ : #()) -> Nil a) - (\ (tup : (a * List a)) -> - Cons a tup.(1) tup.(2)); + (\ (tup : #(a, List a)) -> + Cons a tup.0 tup.1); -- A list of types, i.e. `List (sort 0)` if `List` was universe polymorphic data ListSort : sort 1 @@ -1550,29 +1544,27 @@ data W64List : sort 0 where { unfoldedW64List : sort 0; unfoldedW64List = - Either #() - (Sigma (Vec 64 Bool) (\ (_:Vec 64 Bool) -> #()) * W64List * #()); + Either #() #(Sigma (Vec 64 Bool) (\ (_:Vec 64 Bool) -> #()), W64List); unfoldW64List : W64List -> unfoldedW64List; unfoldW64List l = W64List#rec (\ (_:W64List) -> unfoldedW64List) - (Left #() (Sigma (Vec 64 Bool) (\ (_:Vec 64 Bool) -> #()) * W64List * #()) ()) + (Left #() #(Sigma (Vec 64 Bool) (\ (_:Vec 64 Bool) -> #()), W64List) ()) (\ (bv:Vec 64 Bool) (l':W64List) (_:unfoldedW64List) -> - Right #() (Sigma (Vec 64 Bool) (\ (_:Vec 64 Bool) -> #()) * W64List * #()) + Right #() #(Sigma (Vec 64 Bool) (\ (_:Vec 64 Bool) -> #()), W64List) (exists (Vec 64 Bool) (\ (_:Vec 64 Bool) -> #()) bv (), - l', ())) + l')) l; foldW64List : unfoldedW64List -> W64List; foldW64List = - either #() (Sigma (Vec 64 Bool) (\ (_:Vec 64 Bool) -> #()) * W64List * #()) + either #() #(Sigma (Vec 64 Bool) (\ (_:Vec 64 Bool) -> #()), W64List) W64List (\ (_:#()) -> W64Nil) - (\ (bv_l:(Sigma (Vec 64 Bool) (\ (_:Vec 64 Bool) -> #()) - * W64List * #())) -> + (\ (bv_l : #(Sigma (Vec 64 Bool) (\ (_:Vec 64 Bool) -> #()), W64List)) -> W64Cons (Sigma_proj1 (Vec 64 Bool) - (\ (_:Vec 64 Bool) -> #()) bv_l.(1)) - bv_l.(2).(1)); + (\ (_:Vec 64 Bool) -> #()) bv_l.0) + bv_l.1); -------------------------------------------------------------------------------- @@ -1962,7 +1954,7 @@ UnfoldedIRT As Ds D = IRTDesc__rec As (\ (_:IRTDesc As) -> IRTSubsts As -> sort Either (recl Ds) (recr Ds)) (\ (_:IRTDesc As) (recl : IRTSubsts As -> sort 0) (_:IRTDesc As) (recr : IRTSubsts As -> sort 0) (Ds:IRTSubsts As) -> - recl Ds * recr Ds) + #(recl Ds, recr Ds)) (\ (i:Nat) (_ : listSortGet As i -> IRTDesc As) (recf : listSortGet As i -> IRTSubsts As -> sort 0) (Ds:IRTSubsts As) -> Sigma (listSortGet As i) (\ (a:listSortGet As i) -> recf a Ds)) @@ -2018,7 +2010,7 @@ foldIRT As Ds D = IRTDesc__rec As (\ (D:IRTDesc As) -> (Ds:IRTSubsts As) -> Unfo (\ (xr:UnfoldedIRT As Ds Dr) -> IRT_Right As Ds Dl Dr (recr Ds xr)) x) (\ (Dl:IRTDesc As) (recl : (Ds:IRTSubsts As) -> UnfoldedIRT As Ds Dl -> IRT As Ds Dl) (Dr:IRTDesc As) (recr : (Ds:IRTSubsts As) -> UnfoldedIRT As Ds Dr -> IRT As Ds Dr) - (Ds:IRTSubsts As) (x:UnfoldedIRT As Ds Dl * UnfoldedIRT As Ds Dr) -> + (Ds:IRTSubsts As) (x : #(UnfoldedIRT As Ds Dl, UnfoldedIRT As Ds Dr)) -> uncurry (UnfoldedIRT As Ds Dl) (UnfoldedIRT As Ds Dr) (IRT As Ds (IRT_prod As Dl Dr)) (\ (xl:UnfoldedIRT As Ds Dl) (xr:UnfoldedIRT As Ds Dr) -> IRT_pair As Ds Dl Dr (recl Ds xl) (recr Ds xr)) x) @@ -2093,15 +2085,24 @@ composeM : (a b c: sort 0) -> (a -> CompM b) -> (b -> CompM c) -> a -> CompM c; composeM a b c f g x = bindM b c (f x) g; -- Tuple a type onto the input and output types of a monadic function -tupleCompMFunBoth : (a b c: sort 0) -> (a -> CompM b) -> (c * a -> CompM (c * b)); -tupleCompMFunBoth a b c f = - \ (x:c * a) -> - bindM b (c * b) (f x.(2)) (\ (y:b) -> returnM (c*b) (x.(1), y)); - --- Tuple a valu onto the output of a monadic function -tupleCompMFunOut : (a b c: sort 0) -> c -> (a -> CompM b) -> (a -> CompM (c * b)); -tupleCompMFunOut a b c x f = - \ (y:a) -> bindM b (c*b) (f y) (\ (z:b) -> returnM (c*b) (x,z)); +tupleCompMFunBoth : + (a b : TypeList) -> + (c : sort 0) -> + (Tuple a -> CompM (Tuple b)) -> + Tuple (TypeCons c a) -> CompM (Tuple (TypeCons c b)); +tupleCompMFunBoth a b c f x = + bindM (Tuple b) (Tuple (TypeCons c b)) (f (tailTuple c a x)) + (\ (y : Tuple b) -> returnM (Tuple (TypeCons c b)) (consTuple c b (headTuple c a x) y)); + +-- Tuple a value onto the output of a monadic function +tupleCompMFunOut : + (a : sort 0) -> + (b : TypeList) -> + (c : sort 0) -> + c -> (a -> CompM (Tuple b)) -> (a -> CompM (Tuple (TypeCons c b))); +tupleCompMFunOut a b c x f y = + bindM (Tuple b) (Tuple (TypeCons c b)) (f y) + (\ (z : Tuple b) -> returnM (Tuple (TypeCons c b)) (consTuple c b x z)); -- Map a monadic function across a vector mapM : (a :sort 0) -> (b : isort 0) -> (a -> CompM b) -> (n : Nat) -> Vec n a -> CompM (Vec n b); @@ -2266,16 +2267,21 @@ lrtPi lrts b = (\ (lrt:LetRecType) (_:LetRecTypes) (rest:sort 0) -> lrtToType lrt -> rest) lrts; --- Build the product type (lrtToType lrt1, ..., lrtToType lrtn) from the +-- Build the type list [lrtToType lrt1, ..., lrtToType lrtn] from the -- LetRecTypes list [lrt1, ..., lrtn] -lrtTupleType : LetRecTypes -> sort 0; -lrtTupleType lrts = +lrtTypeList : LetRecTypes -> TypeList; +lrtTypeList lrts = LetRecTypes#rec - (\ (lrts:LetRecTypes) -> sort 0) - #() - (\ (lrt:LetRecType) (_:LetRecTypes) (rest:sort 0) -> #(lrtToType lrt, rest)) + (\ (lrts:LetRecTypes) -> TypeList) + TypeNil + (\ (lrt:LetRecType) (_:LetRecTypes) (rest:TypeList) -> TypeCons (lrtToType lrt) rest) lrts; +-- Build the product type (lrtToType lrt1, ..., lrtToType lrtn) from the +-- LetRecTypes list [lrt1, ..., lrtn] +lrtTupleType : LetRecTypes -> sort 0; +lrtTupleType lrts = Tuple (lrtTypeList lrts); + -- NOTE: the following are needed to define letRecM instead of making it a -- primitive, which we are keeping commented here in case that is needed {- @@ -2357,7 +2363,7 @@ letRecM1 : (a b c : sort 0) -> ((a -> CompM b) -> (a -> CompM b)) -> letRecM1 a b c fn body = letRecM (LRT_Cons (LRT_Fun a (\ (_:a) -> LRT_Ret b)) LRT_Nil) c - (\ (f:a -> CompM b) -> (fn f, ())) + (\ (f:a -> CompM b) -> consTuple (a -> CompM b) TypeNil (fn f) ()) (\ (f:a -> CompM b) -> body f); -- A single-argument fixed-point function @@ -2367,7 +2373,7 @@ fixM : (a:sort 0) -> (b:a -> sort 0) -> fixM a b f x = letRecM (LRT_Cons (LRT_Fun a (\ (y:a) -> LRT_Ret (b y))) LRT_Nil) (b x) - (\ (g: (y:a) -> CompM (b y)) -> (f g, ())) + (\ (g: (y:a) -> CompM (b y)) -> consTuple ((y:a) -> CompM (b y)) TypeNil (f g) ()) (\ (g: (y:a) -> CompM (b y)) -> g x); @@ -2462,7 +2468,10 @@ multiFixM : (lrts:LetRecTypes) -> lrtPi lrts (lrtTupleType lrts) -> multiArgFixM : (lrt:LetRecType) -> (lrtToType lrt -> lrtToType lrt) -> lrtToType lrt; multiArgFixM lrt F = - (multiFixM (LRT_Cons lrt LRT_Nil) (\ (f:lrtToType lrt) -> (F f, ()))).(1); + (multiFixM + (LRT_Cons lrt LRT_Nil) + (\ (f:lrtToType lrt) -> consTuple (lrtToType lrt) TypeNil (F f) ()) + ).0; -- Test computations @@ -2524,7 +2533,7 @@ test_fun6 x = (Vec 64 Bool) (\ (f1:(Vec 64 Bool -> CompM (Vec 64 Bool))) (f2:(Vec 64 Bool -> CompM (Vec 64 Bool))) -> - (f2, (f1, ()))) + (f2, f1)) (\ (f1:(Vec 64 Bool -> CompM (Vec 64 Bool))) (f2:(Vec 64 Bool -> CompM (Vec 64 Bool))) -> f1 x); diff --git a/saw-core/src/Verifier/SAW/Conversion.hs b/saw-core/src/Verifier/SAW/Conversion.hs index 2d6e2615d4..11f1ac3641 100644 --- a/saw-core/src/Verifier/SAW/Conversion.hs +++ b/saw-core/src/Verifier/SAW/Conversion.hs @@ -391,14 +391,13 @@ pureApp mx y = do mkTermF (App x y) mkTuple :: [TermBuilder Term] -> TermBuilder Term -mkTuple [] = mkTermF (FTermF UnitValue) -mkTuple (t : ts) = mkTermF . FTermF =<< (PairValue <$> t <*> mkTuple ts) +mkTuple ts = mkTermF . FTermF . TupleValue . V.fromList =<< sequence ts +-- | Zero-indexed tuple field selection. mkTupleSelector :: Int -> Term -> TermBuilder Term mkTupleSelector i t - | i == 1 = mkTermF (FTermF (PairLeft t)) - | i > 1 = mkTermF (FTermF (PairRight t)) >>= mkTupleSelector (i - 1) - | otherwise = panic "Verifier.SAW.Conversion.mkTupleSelector" ["non-positive index:", show i] + | i < 0 = panic "Verifier.SAW.Conversion.mkTupleSelector" ["non-positive index:", show i] + | otherwise = mkTermF (FTermF (TupleSelector t i)) mkCtor :: PrimName Term -> [TermBuilder Term] -> [TermBuilder Term] -> TermBuilder Term mkCtor i paramsB argsB = diff --git a/saw-core/src/Verifier/SAW/ExternalFormat.hs b/saw-core/src/Verifier/SAW/ExternalFormat.hs index 851fbf2b57..63ea7734d9 100644 --- a/saw-core/src/Verifier/SAW/ExternalFormat.hs +++ b/saw-core/src/Verifier/SAW/ExternalFormat.hs @@ -145,12 +145,8 @@ scWriteExternal t0 = Primitive ec -> do stashPrimName ec pure $ unwords ["Primitive", show (primVarIndex ec), show (primType ec)] - UnitValue -> pure $ unwords ["Unit"] - UnitType -> pure $ unwords ["UnitT"] - PairValue x y -> pure $ unwords ["Pair", show x, show y] - PairType x y -> pure $ unwords ["PairT", show x, show y] - PairLeft e -> pure $ unwords ["ProjL", show e] - PairRight e -> pure $ unwords ["ProjR", show e] + TupleValue xs -> pure $ unwords ("Tuple" : map show (V.toList xs)) + TupleSelector x i -> pure $ unwords ["TupleSelector", show x, show i] CtorApp i ps es -> do stashPrimName i pure $ unwords ("Ctor" : show (primVarIndex i) : show (primType i) : @@ -302,12 +298,9 @@ scReadExternal sc input = ["Constant",i,t,e] -> Constant <$> readEC i t <*> (Just <$> readIdx e) ["ConstantOpaque",i,t] -> Constant <$> readEC i t <*> pure Nothing ["Primitive", i, t] -> FTermF <$> (Primitive <$> readPrimName i t) - ["Unit"] -> pure $ FTermF UnitValue - ["UnitT"] -> pure $ FTermF UnitType - ["Pair", x, y] -> FTermF <$> (PairValue <$> readIdx x <*> readIdx y) - ["PairT", x, y] -> FTermF <$> (PairType <$> readIdx x <*> readIdx y) - ["ProjL", x] -> FTermF <$> (PairLeft <$> readIdx x) - ["ProjR", x] -> FTermF <$> (PairRight <$> readIdx x) + ("Tuple" : xs) -> FTermF <$> (TupleValue <$> (V.fromList <$> traverse readIdx xs)) + ["TupleSelector", x, i] + -> FTermF <$> (TupleSelector <$> readIdx x <*> pure (read i)) ("Ctor" : i : t : (separateArgs -> Just (ps, es))) -> FTermF <$> (CtorApp <$> readPrimName i t <*> traverse readIdx ps <*> traverse readIdx es) ("Data" : i : t : (separateArgs -> Just (ps, es))) -> diff --git a/saw-core/src/Verifier/SAW/Grammar.y b/saw-core/src/Verifier/SAW/Grammar.y index a37fee8d01..1eb5a34c36 100644 --- a/saw-core/src/Verifier/SAW/Grammar.y +++ b/saw-core/src/Verifier/SAW/Grammar.y @@ -157,18 +157,12 @@ Term : LTerm { $1 } -- Term with uses of pi and lambda, but no type ascriptions LTerm :: { Term } -LTerm : ProdTerm { $1 } +LTerm : AppTerm { $1 } | PiArg '->' LTerm { Pi (pos $2) $1 $3 } | '\\' VarCtx '->' LTerm { Lambda (pos $1) $2 $4 } PiArg :: { [(TermVar, Term)] } -PiArg : ProdTerm { mkPiArg $1 } - --- Term formed from infix product type operator (right-associative) -ProdTerm :: { Term } -ProdTerm - : AppTerm { $1 } - | AppTerm '*' ProdTerm { PairType (pos $1) $1 $3 } +PiArg : AppTerm { mkPiArg $1 } -- Term formed from applications of atomic expressions AppTerm :: { Term } @@ -187,13 +181,12 @@ AtomTerm | 'isort' nat { Sort (pos $1) (mkSort (tokNat (val $2))) True } | AtomTerm '.' Ident { RecordProj $1 (val $3) } | AtomTerm '.' IdentRec {% parseRecursorProj $1 $3 } - | AtomTerm '.' nat {% parseTupleSelector $1 (fmap tokNat $3) } + | AtomTerm '.' nat { mkTupleSelector $1 (tokNat (val $3)) } | '(' sepBy(Term, ',') ')' { mkTupleValue (pos $1) $2 } | '#' '(' sepBy(Term, ',') ')' { mkTupleType (pos $1) $3 } | '[' sepBy(Term, ',') ']' { VecLit (pos $1) $2 } | '{' sepBy(FieldValue, ',') '}' { RecordValue (pos $1) $2 } | '#' '{' sepBy(FieldType, ',') '}' { RecordType (pos $1) $3 } - | AtomTerm '.' '(' nat ')' {% mkTupleProj $1 (tokNat (val $4)) } Ident :: { PosPair Text } Ident : ident { fmap (Text.pack . tokIdent) $1 } @@ -336,14 +329,6 @@ mkPiArg (TypeConstraint (exprAsIdentList -> Just xs) _ t) = map (\x -> (x, t)) xs mkPiArg lhs = [(UnusedVar (pos lhs), lhs)] --- | Parse a tuple projection of the form @t.(1)@ or @t.(2)@ -mkTupleProj :: Term -> Natural -> Parser Term -mkTupleProj t 1 = return $ PairLeft t -mkTupleProj t 2 = return $ PairRight t -mkTupleProj t _ = - do addParseError (pos t) "Projections must be either .(1) or .(2)" - return (badTerm (pos t)) - -- | Parse a term as a dotted list of strings parseModuleName :: Term -> Maybe [Text] parseModuleName (RecordProj t fname) = (++ [fname]) <$> parseModuleName t @@ -357,12 +342,6 @@ parseRecursorProj t _ = do addParseError (pos t) "Malformed recursor projection" return (badTerm (pos t)) -parseTupleSelector :: Term -> PosPair Natural -> Parser Term -parseTupleSelector t i = - if val i >= 1 then return (mkTupleSelector t (val i)) else - do addParseError (pos t) "non-positive tuple projection index" - return (badTerm (pos t)) - -- | Create a module name given a list of strings with the top-most -- module name given first. mkPosModuleName :: [PosPair Text] -> PosPair ModuleName diff --git a/saw-core/src/Verifier/SAW/OpenTerm.hs b/saw-core/src/Verifier/SAW/OpenTerm.hs index 57c1fd7ad0..dbf8dfbd5b 100644 --- a/saw-core/src/Verifier/SAW/OpenTerm.hs +++ b/saw-core/src/Verifier/SAW/OpenTerm.hs @@ -29,8 +29,7 @@ module Verifier.SAW.OpenTerm ( trueOpenTerm, falseOpenTerm, boolOpenTerm, boolTypeOpenTerm, arrayValueOpenTerm, vectorTypeOpenTerm, bvLitOpenTerm, bvTypeOpenTerm, pairOpenTerm, pairTypeOpenTerm, pairLeftOpenTerm, pairRightOpenTerm, - tupleOpenTerm, tupleTypeOpenTerm, projTupleOpenTerm, - tupleOpenTerm', tupleTypeOpenTerm', + tupleOpenTerm, tupleTypeOpenTerm, typeListOpenTerm, projTupleOpenTerm, recordOpenTerm, recordTypeOpenTerm, projRecordOpenTerm, ctorOpenTerm, dataTypeOpenTerm, globalOpenTerm, extCnsOpenTerm, applyOpenTerm, applyOpenTermMulti, applyGlobalOpenTerm, @@ -139,11 +138,11 @@ natOpenTerm = flatOpenTerm . NatLit -- | The 'OpenTerm' for the unit value unitOpenTerm :: OpenTerm -unitOpenTerm = flatOpenTerm UnitValue +unitOpenTerm = tupleOpenTerm [] -- | The 'OpenTerm' for the unit type unitTypeOpenTerm :: OpenTerm -unitTypeOpenTerm = flatOpenTerm UnitType +unitTypeOpenTerm = tupleTypeOpenTerm [] -- | Build a SAW core string literal. stringLitOpenTerm :: Text -> OpenTerm @@ -192,43 +191,37 @@ bvTypeOpenTerm n = -- | Build an 'OpenTerm' for a pair pairOpenTerm :: OpenTerm -> OpenTerm -> OpenTerm -pairOpenTerm t1 t2 = flatOpenTerm $ PairValue t1 t2 +pairOpenTerm t1 t2 = tupleOpenTerm [t1, t2] -- | Build an 'OpenTerm' for a pair type pairTypeOpenTerm :: OpenTerm -> OpenTerm -> OpenTerm -pairTypeOpenTerm t1 t2 = flatOpenTerm $ PairType t1 t2 +pairTypeOpenTerm t1 t2 = tupleTypeOpenTerm [t1, t2] -- | Build an 'OpenTerm' for the left projection of a pair pairLeftOpenTerm :: OpenTerm -> OpenTerm -pairLeftOpenTerm t = flatOpenTerm $ PairLeft t +pairLeftOpenTerm t = projTupleOpenTerm 0 t -- | Build an 'OpenTerm' for the right projection of a pair pairRightOpenTerm :: OpenTerm -> OpenTerm -pairRightOpenTerm t = flatOpenTerm $ PairRight t +pairRightOpenTerm t = projTupleOpenTerm 1 t --- | Build a right-nested tuple as an 'OpenTerm' +-- | Build a tuple as an 'OpenTerm' tupleOpenTerm :: [OpenTerm] -> OpenTerm -tupleOpenTerm = foldr pairOpenTerm unitOpenTerm +tupleOpenTerm ts = flatOpenTerm $ TupleValue (V.fromList ts) --- | Build a right-nested tuple type as an 'OpenTerm' +-- | Build a tuple type as an 'OpenTerm' tupleTypeOpenTerm :: [OpenTerm] -> OpenTerm -tupleTypeOpenTerm = foldr pairTypeOpenTerm unitTypeOpenTerm +tupleTypeOpenTerm ts = applyGlobalOpenTerm "Prelude.Tuple" [typeListOpenTerm ts] --- | Project the @n@th element of a right-nested tuple type +typeListOpenTerm :: [OpenTerm] -> OpenTerm +typeListOpenTerm [] = + ctorOpenTerm "Prelude.TypeNil" [] +typeListOpenTerm (t : ts) = + ctorOpenTerm "Prelude.TypeCons" [t, typeListOpenTerm ts] + +-- | Project the @n@th element of a tuple type projTupleOpenTerm :: Integer -> OpenTerm -> OpenTerm -projTupleOpenTerm 0 t = pairLeftOpenTerm t -projTupleOpenTerm i t = projTupleOpenTerm (i-1) (pairRightOpenTerm t) - --- | Build a right-nested tuple as an 'OpenTerm' but without adding a final unit --- as the right-most element -tupleOpenTerm' :: [OpenTerm] -> OpenTerm -tupleOpenTerm' [] = unitOpenTerm -tupleOpenTerm' ts = foldr1 pairTypeOpenTerm ts - --- | Build a right-nested tuple type as an 'OpenTerm' -tupleTypeOpenTerm' :: [OpenTerm] -> OpenTerm -tupleTypeOpenTerm' [] = unitTypeOpenTerm -tupleTypeOpenTerm' ts = foldr1 pairTypeOpenTerm ts +projTupleOpenTerm i t = flatOpenTerm $ TupleSelector t (fromInteger i) -- FIXME: unchecked fromInteger -- | Build a record value as an 'OpenTerm' recordOpenTerm :: [(FieldName, OpenTerm)] -> OpenTerm diff --git a/saw-core/src/Verifier/SAW/Prelude/Constants.hs b/saw-core/src/Verifier/SAW/Prelude/Constants.hs index dd72f875bf..75695d212f 100644 --- a/saw-core/src/Verifier/SAW/Prelude/Constants.hs +++ b/saw-core/src/Verifier/SAW/Prelude/Constants.hs @@ -25,6 +25,12 @@ preludeZeroIdent = mkIdent preludeModuleName "Zero" preludeSuccIdent :: Ident preludeSuccIdent = mkIdent preludeModuleName "Succ" +preludeTypeNilIdent :: Ident +preludeTypeNilIdent = mkIdent preludeModuleName "TypeNil" + +preludeTypeConsIdent :: Ident +preludeTypeConsIdent = mkIdent preludeModuleName "TypeCons" + preludeIntegerIdent :: Ident preludeIntegerIdent = mkIdent preludeModuleName "Integer" diff --git a/saw-core/src/Verifier/SAW/Recognizer.hs b/saw-core/src/Verifier/SAW/Recognizer.hs index ba3d81ead7..25b80ca23e 100644 --- a/saw-core/src/Verifier/SAW/Recognizer.hs +++ b/saw-core/src/Verifier/SAW/Recognizer.hs @@ -27,9 +27,6 @@ module Verifier.SAW.Recognizer , asApp , (<@>), (@>), (<@) , asApplyAll - , asPairType - , asPairValue - , asPairSelector , asTupleType , asTupleValue , asTupleSelector @@ -76,6 +73,7 @@ import Control.Monad import Data.Map (Map) import qualified Data.Map as Map import Data.Text (Text) +import qualified Data.Vector as V import Numeric.Natural (Natural) import Verifier.SAW.Term.Functor @@ -161,63 +159,29 @@ asApplyAll = go [] Nothing -> (t, xs) Just (t', x) -> go (x : xs) t' -asPairType :: Recognizer Term (Term, Term) -asPairType t = do - ftf <- asFTermF t - case ftf of - PairType x y -> return (x, y) - _ -> Nothing - -asPairValue :: Recognizer Term (Term, Term) -asPairValue t = do - ftf <- asFTermF t - case ftf of - PairValue x y -> return (x, y) - _ -> Nothing - -asPairSelector :: Recognizer Term (Term, Bool) -asPairSelector t = do - ftf <- asFTermF t - case ftf of - PairLeft x -> return (x, False) - PairRight x -> return (x, True) - _ -> Nothing - -destTupleType :: Term -> [Term] -destTupleType t = - case unwrapTermF t of - FTermF (PairType x y) -> x : destTupleType y - _ -> [t] - -destTupleValue :: Term -> [Term] -destTupleValue t = - case unwrapTermF t of - FTermF (PairValue x y) -> x : destTupleType y - _ -> [t] +asTypeList :: Recognizer Term [Term] +asTypeList (asCtor -> Just (c, [])) + | primName c == preludeTypeNilIdent = Just [] +asTypeList (asCtor -> Just (c, [t, asTypeList -> Just ts])) + | primName c == preludeTypeConsIdent = Just (t : ts) +asTypeList _ = Nothing asTupleType :: Recognizer Term [Term] -asTupleType t = - do ftf <- asFTermF t - case ftf of - UnitType -> Just [] - PairType x y -> Just (x : destTupleType y) - _ -> Nothing +asTupleType = isGlobalDef "Prelude.Tuple" @> asTypeList asTupleValue :: Recognizer Term [Term] asTupleValue t = do ftf <- asFTermF t case ftf of - UnitValue -> Just [] - PairValue x y -> Just (x : destTupleValue y) + TupleValue xs -> Just (V.toList xs) _ -> Nothing asTupleSelector :: Recognizer Term (Term, Int) -asTupleSelector t = do - ftf <- asFTermF t - case ftf of - PairLeft x -> return (x, 1) - PairRight y -> do (x, i) <- asTupleSelector y; return (x, i+1) - _ -> Nothing +asTupleSelector t = + do ftf <- asFTermF t + case ftf of + TupleSelector x i -> Just (x, i) + _ -> Nothing asRecordType :: Recognizer Term (Map FieldName Term) asRecordType t = do diff --git a/saw-core/src/Verifier/SAW/Rewriter.hs b/saw-core/src/Verifier/SAW/Rewriter.hs index 4d267f3435..d338d7e4f7 100644 --- a/saw-core/src/Verifier/SAW/Rewriter.hs +++ b/saw-core/src/Verifier/SAW/Rewriter.hs @@ -71,6 +71,7 @@ import qualified Data.List as List import qualified Data.Map as Map import Data.Set (Set) import qualified Data.Set as Set +import qualified Data.Vector as V import Control.Monad.Trans.Writer.Strict import Numeric.Natural @@ -537,9 +538,9 @@ asBetaRedex t = asPairRedex :: R.Recognizer Term Term asPairRedex t = - do (u, b) <- R.asPairSelector t - (x, y) <- R.asPairValue u - return (if b then y else x) + do (u, i) <- R.asTupleSelector t + ts <- R.asTupleValue u + return (ts !! i) asRecordRedex :: R.Recognizer Term Term asRecordRedex t = @@ -601,7 +602,7 @@ appCollectedArgs t = step0 (unshared t) [] step1 f args = foldl (++) [] (map (\ x -> step2 f $ unshared x) args) -- step2: analyse an arg. look inside tuples, sequences (TBD), more calls to f step2 :: TermF Term -> TermF Term -> [Term] - step2 f (FTermF (PairValue x y)) = (step2 f $ unshared x) ++ (step2 f $ unshared y) + step2 f (FTermF (TupleValue xs)) = concatMap (step2 f . unshared) (V.toList xs) step2 f (s@(App g a)) = possibly_curried_args s f (unshared g) (step2 f $ unshared a) step2 _ a = [Unshared a] -- @@ -736,12 +737,8 @@ rewriteSharedTermTypeSafe sc ss t0 = FlatTermF Term -> IO (FlatTermF Term) rewriteFTermF ftf = case ftf of - UnitValue -> return ftf - UnitType -> return ftf - PairValue{} -> traverse rewriteAll ftf - PairType{} -> return ftf -- doesn't matter - PairLeft{} -> traverse rewriteAll ftf - PairRight{} -> traverse rewriteAll ftf + TupleValue{} -> traverse rewriteAll ftf + TupleSelector{} -> traverse rewriteAll ftf -- NOTE: we don't rewrite arguments of constructors, datatypes, or -- recursors because of dependent types, as we could potentially cause diff --git a/saw-core/src/Verifier/SAW/SCTypeCheck.hs b/saw-core/src/Verifier/SAW/SCTypeCheck.hs index 9d41822947..15101ebcba 100644 --- a/saw-core/src/Verifier/SAW/SCTypeCheck.hs +++ b/saw-core/src/Verifier/SAW/SCTypeCheck.hs @@ -5,6 +5,7 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} {- | Module : Verifier.SAW.SCTypeCheck @@ -482,18 +483,10 @@ instance TypeInfer (TermF TypedTerm) where instance TypeInfer (FlatTermF TypedTerm) where typeInfer (Primitive ec) = typeCheckWHNF $ typedVal $ primType ec - typeInfer UnitValue = liftTCM scUnitType - typeInfer UnitType = liftTCM scSort (mkSort 0) - typeInfer (PairValue (TypedTerm _ tx) (TypedTerm _ ty)) = - liftTCM scPairType tx ty - typeInfer (PairType (TypedTerm _ tx) (TypedTerm _ ty)) = - do sx <- ensureSort tx - sy <- ensureSort ty - liftTCM scSort (max sx sy) - typeInfer (PairLeft (TypedTerm _ tp)) = - ensurePairType tp >>= \(t1,_) -> return t1 - typeInfer (PairRight (TypedTerm _ tp)) = - ensurePairType tp >>= \(_,t2) -> return t2 + typeInfer (TupleValue tts) = + liftTCM scTupleType (map typedType (V.toList tts)) + typeInfer (TupleSelector (TypedTerm _ tp) i) = + ensureTupleType tp >>= \ts -> pure (ts !! i) typeInfer (DataTypeApp d params args) = -- Look up the DataType structure, check the length of the params and args, @@ -581,10 +574,31 @@ ensureRecognizer f err trm = ensureSort :: Term -> TCM Sort ensureSort tp = ensureRecognizer asSort (NotSort tp) tp --- | Ensure a 'Term' is a pair type, normalizing if necessary, and return the --- two components of that pair type -ensurePairType :: Term -> TCM (Term, Term) -ensurePairType tp = ensureRecognizer asPairType (NotSort tp) tp +-- | Ensure a 'Term' is a @TypeList@, normalizing if necessary, and +-- return the components of that @TypeList@. Note that this function +-- cannot be correctly implemented with 'ensureRecognizer', because +-- that function does not normalize deeply enough. +ensureTypeList :: Term -> TCM (Maybe [Term]) +ensureTypeList t = + do t' <- typeCheckWHNF t + case asCtor t' of + Just (c, []) + | primName c == "Prelude.TypeNil" -> pure (Just []) + Just (c, [t1, t2]) + | primName c == "Prelude.TypeCons" -> + do ts <- ensureTypeList t2 + pure ((t1 :) <$> ts) + _ -> pure Nothing + +-- | Ensure a 'Term' is a tuple type, normalizing if necessary, and +-- return the components of that tuple type. Note that this function +-- cannot be correctly implemented with 'ensureRecognizer', because +-- that function does not normalize deeply enough. +ensureTupleType :: Term -> TCM [Term] +ensureTupleType tp = + do let err = NotTupleType tp + t <- ensureRecognizer (isGlobalDef "Prelude.Tuple" @> Just) err tp + maybe (throwTCError err) pure =<< ensureTypeList t -- | Ensure a 'Term' is a record type, normalizing if necessary, and return the -- components of that record type diff --git a/saw-core/src/Verifier/SAW/SharedTerm.hs b/saw-core/src/Verifier/SAW/SharedTerm.hs index 4a8cde027a..84ce84503d 100644 --- a/saw-core/src/Verifier/SAW/SharedTerm.hs +++ b/saw-core/src/Verifier/SAW/SharedTerm.hs @@ -124,14 +124,9 @@ module Verifier.SAW.SharedTerm , scEqTrue , scBool , scBoolType - -- *** Unit, pairs, and tuples + -- *** Tuples , scUnitValue , scUnitType - , scPairValue - , scPairType - , scPairLeft - , scPairRight - , scPairValueReduced , scTuple , scTupleType , scTupleSelector @@ -777,7 +772,7 @@ scReduceNatRecursor sc rec crec n data WHNFElim = ElimApp Term | ElimProj FieldName - | ElimPair Bool + | ElimTuple Int | ElimRecursor Term (CompiledRecursor Term) [Term] -- | Test if a term is a constructor application that should be converted to a @@ -818,9 +813,11 @@ scWhnf sc t0 = go xs (convertsToNat -> Just k) = scFlatTermF sc (NatLit k) >>= go xs go xs (asApp -> Just (t, x)) = go (ElimApp x : xs) t go xs (asRecordSelector -> Just (t, n)) = go (ElimProj n : xs) t - go xs (asPairSelector -> Just (t, i)) = go (ElimPair i : xs) t + go xs (asTupleSelector -> Just (t, i)) = go (ElimTuple i : xs) t go (ElimApp x : xs) (asLambda -> Just (_, _, body)) = betaReduce xs [x] body - go (ElimPair i : xs) (asPairValue -> Just (a, b)) = go xs (if i then b else a) + go (ElimTuple i : xs) (asTupleValue -> Just ts) = case V.fromList ts V.!? i of + Just t -> go xs t + Nothing -> error "scWhnf: invalid tuple index" go (ElimProj fld : xs) (asRecordValue -> Just elems) = case Map.lookup fld elems of Just t -> go xs t Nothing -> @@ -835,13 +832,6 @@ scWhnf sc t0 = go xs (asGlobalDef -> Just c) = scRequireDef sc c >>= tryDef c xs go xs (asRecursorApp -> Just (r, crec, ixs, arg)) = go (ElimRecursor r crec ixs : xs) arg - go xs (asPairValue -> Just (a, b)) = do b' <- memo b - t' <- scPairValue sc a b' - foldM reapply t' xs - go xs (asPairType -> Just (a, b)) = do a' <- memo a - b' <- memo b - t' <- scPairType sc a' b' - foldM reapply t' xs go xs (asRecordType -> Just elems) = do elems' <- mapM (\(i,t) -> (i,) <$> memo t) (Map.assocs elems) t' <- scRecordType sc elems' @@ -868,7 +858,7 @@ scWhnf sc t0 = reapply :: Term -> WHNFElim -> IO Term reapply t (ElimApp x) = scApply sc t x reapply t (ElimProj i) = scRecordSelect sc t i - reapply t (ElimPair i) = scPairSelector sc t i + reapply t (ElimTuple i) = scTupleSelector sc t i reapply t (ElimRecursor r _crec ixs) = scFlatTermF sc (RecursorApp r ixs t) @@ -1018,26 +1008,17 @@ scTypeOf' sc env t0 = State.evalStateT (memo t0) Map.empty ftermf tf = case tf of Primitive ec -> return (primType ec) - UnitValue -> lift $ scUnitType sc - UnitType -> lift $ scSort sc (mkSort 0) - PairValue x y -> do - tx <- memo x - ty <- memo y - lift $ scPairType sc tx ty - PairType x y -> do - sx <- sort x - sy <- sort y - lift $ scSort sc (max sx sy) - PairLeft t -> do - tp <- (liftIO . scWhnf sc) =<< memo t - case asPairType tp of - Just (t1, _) -> return t1 - Nothing -> fail "scTypeOf: type error: expected pair type" - PairRight t -> do - tp <- (liftIO . scWhnf sc) =<< memo t - case asPairType tp of - Just (_, t2) -> return t2 - Nothing -> fail "scTypeOf: type error: expected pair type" + TupleValue xs -> + liftIO . scTupleType sc =<< traverse memo (V.toList xs) + TupleSelector x i -> + do tp <- (liftIO . scWhnf sc) =<< memo x + case asTupleType tp of + Nothing -> fail "scTypeOf: type error: expected pair type" + Just ts -> + case V.fromList ts V.!? i of + Nothing -> + fail $ "scTypeOf: tuple selector out of range (" ++ show i ++ " > " ++ show (length ts) ++ ")" + Just t -> pure t CtorApp c params args -> do lift $ foldM (reducePi sc) (primType c) (params ++ args) DataTypeApp dt params args -> do @@ -1354,69 +1335,41 @@ scRecordType sc elem_tps = scFlatTermF sc (RecordType elem_tps) -- | Create a unit-valued term. scUnitValue :: SharedContext -> IO Term -scUnitValue sc = scFlatTermF sc UnitValue +scUnitValue sc = scTuple sc [] -- | Create a term representing the unit type. scUnitType :: SharedContext -> IO Term -scUnitType sc = scFlatTermF sc UnitType - --- | Create a pair term from two terms. -scPairValue :: SharedContext - -> Term -- ^ The left projection - -> Term -- ^ The right projection - -> IO Term -scPairValue sc x y = scFlatTermF sc (PairValue x y) - --- | Create a term representing a pair type from two other terms, each --- representing a type. -scPairType :: SharedContext - -> Term -- ^ Left projection type - -> Term -- ^ Right projection type - -> IO Term -scPairType sc x y = scFlatTermF sc (PairType x y) +scUnitType sc = scTupleType sc [] -- | Create an n-place tuple from a list (of length n) of 'Term's. -- Note that tuples are nested pairs, associating to the right e.g. -- @(a, (b, (c, d)))@. scTuple :: SharedContext -> [Term] -> IO Term -scTuple sc [] = scUnitValue sc -scTuple _ [t] = return t -scTuple sc (t : ts) = scPairValue sc t =<< scTuple sc ts +scTuple sc ts = scFlatTermF sc (TupleValue (V.fromList ts)) + +scTypeList :: SharedContext -> [Term] -> IO Term +scTypeList sc [] = scCtorApp sc "Prelude.TypeNil" [] +scTypeList sc (t : ts) = + do ts' <- scTypeList sc ts + scCtorApp sc "Prelude.TypeCons" [t, ts'] -- | Create a term representing the type of an n-place tuple, from a list -- (of length n) of 'Term's, each representing a type. scTupleType :: SharedContext -> [Term] -> IO Term -scTupleType sc [] = scUnitType sc -scTupleType _ [t] = return t -scTupleType sc (t : ts) = scPairType sc t =<< scTupleType sc ts +scTupleType sc ts = + do ts' <- scTypeList sc ts + scGlobalApply sc "Prelude.Tuple" [ts'] --- | Create a term giving the left projection of a 'Term' representing a pair. -scPairLeft :: SharedContext -> Term -> IO Term -scPairLeft sc t = scFlatTermF sc (PairLeft t) - --- | Create a term giving the right projection of a 'Term' representing a pair. -scPairRight :: SharedContext -> Term -> IO Term -scPairRight sc t = scFlatTermF sc (PairRight t) - --- | Create a term representing either the left or right projection of the --- given 'Term', depending on the given 'Bool': left if @False@, right if @True@. -scPairSelector :: SharedContext -> Term -> Bool -> IO Term -scPairSelector sc t False = scPairLeft sc t -scPairSelector sc t True = scPairRight sc t - --- | @scTupleSelector sc t i n@ returns a term selecting the @i@th component of +-- | @scTupleSelector sc t i@ returns a term selecting the @i@th component of -- an @n@-place tuple 'Term', @t@. scTupleSelector :: - SharedContext -> Term -> - Int {- ^ 1-based index -} -> - Int {- ^ tuple size -} -> + SharedContext -> + Term {- ^ tuple -} -> + Int {- ^ 0-based index -} -> IO Term -scTupleSelector sc t i n - | n == 1 = return t - | i == 1 = scPairLeft sc t - | i > 1 = do t' <- scPairRight sc t - scTupleSelector sc t' (i - 1) (n - 1) - | otherwise = fail "scTupleSelector: non-positive index" +scTupleSelector sc t i + | i < 0 = fail "scTupleSelector: negative index" + | otherwise = scFlatTermF sc (TupleSelector t i) -- | Create a term representing the type of a non-dependent function, given a -- parameter and result type (as 'Term's). @@ -1549,20 +1502,25 @@ scGlobalApply sc i ts = do c <- scGlobalDef sc i scApplyAll sc c ts --- | An optimized variant of 'scPairValue' that will reduce pairs of --- the form @(x.L, x.R)@ to @x@. -scPairValueReduced :: SharedContext -> Term -> Term -> IO Term -scPairValueReduced sc x y = - case (unwrapTermF x, unwrapTermF y) of - (FTermF (PairLeft a), FTermF (PairRight b)) | a == b -> return a - _ -> scPairValue sc x y - -- | An optimized variant of 'scPairTuple' that will reduce tuples of --- the form @(x.1, x.2, x.3)@ to @x@. +-- the form @(x.0, x.1, x.2)@ to @x@. scTupleReduced :: SharedContext -> [Term] -> IO Term -scTupleReduced sc [] = scUnitValue sc -scTupleReduced _ [t] = return t -scTupleReduced sc (t : ts) = scPairValueReduced sc t =<< scTupleReduced sc ts +scTupleReduced sc ts = + case asTupleRedex ts of + Just t -> pure t + Nothing -> scTuple sc ts + +asTupleRedex :: [Term] -> Maybe Term +asTupleRedex [] = Nothing +asTupleRedex (t0 : ts0) = + do (x, i) <- asTupleSelector t0 + go x i ts0 + where + go x _ [] = Just x + go x i (t : ts) = + do (y, j) <- asTupleSelector t + guard (j == i + 1 && x == y) + go x j ts -- | An optimized variant of 'scVector' that will reduce vectors of -- the form @[at x 0, at x 1, at x 2, at x 3]@ to just @x@. diff --git a/saw-core/src/Verifier/SAW/Simulator.hs b/saw-core/src/Verifier/SAW/Simulator.hs index 5567802548..f2819cca85 100644 --- a/saw-core/src/Verifier/SAW/Simulator.hs +++ b/saw-core/src/Verifier/SAW/Simulator.hs @@ -47,6 +47,7 @@ import Data.IntMap (IntMap) import qualified Data.IntMap as IMap import Data.Text (Text) import qualified Data.Text as Text +import qualified Data.Vector as V import Data.Traversable import GHC.Stack @@ -142,8 +143,8 @@ evalTermF cfg lam recEval tf env = pure (VDependentPi (\x -> toTValue <$> lam t2 ((x,v) : env))) else do -- put dummy values in the environment; the term should never reference them - let val = ready VUnit - let tp = VUnitType + let val = ready (VTuple mempty) + let tp = VTupleType mempty VNondependentPi . toTValue <$> lam t2 ((val,tp):env) return $ TValue $ VPiType nm v body @@ -158,25 +159,11 @@ evalTermF cfg lam recEval tf env = do pn' <- traverse evalType pn simPrimitive cfg pn' - UnitValue -> return VUnit + TupleValue xs -> VTuple <$> traverse recEvalDelay xs - UnitType -> return $ TValue VUnitType - - PairValue x y -> do tx <- recEvalDelay x - ty <- recEvalDelay y - return $ VPair tx ty - - PairType x y -> do vx <- evalType x - vy <- evalType y - return $ TValue $ VPairType vx vy - - PairLeft x -> recEval x >>= \case - VPair l _r -> force l - _ -> simNeutral cfg env (NeutralPairLeft (NeutralBox x)) - - PairRight x -> recEval x >>= \case - VPair _l r -> force r - _ -> simNeutral cfg env (NeutralPairRight (NeutralBox x)) + TupleSelector x i -> recEval x >>= \case + VTuple ys -> force (ys V.! i) + _ -> simNeutral cfg env (NeutralTupleProj (NeutralBox x) i) CtorApp c ps ts -> do c' <- traverse evalType c ps' <- mapM recEvalDelay ps diff --git a/saw-core/src/Verifier/SAW/Simulator/Prims.hs b/saw-core/src/Verifier/SAW/Simulator/Prims.hs index e2a16c7614..f61d381927 100644 --- a/saw-core/src/Verifier/SAW/Simulator/Prims.hs +++ b/saw-core/src/Verifier/SAW/Simulator/Prims.hs @@ -1194,9 +1194,9 @@ muxValue bp tp0 b = value tp0 y <- g a value tp' x y - value VUnitType VUnit VUnit = return VUnit - value (VPairType t1 t2) (VPair x1 x2) (VPair y1 y2) = - VPair <$> thunk t1 x1 y1 <*> thunk t2 x2 y2 + value (VTupleType ts) (VTuple xs) (VTuple ys) + | V.length ts == V.length xs && V.length ts == V.length ys + = VTuple <$> V.sequence (V.zipWith3 thunk ts xs ys) value (VRecordType fs) (VRecordValue elems1) (VRecordValue elems2) = do let em1 = Map.fromList elems1 diff --git a/saw-core/src/Verifier/SAW/Simulator/TermModel.hs b/saw-core/src/Verifier/SAW/Simulator/TermModel.hs index c92fe468fe..eebd4fb6be 100644 --- a/saw-core/src/Verifier/SAW/Simulator/TermModel.hs +++ b/saw-core/src/Verifier/SAW/Simulator/TermModel.hs @@ -223,7 +223,6 @@ readBackTValue sc cfg = loop where loop tv = case tv of - VUnitType -> scUnitType sc VBoolType -> scBoolType sc VStringType -> scStringType sc VIntType -> scIntegerType sc @@ -239,10 +238,9 @@ readBackTValue sc cfg = loop do n' <- scNat sc n t' <- loop t scVecType sc n' t' - VPairType t1 t2 -> - do t1' <- loop t1 - t2' <- loop t2 - scPairType sc t1' t2' + VTupleType ts -> + do ts' <- traverse loop (V.toList ts) + scTupleType sc ts' VRecordType fs -> do fs' <- traverse (traverse loop) fs scRecordType sc fs' @@ -301,7 +299,6 @@ reflectTerm :: reflectTerm sc cfg = loop where loop tv tm = case tv of - VUnitType -> pure VUnit VBoolType -> return (VBool (Left tm)) VIntType -> return (VInt (Left tm)) VIntModType m -> return (VIntMod m (Left tm)) @@ -337,7 +334,7 @@ reflectTerm sc cfg = loop VStringType{} -> return (VExtra (VExtraTerm tv tm)) VRecordType{} -> return (VExtra (VExtraTerm tv tm)) - VPairType{} -> return (VExtra (VExtraTerm tv tm)) + VTupleType{} -> return (VExtra (VExtraTerm tv tm)) VDataType{} -> return (VExtra (VExtraTerm tv tm)) VRecursorType{} -> return (VExtra (VExtraTerm tv tm)) VTyTerm{} -> return (VExtra (VExtraTerm tv tm)) @@ -353,8 +350,6 @@ readBackValue :: IO Term readBackValue sc cfg = loop where - loop _ VUnit = scUnitValue sc - loop _ (VNat n) = scNat sc n loop _ (VBVToNat w n) = @@ -393,10 +388,9 @@ readBackValue sc cfg = loop do (ecs, tm) <- readBackFuns tv v scAbstractExtsEtaCollapse sc ecs tm - loop (VPairType t1 t2) (VPair v1 v2) = - do tm1 <- loop t1 =<< force v1 - tm2 <- loop t2 =<< force v2 - scPairValueReduced sc tm1 tm2 + loop (VTupleType ts) (VTuple vs) | V.length ts == V.length vs = + do tms <- V.sequence $ V.zipWith (\t v -> loop t =<< force v) ts vs + scTupleReduced sc (V.toList tms) loop (VVecType _n tp) (VVector vs) = do tp' <- readBackTValue sc cfg tp diff --git a/saw-core/src/Verifier/SAW/Simulator/Value.hs b/saw-core/src/Verifier/SAW/Simulator/Value.hs index cf4f67a2f2..5ef987ca87 100644 --- a/saw-core/src/Verifier/SAW/Simulator/Value.hs +++ b/saw-core/src/Verifier/SAW/Simulator/Value.hs @@ -51,8 +51,7 @@ The concrete parameters to use are computed from the name using a collection of type families (e.g., 'EvalM', 'VBool', etc.). -} data Value l = VFun !LocalName !(Thunk l -> MValue l) - | VUnit - | VPair (Thunk l) (Thunk l) -- TODO: should second component be strict? + | VTuple !(Vector (Thunk l)) | VCtorApp !(PrimName (TValue l)) ![Thunk l] ![Thunk l] | VVector !(Vector (Thunk l)) | VBool (VBool l) @@ -83,9 +82,8 @@ data TValue l | VArrayType !(TValue l) !(TValue l) | VPiType LocalName !(TValue l) !(PiBody l) | VStringType - | VUnitType - | VPairType !(TValue l) !(TValue l) | VDataType !(PrimName (TValue l)) ![Value l] ![Value l] + | VTupleType !(Vector (TValue l)) | VRecordType ![(FieldName, TValue l)] | VSort !Sort | VRecursorType @@ -105,8 +103,7 @@ data PiBody l -- is being hidden, etc.) data NeutralTerm = NeutralBox Term -- the thing blocking evaluation - | NeutralPairLeft NeutralTerm -- left pair projection - | NeutralPairRight NeutralTerm -- right pair projection + | NeutralTupleProj NeutralTerm Int -- tuple projection | NeutralRecordProj NeutralTerm FieldName -- record projection | NeutralApp NeutralTerm Term -- function application | NeutralRecursor @@ -174,8 +171,7 @@ instance Show (Extra l) => Show (Value l) where showsPrec p v = case v of VFun {} -> showString "<>" - VUnit -> showString "()" - VPair{} -> showString "<>" + VTuple xv -> showString "<<" . shows (V.length xv) . showString "-tuple>>" VCtorApp s _ps _xv -> shows (primName s) VVector xv -> showList (toList xv) VBool _ -> showString "<>" @@ -207,8 +203,7 @@ instance Show (Extra l) => Show (TValue l) where VArrayType{} -> showString "Array" VPiType _ t _ -> showParen True (shows t . showString " -> ...") - VUnitType -> showString "#()" - VPairType x y -> showParen True (shows x . showString " * " . shows y) + VTupleType ts -> showString "#" . showParen True (showCommas (map shows (V.toList ts))) VDataType s ps vs | null (ps++vs) -> shows s | otherwise -> shows s . showList (ps++vs) @@ -221,6 +216,10 @@ instance Show (Extra l) => Show (TValue l) where VRecursorType{} -> showString "RecursorType" VTyTerm _ tm -> showString "TyTerm (" . (\x -> showTerm tm ++ x) . showString ")" + where + showCommas [] = id + showCommas [x] = x + showCommas (x : xs) = x . showString "," . showCommas xs data Nil = Nil @@ -231,23 +230,10 @@ instance Show Nil where -- Basic operations on values vTuple :: VMonad l => [Thunk l] -> Value l -vTuple [] = VUnit -vTuple [_] = error "vTuple: unsupported 1-tuple" -vTuple [x, y] = VPair x y -vTuple (x : xs) = VPair x (ready (vTuple xs)) +vTuple xs = VTuple (V.fromList xs) vTupleType :: VMonad l => [TValue l] -> TValue l -vTupleType [] = VUnitType -vTupleType [t] = t -vTupleType (t : ts) = VPairType t (vTupleType ts) - -valPairLeft :: (HasCallStack, VMonad l, Show (Extra l)) => Value l -> MValue l -valPairLeft (VPair t1 _) = force t1 -valPairLeft v = panic "Verifier.SAW.Simulator.Value.valPairLeft" ["Not a pair value:", show v] - -valPairRight :: (HasCallStack, VMonad l, Show (Extra l)) => Value l -> MValue l -valPairRight (VPair _ t2) = force t2 -valPairRight v = panic "Verifier.SAW.Simulator.Value.valPairRight" ["Not a pair value:", show v] +vTupleType ts = VTupleType (V.fromList ts) vRecord :: Map FieldName (Thunk l) -> Value l vRecord m = VRecordValue (Map.assocs m) @@ -289,13 +275,8 @@ asFiniteTypeTValue v = VVecType n v1 -> do t1 <- asFiniteTypeTValue v1 return (FTVec n t1) - VUnitType -> return (FTTuple []) - VPairType v1 v2 -> do - t1 <- asFiniteTypeTValue v1 - t2 <- asFiniteTypeTValue v2 - case t2 of - FTTuple ts -> return (FTTuple (t1 : ts)) - _ -> return (FTTuple [t1, t2]) + VTupleType vs -> + FTTuple <$> traverse asFiniteTypeTValue (V.toList vs) VRecordType elem_tps -> FTRec <$> Map.fromList <$> mapM (\(fld,tp) -> (fld,) <$> asFiniteTypeTValue tp) elem_tps @@ -316,13 +297,8 @@ asFirstOrderTypeTValue v = VIntModType m -> return (FOTIntMod m) VArrayType a b -> FOTArray <$> asFirstOrderTypeTValue a <*> asFirstOrderTypeTValue b - VUnitType -> return (FOTTuple []) - VPairType v1 v2 -> do - t1 <- asFirstOrderTypeTValue v1 - t2 <- asFirstOrderTypeTValue v2 - case t2 of - FOTTuple ts -> return (FOTTuple (t1 : ts)) - _ -> return (FOTTuple [t1, t2]) + VTupleType vs -> + FOTTuple <$> traverse asFirstOrderTypeTValue (V.toList vs) VRecordType elem_tps -> FOTRec . Map.fromList <$> mapM (traverse asFirstOrderTypeTValue) elem_tps @@ -351,11 +327,9 @@ suffixTValue tv = b' <- suffixTValue b Just ("_Array" ++ a' ++ b') VPiType _ _ _ -> Nothing - VUnitType -> Just "_Unit" - VPairType a b -> - do a' <- suffixTValue a - b' <- suffixTValue b - Just ("_Pair" ++ a' ++ b') + VTupleType vs -> + do vs' <- traverse suffixTValue (V.toList vs) + Just ("_Tuple" ++ show (V.length vs) ++ concat vs') VStringType -> Nothing VDataType {} -> Nothing @@ -369,10 +343,8 @@ neutralToTerm :: NeutralTerm -> Term neutralToTerm = loop where loop (NeutralBox tm) = tm - loop (NeutralPairLeft nt) = - Unshared (FTermF (PairLeft (loop nt))) - loop (NeutralPairRight nt) = - Unshared (FTermF (PairRight (loop nt))) + loop (NeutralTupleProj nt i) = + Unshared (FTermF (TupleSelector (loop nt) i)) loop (NeutralRecordProj nt f) = Unshared (FTermF (RecordProj (loop nt) f)) loop (NeutralApp nt arg) = @@ -388,10 +360,9 @@ neutralToSharedTerm :: SharedContext -> NeutralTerm -> IO Term neutralToSharedTerm sc = loop where loop (NeutralBox tm) = pure tm - loop (NeutralPairLeft nt) = - scFlatTermF sc . PairLeft =<< loop nt - loop (NeutralPairRight nt) = - scFlatTermF sc . PairRight =<< loop nt + loop (NeutralTupleProj nt i) = + do tm <- loop nt + scFlatTermF sc (TupleSelector tm i) loop (NeutralRecordProj nt f) = do tm <- loop nt scFlatTermF sc (RecordProj tm f) diff --git a/saw-core/src/Verifier/SAW/Term/Functor.hs b/saw-core/src/Verifier/SAW/Term/Functor.hs index 59c3e3276d..a6497e94c8 100644 --- a/saw-core/src/Verifier/SAW/Term/Functor.hs +++ b/saw-core/src/Verifier/SAW/Term/Functor.hs @@ -143,14 +143,9 @@ data FlatTermF e -- | A primitive or axiom without a definition. = Primitive !(PrimName e) - -- Tuples are represented as nested pairs, grouped to the right, - -- terminated with unit at the end. - | UnitValue - | UnitType - | PairValue e e - | PairType e e - | PairLeft e - | PairRight e + | TupleValue (Vector e) + -- | A zero-indexed tuple field selection. + | TupleSelector e Int -- | An inductively-defined type, applied to parameters and type indices | DataTypeApp !(PrimName e) ![e] ![e] @@ -273,19 +268,16 @@ zipRec f (CompiledRecursor d1 ps1 m1 mty1 es1 ord1) (CompiledRecursor d2 ps2 m2 -- | Zip a binary function @f@ over a pair of 'FlatTermF's by applying @f@ -- pointwise to immediate subterms, if the two 'FlatTermF's are the same --- constructor; otherwise, return 'Nothing' if they use different constructors +-- constructor; otherwise, return 'Nothing' if they use different constructors. zipWithFlatTermF :: (x -> y -> z) -> FlatTermF x -> FlatTermF y -> Maybe (FlatTermF z) zipWithFlatTermF f = go where go (Primitive pn1) (Primitive pn2) = Primitive <$> zipPrimName f pn1 pn2 - go UnitValue UnitValue = Just UnitValue - go UnitType UnitType = Just UnitType - go (PairValue x1 x2) (PairValue y1 y2) = Just (PairValue (f x1 y1) (f x2 y2)) - go (PairType x1 x2) (PairType y1 y2) = Just (PairType (f x1 y1) (f x2 y2)) - go (PairLeft x) (PairLeft y) = Just (PairLeft (f x y)) - go (PairRight x) (PairRight y) = Just (PairLeft (f x y)) - + go (TupleValue xs) (TupleValue ys) + | V.length xs == V.length ys = Just $ TupleValue (V.zipWith f xs ys) + go (TupleSelector x i) (TupleSelector y j) + | i == j = Just (TupleSelector (f x y) i) go (CtorApp cx psx lx) (CtorApp cy psy ly) = do c <- zipPrimName f cx cy Just $ CtorApp c (zipWith f psx psy) (zipWith f lx ly) diff --git a/saw-core/src/Verifier/SAW/Term/Pretty.hs b/saw-core/src/Verifier/SAW/Term/Pretty.hs index 58cf0de96e..cf98712859 100644 --- a/saw-core/src/Verifier/SAW/Term/Pretty.hs +++ b/saw-core/src/Verifier/SAW/Term/Pretty.hs @@ -342,13 +342,22 @@ ppLetBlock defs body = ppEqn (var,d) = ppMemoVar var <+> pretty '=' <+> d --- | Pretty-print pairs as "(x, y)" -ppPair :: Prec -> SawDoc -> SawDoc -> SawDoc -ppPair prec x y = ppParensPrec prec PrecCommas (group (vcat [x <> pretty ',', y])) +-- -- | Pretty-print pairs as "(x, y)" +-- ppPair :: Prec -> SawDoc -> SawDoc -> SawDoc +-- ppPair prec x y = ppParensPrec prec PrecCommas (group (vcat [x <> pretty ',', y])) --- | Pretty-print pair types as "x * y" -ppPairType :: Prec -> SawDoc -> SawDoc -> SawDoc -ppPairType prec x y = ppParensPrec prec PrecProd (x <+> pretty '*' <+> y) +ppCommaSep :: [SawDoc] -> SawDoc +ppCommaSep [] = emptyDoc +ppCommaSep [x] = x +ppCommaSep (x : xs) = group (vcat [x <> pretty ',', ppCommaSep xs]) + +-- | Pretty-print tuples as "(x, y, z)" +ppTuple :: [SawDoc] -> SawDoc +ppTuple xs = parens (align (ppCommaSep xs)) + +-- -- | Pretty-print pair types as "x * y" +-- ppPairType :: Prec -> SawDoc -> SawDoc -> SawDoc +-- ppPairType prec x y = ppParensPrec prec PrecProd (x <+> pretty '*' <+> y) -- | Pretty-print records (if the flag is 'False') or record types (if the flag -- is 'True'), where the latter are preceded by the string @#@, either as: @@ -422,12 +431,8 @@ ppFlatTermF :: Prec -> FlatTermF Term -> PPM SawDoc ppFlatTermF prec tf = case tf of Primitive ec -> annotate PrimitiveStyle <$> ppBestName (ModuleIdentifier (primName ec)) - UnitValue -> return "(-empty-)" - UnitType -> return "#(-empty-)" - PairValue x y -> ppPair prec <$> ppTerm' PrecTerm x <*> ppTerm' PrecCommas y - PairType x y -> ppPairType prec <$> ppTerm' PrecApp x <*> ppTerm' PrecProd y - PairLeft t -> ppProj "1" <$> ppTerm' PrecArg t - PairRight t -> ppProj "2" <$> ppTerm' PrecArg t + TupleValue xs -> ppTuple <$> traverse (ppTerm' PrecCommas) (V.toList xs) + TupleSelector t i -> ppProj (Text.pack (show i)) <$> ppTerm' PrecArg t RecursorType d params motive _motiveTy -> do params_pp <- mapM (ppTerm' PrecArg) params @@ -585,8 +590,7 @@ shouldMemoizeTerm :: Term -> Bool shouldMemoizeTerm t = case unwrapTermF t of FTermF Primitive{} -> False - FTermF UnitValue -> False - FTermF UnitType -> False + FTermF (TupleValue xs) -> not (V.null xs) FTermF (CtorApp _ [] []) -> False FTermF (DataTypeApp _ [] []) -> False FTermF Sort{} -> False diff --git a/saw-core/src/Verifier/SAW/Typechecker.hs b/saw-core/src/Verifier/SAW/Typechecker.hs index c2ccb0903a..cb63ed97e4 100644 --- a/saw-core/src/Verifier/SAW/Typechecker.hs +++ b/saw-core/src/Verifier/SAW/Typechecker.hs @@ -238,23 +238,19 @@ typeInferCompleteTerm (Un.RecordType _ elems) = typeInferCompleteTerm (Un.RecordProj t prj) = (RecordProj <$> typeInferComplete t <*> return prj) >>= typeInferComplete --- Unit -typeInferCompleteTerm (Un.UnitValue _) = - typeInferComplete (UnitValue :: FlatTermF TypedTerm) -typeInferCompleteTerm (Un.UnitType _) = - typeInferComplete (UnitType :: FlatTermF TypedTerm) - --- Simple pairs -typeInferCompleteTerm (Un.PairValue _ t1 t2) = - (PairValue <$> typeInferComplete t1 <*> typeInferComplete t2) +-- Tuples +typeInferCompleteTerm (Un.TupleValue _ ts) = + (TupleValue <$> traverse typeInferComplete (V.fromList ts)) >>= typeInferComplete -typeInferCompleteTerm (Un.PairType _ t1 t2) = - (PairType <$> typeInferComplete t1 <*> typeInferComplete t2) - >>= typeInferComplete -typeInferCompleteTerm (Un.PairLeft t) = - (PairLeft <$> typeInferComplete t) >>= typeInferComplete -typeInferCompleteTerm (Un.PairRight t) = - (PairRight <$> typeInferComplete t) >>= typeInferComplete +typeInferCompleteTerm (Un.TupleType _ ts) = + do tts <- traverse typeInferComplete ts + v <- liftTCM scTupleType (map typedVal tts) + -- Ensure all arguments have type 'sort 0' + s0 <- liftTCM scSort (mkSort 0) + mapM_ (\tt -> checkSubtype tt s0) tts + pure (TypedTerm v s0) +typeInferCompleteTerm (Un.TupleProj t i) = + (TupleSelector <$> typeInferComplete t <*> pure i) >>= typeInferComplete -- Type ascriptions typeInferCompleteTerm (Un.TypeConstraint t _ tp) = diff --git a/saw-core/src/Verifier/SAW/UntypedAST.hs b/saw-core/src/Verifier/SAW/UntypedAST.hs index d889210fe2..a87dcd43b8 100644 --- a/saw-core/src/Verifier/SAW/UntypedAST.hs +++ b/saw-core/src/Verifier/SAW/UntypedAST.hs @@ -67,17 +67,14 @@ data Term | Lambda Pos TermCtx Term | Pi Pos TermCtx Term | Recursor (Maybe ModuleName) (PosPair Text) - | UnitValue Pos - | UnitType Pos -- | New-style records | RecordValue Pos [(PosPair FieldName, Term)] | RecordType Pos [(PosPair FieldName, Term)] | RecordProj Term FieldName - -- | Simple pairs - | PairValue Pos Term Term - | PairType Pos Term Term - | PairLeft Term - | PairRight Term + -- | Tuples + | TupleValue Pos [Term] + | TupleType Pos [Term] + | TupleProj Term Int -- | Identifies a type constraint on the term, i.e., a type ascription | TypeConstraint Term Pos Term | NatLit Pos Natural @@ -117,15 +114,12 @@ instance Positioned Term where App x _ -> pos x Pi p _ _ -> p Recursor _ i -> pos i - UnitValue p -> p - UnitType p -> p RecordValue p _ -> p RecordType p _ -> p RecordProj x _ -> pos x - PairValue p _ _ -> p - PairType p _ _ -> p - PairLeft x -> pos x - PairRight x -> pos x + TupleValue p _ -> p + TupleType p _ -> p + TupleProj x _ -> pos x TypeConstraint _ p _ -> p NatLit p _ -> p StringLit p _ -> p @@ -240,21 +234,14 @@ asApp = go [] -- | Build a tuple value @(x1, .., xn)@. mkTupleValue :: Pos -> [Term] -> Term -mkTupleValue p [] = UnitValue p mkTupleValue _ [x] = x -mkTupleValue p (x:xs) = PairValue (pos x) x (mkTupleValue p xs) +mkTupleValue p xs = TupleValue p xs -- | Build a tuple type @#(x1, .., xn)@. mkTupleType :: Pos -> [Term] -> Term -mkTupleType p [] = UnitType p mkTupleType _ [x] = x -mkTupleType p (x:xs) = PairType (pos x) x (mkTupleType p xs) +mkTupleType p xs = TupleType p xs --- | Build a projection @t.i@ of a tuple. NOTE: This function does not --- work to access the last component in a tuple, since it always --- generates a @PairLeft@. +-- | Build a projection @t.i@ of a tuple. mkTupleSelector :: Term -> Natural -> Term -mkTupleSelector t i - | i == 1 = PairLeft t - | i > 1 = mkTupleSelector (PairRight t) (i - 1) - | otherwise = error "mkTupleSelector: non-positive index" +mkTupleSelector t i = TupleProj t (fromIntegral i) -- FIXME: unchecked fromIntegral diff --git a/src/SAWScript/Crucible/Common/MethodSpec.hs b/src/SAWScript/Crucible/Common/MethodSpec.hs index 5d9249a824..cefc96e1f4 100644 --- a/src/SAWScript/Crucible/Common/MethodSpec.hs +++ b/src/SAWScript/Crucible/Common/MethodSpec.hs @@ -253,9 +253,9 @@ setupToTerm opts sc = typ <- lift $ scTypeOf sc et lift $ scAt sc lent typ art ixt - SetupStruct _ _ fs -> + SetupStruct _ _ _fs -> do st <- setupToTerm opts sc base - lift $ scTupleSelector sc st ind (length fs) + lift $ scTupleSelector sc st ind _ -> MaybeT $ return Nothing diff --git a/src/SAWScript/Crucible/LLVM/Builtins.hs b/src/SAWScript/Crucible/LLVM/Builtins.hs index ba1ad5e1a1..3d9777a9c9 100644 --- a/src/SAWScript/Crucible/LLVM/Builtins.hs +++ b/src/SAWScript/Crucible/LLVM/Builtins.hs @@ -414,9 +414,9 @@ llvm_compositional_extract (Some lm) nm func_name lemmas checkSat setup tactic = input_terms <- io $ traverse (scExtCns shared_context) input_parameters applied_extracted_func <- io $ scApplyAll shared_context extracted_func_const input_terms applied_extracted_func_selectors <- - io $ forM [1 .. (length output_parameters)] $ \i -> + io $ forM [0 .. (length output_parameters - 1)] $ \i -> mkTypedTerm shared_context - =<< scTupleSelector shared_context applied_extracted_func i (length output_parameters) + =<< scTupleSelector shared_context applied_extracted_func i let output_parameter_substitution = Map.fromList $ zip (map ecVarIndex output_parameters) (map ttTerm applied_extracted_func_selectors) diff --git a/src/SAWScript/Crucible/LLVM/ResolveSetupValue.hs b/src/SAWScript/Crucible/LLVM/ResolveSetupValue.hs index fb37eae978..b5374b1829 100644 --- a/src/SAWScript/Crucible/LLVM/ResolveSetupValue.hs +++ b/src/SAWScript/Crucible/LLVM/ResolveSetupValue.hs @@ -838,7 +838,7 @@ resolveSAWTerm cc tp tm = Cryptol.TVTuple tps -> do st <- sawCoreState sym let sc = saw_ctx st - tms <- mapM (\i -> scTupleSelector sc tm i (length tps)) [1 .. length tps] + tms <- mapM (scTupleSelector sc tm) [0 .. length tps - 1] vals <- zipWithM (resolveSAWTerm cc) tps tms storTy <- case toLLVMType dl tp of @@ -1062,8 +1062,7 @@ memArrayToSawCoreTerm crucible_context endianess typed_term = do inner_saw_term <- liftIO $ scTupleSelector saw_context saw_term - (field_index + 1) - (length tuple_element_cryptol_types) + field_index setBytes tuple_element_cryptol_type inner_saw_term diff --git a/src/SAWScript/Crucible/LLVM/X86.hs b/src/SAWScript/Crucible/LLVM/X86.hs index e95786b426..958cb3d84f 100644 --- a/src/SAWScript/Crucible/LLVM/X86.hs +++ b/src/SAWScript/Crucible/LLVM/X86.hs @@ -566,8 +566,8 @@ setupSimpleLoopFixpointFeature sym sc sawst cfg mvar func = arguments <- forM fixpoint_substitution_as_list $ \(MapF.Pair _ fixpoint_entry) -> toSC sym sawst $ Crucible.LLVM.Fixpoint.headerValue fixpoint_entry applied_func <- scApplyAll sc (ttTerm func) $ implicit_parameters ++ arguments - applied_func_selectors <- forM [1 .. (length fixpoint_substitution_as_list)] $ \i -> - scTupleSelector sc applied_func i (length fixpoint_substitution_as_list) + applied_func_selectors <- forM [0 .. (length fixpoint_substitution_as_list - 1)] $ \i -> + scTupleSelector sc applied_func i result_substitution <- MapF.fromList <$> zipWithM (\(MapF.Pair variable _) applied_func_selector -> MapF.Pair variable <$> bindSAWTerm sym sawst (W4.exprType variable) applied_func_selector) diff --git a/src/SAWScript/Prover/Exporter.hs b/src/SAWScript/Prover/Exporter.hs index 63d66a18bf..9e48e4ea3c 100644 --- a/src/SAWScript/Prover/Exporter.hs +++ b/src/SAWScript/Prover/Exporter.hs @@ -343,12 +343,9 @@ writeVerilogSAT path satq = getSharedContext >>= \sc -> io $ flattenSValue :: IsSymExprBuilder sym => sym -> W4Sim.SValue sym -> IO [Some (W4.SymExpr sym)] flattenSValue _ (Sim.VBool b) = return [Some b] flattenSValue _ (Sim.VWord (W4Sim.DBV w)) = return [Some w] -flattenSValue sym (Sim.VPair l r) = - do lv <- Sim.force l - rv <- Sim.force r - ls <- flattenSValue sym lv - rs <- flattenSValue sym rv - return (ls ++ rs) +flattenSValue sym (Sim.VTuple ts) = + do vs <- traverse Sim.force ts + concat <$> traverse (flattenSValue sym) vs flattenSValue sym (Sim.VVector ts) = do vs <- mapM Sim.force ts let getBool (Sim.VBool b) = Just b diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index 71e79735ba..d2eb06be5c 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -438,15 +438,14 @@ liftSC5 f a b c d e = mrSC >>= \sc -> liftIO (f sc a b c d e) -- | Apply a 'TermProj' to perform a projection on a 'Term' doTermProj :: Term -> TermProj -> MRM Term -doTermProj t TermProjLeft = liftSC1 scPairLeft t -doTermProj t TermProjRight = liftSC1 scPairRight t +doTermProj t (TermProjTuple i) = liftSC1 (\sc x -> scTupleSelector sc x i) t doTermProj t (TermProjRecord fld) = liftSC2 scRecordSelect t fld -- | Apply a 'TermProj' to a type to get the output type of the projection, -- assuming that the type is already normalized doTypeProj :: Term -> TermProj -> MRM Term -doTypeProj (asPairType -> Just (tp1, _)) TermProjLeft = return tp1 -doTypeProj (asPairType -> Just (_, tp2)) TermProjRight = return tp2 +doTypeProj (asTupleType -> Just tps) (TermProjTuple i) + | i < length tps = pure (tps !! i) doTypeProj (asRecordType -> Just tp_map) (TermProjRecord fld) | Just tp <- Map.lookup fld tp_map = return tp diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index efb196f0bc..406a69d4fe 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -221,11 +221,9 @@ mrProvable bool_tm = (closedOpenTerm a) ec <- instUVar nm ec_tp liftSC4 genBVVecTerm n len a ec - -- For pairs, recurse on both sides and combine the result as a pair - (asPairType -> Just (tp1, tp2)) -> do - e1 <- instUVar nm tp1 - e2 <- instUVar nm tp2 - liftSC2 scPairValue e1 e2 + -- For tuples, recurse on all components and combine the result as a tuple + (asTupleType -> Just tps) -> + liftSC1 scTuple =<< traverse (instUVar nm) tps -- Otherwise, create a global variable with the given name and type tp' -> liftSC2 scFreshEC nm tp' >>= liftSC1 scExtCns @@ -268,6 +266,12 @@ andTermInCtx (TermInCtx ctx1 t1) (TermInCtx ctx2 t2) = t2' <- liftTermLike (length ctx2) (length ctx1) t2 TermInCtx (ctx1++ctx2) <$> liftSC2 scAnd t1' t2' +-- | Conjoin a list of 'TermInCtx's, assuming they all have Boolean type. +allTermInCtx :: [TermInCtx] -> MRM TermInCtx +allTermInCtx [] = TermInCtx [] <$> liftSC1 scBool True +allTermInCtx [x] = pure x +allTermInCtx (x : xs) = andTermInCtx x =<< allTermInCtx xs + -- | Extend the context of a 'TermInCtx' with additional universal variables -- bound "outside" the 'TermInCtx' extTermInCtx :: [(LocalName,Term)] -> TermInCtx -> TermInCtx @@ -358,15 +362,13 @@ mrProveEqH _ (asBoolType -> Just _) t1 t2 = mrProveEqH _ (asIntegerType -> Just _) t1 t2 = mrProveEqSimple (liftSC2 scIntEq) t1 t2 --- For pair types, prove both the left and right projections are equal -mrProveEqH var_map (asPairType -> Just (tpL, tpR)) t1 t2 = - do t1L <- liftSC1 scPairLeft t1 - t2L <- liftSC1 scPairLeft t2 - t1R <- liftSC1 scPairRight t1 - t2R <- liftSC1 scPairRight t2 - condL <- mrProveEqH var_map tpL t1L t2L - condR <- mrProveEqH var_map tpR t1R t2R - andTermInCtx condL condR +-- For tuple types, prove all of the projections are equal +mrProveEqH var_map (asTupleType -> Just tps) t1 t2 = + do let idxs = [0 .. length tps - 1] + ts1 <- liftSC1 (\sc t -> traverse (scTupleSelector sc t) idxs) t1 + ts2 <- liftSC1 (\sc t -> traverse (scTupleSelector sc t) idxs) t2 + conds <- sequence $ zipWith3 (mrProveEqH var_map) tps ts1 ts2 + allTermInCtx conds -- For non-bitvector vector types, prove all projections are equal by -- quantifying over a universal index variable and proving equality at that diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index f6b711a1e7..278d9e12f8 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -145,13 +145,6 @@ asLRTList (asCtor -> Just (primName -> "Prelude.LRT_Cons", [lrt, lrts])) = (tp_norm_closed :) <$> asLRTList lrts asLRTList t = throwMRFailure (MalformedLetRecTypes t) --- | Match a right-nested series of pairs. This is similar to 'asTupleValue' --- except that it expects a unit value to always be at the end. -asNestedPairs :: Recognizer Term [Term] -asNestedPairs (asPairValue -> Just (x, asNestedPairs -> Just xs)) = Just (x:xs) -asNestedPairs (asFTermF -> Just UnitValue) = Just [] -asNestedPairs _ = Nothing - -- | Bind fresh function variables for a @letRecM@ or @multiFixM@ with the given -- @LetRecTypes@ and definitions for the function bodies as a lambda mrFreshLetRecVars :: Term -> Term -> MRM [Term] @@ -169,7 +162,7 @@ mrFreshLetRecVars lrts defs_f = -- the definitions of the individual letrec-bound functions in terms of the -- new function constants defs_tm <- mrApplyAll defs_f fun_tms - defs <- case asNestedPairs defs_tm of + defs <- case asTupleValue defs_tm of Just defs -> return defs Nothing -> throwMRFailure (MalformedDefsFun defs_f) diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index cd7a10c86d..8b27a85a72 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -64,16 +64,15 @@ showMRVar :: MRVar -> String showMRVar = show . ppName . ecName . unMRVar -- | A tuple or record projection of a 'Term' -data TermProj = TermProjLeft | TermProjRight | TermProjRecord FieldName +data TermProj = TermProjTuple Int | TermProjRecord FieldName deriving (Eq, Ord, Show) -- | Recognize a 'Term' as 0 or more projections asProjAll :: Term -> (Term, [TermProj]) asProjAll (asRecordSelector -> Just ((asProjAll -> (t, projs)), fld)) = (t, TermProjRecord fld:projs) -asProjAll (asPairSelector -> Just ((asProjAll -> (t, projs)), isRight)) - | isRight = (t, TermProjRight:projs) - | not isRight = (t, TermProjLeft:projs) +asProjAll (asTupleSelector -> Just ((asProjAll -> (t, projs)), i)) = + (t, TermProjTuple i : projs) asProjAll t = (t, []) -- | Names of functions to be used in computations, which are either names bound @@ -100,10 +99,8 @@ funNameTerm :: FunName -> Term funNameTerm (LetRecName var) = Unshared $ FTermF $ ExtCns $ unMRVar var funNameTerm (EVarFunName var) = Unshared $ FTermF $ ExtCns $ unMRVar var funNameTerm (GlobalName gdef []) = globalDefTerm gdef -funNameTerm (GlobalName gdef (TermProjLeft:projs)) = - Unshared $ FTermF $ PairLeft $ funNameTerm (GlobalName gdef projs) -funNameTerm (GlobalName gdef (TermProjRight:projs)) = - Unshared $ FTermF $ PairRight $ funNameTerm (GlobalName gdef projs) +funNameTerm (GlobalName gdef (TermProjTuple i : projs)) = + Unshared $ FTermF $ TupleSelector (funNameTerm (GlobalName gdef projs)) i funNameTerm (GlobalName gdef (TermProjRecord fname:projs)) = Unshared $ FTermF $ RecordProj (funNameTerm (GlobalName gdef projs)) fname @@ -381,8 +378,7 @@ instance PrettyInCtx [Term] where prettyInCtx xs = list <$> mapM prettyInCtx xs instance PrettyInCtx TermProj where - prettyInCtx TermProjLeft = return (pretty '.' <> "1") - prettyInCtx TermProjRight = return (pretty '.' <> "2") + prettyInCtx (TermProjTuple i) = return (pretty '.' <> pretty i) prettyInCtx (TermProjRecord fld) = return (pretty '.' <> pretty fld) instance PrettyInCtx FunName where diff --git a/src/SAWScript/SBVParser.hs b/src/SAWScript/SBVParser.hs index 8f16083c79..dfcf4d49f4 100644 --- a/src/SAWScript/SBVParser.hs +++ b/src/SAWScript/SBVParser.hs @@ -263,7 +263,7 @@ scTyp sc (TRecord fields) = splitInputs :: SharedContext -> Typ -> Term -> IO [Term] splitInputs _sc TBool x = return [x] splitInputs sc (TTuple ts) x = - do xs <- mapM (\i -> scTupleSelector sc x i (length ts)) [1 .. length ts] + do xs <- mapM (scTupleSelector sc x) [0 .. length ts - 1] yss <- sequence (zipWith (splitInputs sc) ts xs) return (concat yss) splitInputs _ (TVec _ TBool) x = return [x]