Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify generated code for variable declarations with initialization expressions #572

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/main/scala/viper/gobra/ast/internal/PrettyPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ class DefaultPrettyPrinter extends PrettyPrinter with kiama.output.PrettyPrinter
showVar(resTarget) <> "," <+> showVar(successTarget) <+> "=" <+> showExpr(expr) <> "." <> parens(showType(typ))

case Initialization(left) => "init" <+> showVar(left)

case Allocation(left) => "allocate" <+> showVar(left)

case SingleAss(left, right) => showAssignee(left) <+> "=" <+> showExpr(right)

case FunctionCall(targets, func, args) =>
Expand Down Expand Up @@ -738,6 +741,9 @@ class ShortPrettyPrinter extends DefaultPrettyPrinter {
showVar(resTarget) <> "," <+> showVar(successTarget) <+> "=" <+> showExpr(expr)

case Initialization(left) => "init" <+> showVar(left)

case Allocation(left) => "allocate" <+> showVar(left)

case SingleAss(left, right) => showAssignee(left) <+> "=" <+> showExpr(right)

case _: FunctionCall | _: MethodCall | _: ClosureCall => super.showStmt(s)
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/viper/gobra/ast/internal/Program.scala
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ case class While(cond: Expr, invs: Vector[Assertion], terminationMeasure: Option

case class Initialization(left: AssignableVar)(val info: Source.Parser.Info) extends Stmt

case class Allocation(left: AssignableVar)(val info: Source.Parser.Info) extends Stmt

sealed trait Assignment extends Stmt

case class SingleAss(left: Assignee, right: Expr)(val info: Source.Parser.Info) extends Assignment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ package viper.gobra.ast.internal.transform

import viper.gobra.ast.internal._
import viper.gobra.reporting.Source
import viper.gobra.reporting.Source.OverflowCheckAnnotation
import viper.gobra.reporting.Source.{AnnotatedOrigin, InhaleInsteadOfAssignmentAnnotation, OverflowCheckAnnotation}
import viper.gobra.reporting.Source.Parser.Single
import viper.gobra.util.TypeBounds.BoundedIntegerKind
import viper.gobra.util.Violation.violation
Expand Down Expand Up @@ -134,10 +134,20 @@ object OverflowChecksTransform extends InternalTransform {
case m@SafeMapLookup(_, _, IndexedExp(base, idx, _)) =>
Seqn(genOverflowChecksExprs(Vector(base, idx)) :+ m)(m.info)

case i@Inhale(a) =>
i.info.origin match {
case Some(o: AnnotatedOrigin) if o.annotation == InhaleInsteadOfAssignmentAnnotation =>
a match {
case ExprAssertion(e) => Seqn(genOverflowChecksExprs(Vector(e)) :+ i)(i.info)
case a => violation(s"unexpected assertion $a.")
}
case _ => i
}

// explicitly matches remaining statements to detect non-exhaustive pattern matching if a new statement is added
case x@(_: Inhale | _: Exhale | _: Assert | _: Assume
case x@(_: Exhale | _: Assert | _: Assume
| _: Return | _: Fold | _: Unfold | _: PredExprFold | _: PredExprUnfold | _: Outline
| _: SafeTypeAssertion | _: SafeReceive | _: Label | _: Initialization ) => x
| _: SafeTypeAssertion | _: SafeReceive | _: Label | _: Initialization | _: Allocation) => x

case _ => violation("Unexpected case reached.")
}
Expand Down
105 changes: 59 additions & 46 deletions src/main/scala/viper/gobra/frontend/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import viper.gobra.frontend.info.base.Type._
import viper.gobra.frontend.info.base.{BuiltInMemberTag, Type, SymbolTable => st}
import viper.gobra.frontend.info.implementation.resolution.MemberPath
import viper.gobra.frontend.info.{ExternalTypeInfo, TypeInfo}
import viper.gobra.reporting.Source.{AutoImplProofAnnotation, ImportPreNotEstablished, MainPreNotEstablished}
import viper.gobra.reporting.Source.{AutoImplProofAnnotation, ImportPreNotEstablished, InhaleInsteadOfAssignmentAnnotation, MainPreNotEstablished}
import viper.gobra.reporting.{DesugaredMessage, Source}
import viper.gobra.theory.Addressability
import viper.gobra.translator.Names
Expand Down Expand Up @@ -958,17 +958,29 @@ object Desugar {

val src: Meta = meta(stmt, info)

/**
* Desugars the left side of an assignment, short variable declaration, and normal variable declaration.
* If the left side is an identifier definition, a variable declaration and initialization is written, as well.
*/
def leftOfAssignmentD(idn: PIdnNode, info: TypeInfo)(t: in.Type): Writer[in.Assignee] = {
val isDef = idn match {
trait InitMode

object InitMode {
// TODO(suggestion): we could remove `leftOfAssignmentNoInit` altogether by introducing the following mode
// case object SkipAllocAndInit extends InitMode
case object OnlyAlloc extends InitMode
case object AllocAndInit extends InitMode
}

def isDef(idn: PIdnNode, info: TypeInfo): Boolean = {
idn match {
case _: PIdnDef => true
case unk: PIdnUnk if info.isDef(unk) => true
case _ => false
}
}

/**
* Desugars the left side of an assignment, short variable declaration, and normal variable declaration.
* If the left side is an identifier definition, a variable declaration and allocation/initialization are also
* written, depending on the value of `initMode`.
*/
def leftOfAssignmentD(idn: PIdnNode, info: TypeInfo, initMode: InitMode)(t: in.Type): Writer[in.Assignee] = {
idn match {
case _: PWildcard => freshDeclaredExclusiveVar(t.withAddressability(Addressability.Exclusive), idn, info)(src).map(in.Assignee.Var)

Expand All @@ -977,11 +989,15 @@ object Desugar {
case Left(v) => v
case Right(v) => violation(s"Expected an assignable variable, but got $v instead")
}
if (isDef) {
if (isDef(idn, info)) {
val v = x.asInstanceOf[in.LocalVar]
for {
_ <- declare(v)
_ <- write(in.Initialization(v)(src))
allocOrInit = initMode match {
case InitMode.OnlyAlloc => in.Allocation(v)(src)
case InitMode.AllocAndInit => in.Initialization(v)(src)
}
_ <- write(allocOrInit)
} yield in.Assignee(v)
} else unit(in.Assignee(x))
}
Expand Down Expand Up @@ -1343,7 +1359,7 @@ object Desugar {

domain = in.MapKeys(c, underlyingType(exp.typ))(src)

visited <- leftOfAssignmentD(range.enumerated, info)(in.SetT(keyType, Addressability.exclusiveVariable))
visited <- leftOfAssignmentD(range.enumerated, info, InitMode.AllocAndInit)(in.SetT(keyType, Addressability.exclusiveVariable))

perm <- freshDeclaredExclusiveVar(in.PermissionT(Addressability.exclusiveVariable), n, info)(src)

Expand Down Expand Up @@ -1425,7 +1441,7 @@ object Desugar {

domain = in.MapKeys(c, underlyingType(exp.typ))(src)

visited <- leftOfAssignmentD(range.enumerated, info)(in.SetT(keyType, Addressability.exclusiveVariable))
visited <- leftOfAssignmentD(range.enumerated, info, InitMode.AllocAndInit)(in.SetT(keyType, Addressability.exclusiveVariable))

perm <- freshDeclaredExclusiveVar(in.PermissionT(Addressability.exclusiveVariable), n, info)(src)

Expand Down Expand Up @@ -1635,18 +1651,16 @@ object Desugar {
seqn(sequence((left zip right).map{ case (l, r) =>
for {
re <- goE(r)
le <- leftOfAssignmentD(l, info)(re.typ)
} yield singleAss(le, re)(src)
le <- leftOfAssignmentD(l, info, InitMode.OnlyAlloc)(re.typ)
} yield singleAss(le, re, isInitExpr = isDef(l, info))(src)
}).map(in.Seqn(_)(src)))
} else if (right.size == 1) {
seqn(for {
re <- goE(right.head)
les <- sequence(left.map{ l =>
for {
dL <- leftOfAssignmentD(l, info)(typeD(info.typ(l), Addressability.exclusiveVariable)(src))
} yield dL
leftOfAssignmentD(l, info, InitMode.OnlyAlloc)(typeD(info.typ(l), Addressability.exclusiveVariable)(src))
})
} yield multiassD(les, re, stmt)(src))
} yield multiassD(les, re, stmt, isInitExpr = left.forall(l => isDef(l, info)))(src))
} else { violation("invalid assignment") }

case PVarDecl(typOpt, right, left, _) =>
Expand All @@ -1656,32 +1670,19 @@ object Desugar {
for {
re <- goE(r)
typ = typOpt.map(x => typeD(info.symbType(x), Addressability.exclusiveVariable)(src)).getOrElse(re.typ)
dL <- leftOfAssignmentD(l, info)(typ)
dL <- leftOfAssignmentD(l, info, InitMode.OnlyAlloc)(typ)
le <- unit(dL)
} yield singleAss(le, re)(src)
} yield singleAss(le, re, isInitExpr = isDef(l, info))(src)
}).map(in.Seqn(_)(src)))
} else if (right.size == 1) {
seqn(for {
re <- goE(right.head)
les <- sequence(left.map{l =>
for {
dL <- leftOfAssignmentD(l, info)(re.typ)
} yield dL
})
} yield multiassD(les, re, stmt)(src))
les <- sequence(left.map{leftOfAssignmentD(_, info, InitMode.OnlyAlloc)(re.typ)})
} yield multiassD(les, re, stmt, isInitExpr = left.forall(l => isDef(l, info)))(src))
} else if (right.isEmpty && typOpt.nonEmpty) {
val typ = typeD(info.symbType(typOpt.get), Addressability.exclusiveVariable)(src)
val lelems = sequence(left.map{ l =>
for {
dL <- leftOfAssignmentD(l, info)(typ)
} yield dL
})
val relems = left.map{ l => in.DfltVal(typeD(info.symbType(typOpt.get), Addressability.defaultValue)(meta(l, info)))(meta(l, info)) }
seqn(lelems.map{ lelemsV =>
in.Seqn((lelemsV zip relems).map{
case (l, r) => singleAss(l, r)(src)
})(src)
})
val lelems = sequence(left.map{ leftOfAssignmentD(_, info, InitMode.AllocAndInit)(typ) })
for {_ <- lelems} yield in.Seqn(Vector())(src)
} else { violation("invalid declaration") }

case PReturn(exps) =>
Expand Down Expand Up @@ -1916,11 +1917,11 @@ object Desugar {
} yield (acceptCond, stmt)
}

def multiassD(lefts: Vector[in.Assignee], right: in.Expr, astCtx: PNode)(src: Source.Parser.Info): in.Stmt = {
def multiassD(lefts: Vector[in.Assignee], right: in.Expr, astCtx: PNode, isInitExpr: Boolean = false)(src: Source.Parser.Info): in.Stmt = {

right match {
case in.Tuple(args) if args.size == lefts.size =>
in.Seqn(lefts.zip(args) map { case (l, r) => singleAss(l, r)(src)})(src)
in.Seqn(lefts.zip(args) map { case (l, r) => singleAss(l, r, isInitExpr)(src)})(src)

case n: in.TypeAssertion if lefts.size == 2 =>
val resTarget = freshExclusiveVar(lefts(0).op.typ.withAddressability(Addressability.exclusiveVariable), astCtx, info)(src)
Expand All @@ -1929,8 +1930,8 @@ object Desugar {
Vector(resTarget, successTarget),
Vector( // declare for the fresh variables is not necessary because they are put into a block
in.SafeTypeAssertion(resTarget, successTarget, n.exp, n.arg)(n.info),
singleAss(lefts(0), resTarget)(src),
singleAss(lefts(1), successTarget)(src)
singleAss(lefts(0), resTarget, isInitExpr)(src),
singleAss(lefts(1), successTarget, isInitExpr)(src)
)
)(src)

Expand All @@ -1945,8 +1946,8 @@ object Desugar {
Vector(resTarget, successTarget),
Vector( // declare for the fresh variables is not necessary because they are put into a block
in.SafeReceive(resTarget, successTarget, n.channel, recvChannelProxy, recvGivenPermProxy, recvGotPermProxy, closedProxy)(n.info),
singleAss(lefts(0), resTarget)(src),
singleAss(lefts(1), successTarget)(src)
singleAss(lefts(0), resTarget, isInitExpr)(src),
singleAss(lefts(1), successTarget, isInitExpr)(src)
)
)(src)

Expand All @@ -1957,8 +1958,8 @@ object Desugar {
Vector(resTarget, successTarget),
Vector(
in.SafeMapLookup(resTarget, successTarget, l)(l.info),
singleAss(lefts(0), resTarget)(src),
singleAss(lefts(1), successTarget)(src)
singleAss(lefts(0), resTarget, isInitExpr)(src),
singleAss(lefts(1), successTarget, isInitExpr)(src)
)
)(src)

Expand Down Expand Up @@ -2305,8 +2306,20 @@ object Desugar {
}
}

def singleAss(left: in.Assignee, right: in.Expr)(info: Source.Parser.Info): in.SingleAss = {
in.SingleAss(left, implicitConversion(right.typ, left.op.typ, right))(info)
def singleAss(left: in.Assignee, right: in.Expr, isInitExpr: Boolean = false)(info: Source.Parser.Info): in.Stmt = {
if (isInitExpr) {
val newInfo = info match {
case s: Source.Parser.Single => s.createAnnotatedInfo(InhaleInsteadOfAssignmentAnnotation)
case i => violation(s"l.op.info ($i) is expected to be a Single")
}
// Optimization: if we know that right is the expr passed to the declaration of left, then there is no need to
// assign to left (which implies exhaling and inhaling the footprint). Instead, we can just assume directly the
// equality between left and right.
val eq = in.ExprAssertion(in.GhostEqCmp(left.op, implicitConversion(right.typ, left.op.typ, right))(newInfo))(newInfo)
in.Inhale(eq)(newInfo)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is ultra urgent, then you could introduce another internal node, e.g. "InitialAssignment" and then replace the inhales with that. That would still be an ugly solution, but at it maintaines the separation of purposes somewhat.

} else {
in.SingleAss(left, implicitConversion(right.typ, left.op.typ, right))(info)
}
}

def arguments(symb: st.WithArguments, args: Vector[in.Expr]): Vector[in.Expr] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

package viper.gobra.reporting

import viper.gobra.reporting.Source.{AutoImplProofAnnotation, CertainSource, CertainSynthesized, ImportPreNotEstablished, MainPreNotEstablished, OverflowCheckAnnotation, ReceiverNotNilCheckAnnotation, InsufficientPermissionToRangeExpressionAnnotation, LoopInvariantNotEstablishedAnnotation}

import viper.gobra.reporting.Source.{AutoImplProofAnnotation, CertainSource, CertainSynthesized, ImportPreNotEstablished, InhaleInsteadOfAssignmentAnnotation, InsufficientPermissionToRangeExpressionAnnotation, LoopInvariantNotEstablishedAnnotation, MainPreNotEstablished, OverflowCheckAnnotation, ReceiverNotNilCheckAnnotation}
import viper.gobra.reporting.Source.Verifier./
import viper.silver
import viper.silver.ast.Not
Expand Down Expand Up @@ -185,6 +184,11 @@ class DefaultErrorBackTranslator(
case _ / LoopInvariantNotEstablishedAnnotation =>
x.reasons.foldLeft(LoopInvariantEstablishmentError(x.info): VerificationError) { case (err, reason) => err dueTo reason }

case _ / InhaleInsteadOfAssignmentAnnotation =>
x.reasons.foldLeft(AssignmentError(x.info): VerificationError){
case (err, reason) => err dueTo reason
}

case _ => x
}

Expand Down
1 change: 1 addition & 0 deletions src/main/scala/viper/gobra/reporting/Source.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ object Source {
case object ImportPreNotEstablished extends Annotation
case object MainPreNotEstablished extends Annotation
case object LoopInvariantNotEstablishedAnnotation extends Annotation
case object InhaleInsteadOfAssignmentAnnotation extends Annotation
case class NoPermissionToRangeExpressionAnnotation() extends Annotation
case class InsufficientPermissionToRangeExpressionAnnotation() extends Annotation
case class AutoImplProofAnnotation(subT: String, superT: String) extends Annotation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,13 @@ trait TypeEncoding extends Generator {
* Initialize[loc: T@] -> inhale Footprint[loc]; assume [loc == dflt(T°)] && [&loc != nil(*T)]
*/
def initialization(ctx: Context): in.Location ==> CodeWriter[vpr.Stmt] = {
case loc :: t / Exclusive if typ(ctx).isDefinedAt(t) =>
case loc :: t if typ(ctx).isDefinedAt(t) =>
val (pos, info, errT) = loc.vprMeta
for {
alloc <- ctx.allocation(loc)
_ <- write(alloc)
eq <- ctx.equal(loc, in.DfltVal(t.withAddressability(Exclusive))(loc.info))(loc)
} yield vpr.Inhale(eq)(pos, info, errT): vpr.Stmt

case loc :: t / Shared if typ(ctx).isDefinedAt(t) =>
val (pos, info, errT) = loc.vprMeta
for {
footprint <- addressFootprint(ctx)(loc, in.FullPerm(loc.info))
eq1 <- ctx.equal(loc, in.DfltVal(t.withAddressability(Exclusive))(loc.info))(loc)
eq2 <- ctx.equal(in.Ref(loc)(loc.info), in.NilLit(in.PointerT(t, Exclusive))(loc.info))(loc)
} yield vpr.Inhale(vpr.And(footprint, vpr.And(eq1, vpr.Not(eq2)(pos, info, errT))(pos, info, errT))(pos, info, errT))(pos, info, errT)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class MemoryEncoding extends Encoding {
}

override def statement(ctx: Context): in.Stmt ==> CodeWriter[vpr.Stmt] = {
case in.Allocation(left) => ctx.allocation(left)
case in.Initialization(left) => ctx.initialization(left)
case ass: in.SingleAss => ctx.assignment(ass.left, ass.right)(ass)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,4 @@ func isAstNode(n node) (res bool) {
} else {
return false
}
}
}