Skip to content

Commit

Permalink
Merge branch 'master' into meilers_silicon_wand_issue_test
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoeilers authored Oct 16, 2023
2 parents 769b09f + f80cbfa commit c49b10a
Show file tree
Hide file tree
Showing 10 changed files with 525 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/main/scala/viper/silver/frontend/SilFrontEndConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,15 @@ abstract class SilFrontendConfig(args: Seq[String], private var projectName: Str
case i => throw new IllegalArgumentException(s"Unsupported counterexample model provided. Expected 'native', 'variables' or 'mapped' but got $i")
}))

val terminationPlugin = opt[Boolean]("disableTerminationPlugin",
val disableTerminationPlugin = opt[Boolean]("disableTerminationPlugin",
descr = "Disable the termination plugin, which adds termination checks to functions, " +
"methods and loops.",
default = Some(false),
noshort = true,
hidden = true
)

val adtPlugin = opt[Boolean]("disableAdtPlugin",
val disableAdtPlugin = opt[Boolean]("disableAdtPlugin",
descr = "Disable the ADT plugin, which adds support for ADTs as a built-in type.",
default = Some(false),
noshort = true,
Expand Down
10 changes: 8 additions & 2 deletions src/main/scala/viper/silver/plugin/standard/adt/AdtPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class AdtPlugin(@unused reporter: viper.silver.reporter.Reporter,
*/
private var derivesImported: Boolean = false

private def isTerminationPluginActive: Boolean = {
config != null && !config.disableTerminationPlugin.toOption.getOrElse(false) &&
(!config.disableDefaultPlugins.toOption.getOrElse(false) ||
config.plugin.toOption.getOrElse("").split(":").contains("viper.silver.plugin.standard.termination.TerminationPlugin"))
}

def adtDerivingFunc[$: P]: P[PIdnUse] = FP(StringIn("contains").!).map { case (pos, id) => PIdnUse(id)(pos) }

override def beforeParse(input: String, isImported: Boolean): String = {
Expand All @@ -51,7 +57,7 @@ class AdtPlugin(@unused reporter: viper.silver.reporter.Reporter,
input
}

private def deactivated: Boolean = config != null && config.adtPlugin.toOption.getOrElse(false)
private def deactivated: Boolean = config != null && config.disableAdtPlugin.toOption.getOrElse(false)

private def setDerivesImported(input: String): Unit = "import[\\s]+<adt\\/derives\\.vpr>".r.findFirstIn(input) match {
case Some(_) => derivesImported = true
Expand Down Expand Up @@ -131,7 +137,7 @@ class AdtPlugin(@unused reporter: viper.silver.reporter.Reporter,
if (deactivated) {
return input
}
new AdtEncoder(input).encode()
new AdtEncoder(input).encode(isTerminationPluginActive)
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,19 @@ class AdtEncoder(val program: Program) extends AdtNameManager {
*
* @return The encoded program.
*/
def encode(): Program = {
def encode(isTerminationPluginActive: Boolean): Program = {
def generateWellFoundedness(a: Adt) = {
isTerminationPluginActive &&
program.domainsByName.contains(getWellFoundedOrderDeclarationDomainName) &&
!program.domainsByName.contains(getWellFoundedDomainName(a.name))
}

// In a first step encode all adt top level declarations and constructor calls
var newProgram: Program = StrategyBuilder.Slim[Node]({
case p@Program(domains, fields, functions, predicates, methods, extensions) =>
val remainingExtensions = extensions filter { case _: Adt => false; case _ => true }
val encodedAdtsAsDomains: Seq[Domain] = extensions collect { case a: Adt => encodeAdtAsDomain(a) }
val tmp = extensions collect { case a: Adt => encodeAdtAsDomain(a, generateWellFoundedness(a)) }
val encodedAdtsAsDomains: Seq[Domain] = tmp.flatten
Program(domains ++ encodedAdtsAsDomains, fields, functions, predicates, methods, remainingExtensions)(p.pos, p.info, p.errT)
case aca: AdtConstructorApp => encodeAdtConstructorApp(aca)
case ada: AdtDestructorApp => encodeAdtDestructorApp(ada)
Expand Down Expand Up @@ -78,7 +84,7 @@ class AdtEncoder(val program: Program) extends AdtNameManager {
* @param adt The ADT to encode
* @return An the encoded ADT as a domain
*/
private def encodeAdtAsDomain(adt: Adt): Domain = {
private def encodeAdtAsDomain(adt: Adt, generateWellFoundedness: Boolean): Seq[Domain] = {
adt match {
case Adt(name, constructors, typVars, derivingInfo) =>
val domain = Domain(name, null, null, typVars)(adt.pos, adt.info, adt.errT)
Expand All @@ -88,7 +94,15 @@ class AdtEncoder(val program: Program) extends AdtNameManager {
(constructors map generateTagAxiom(domain)) ++ Seq(generateExclusivityAxiom(domain)(constructors))
val derivingAxioms = if (derivingInfo.contains(getContainsFunctionName))
constructors filter (_.formalArgs.nonEmpty) map generateContainsAxiom(domain, derivingInfo(getContainsFunctionName)._2) else Seq.empty
domain.copy(functions = functions, axioms = axioms ++ derivingAxioms)(adt.pos, adt.info, adt.errT)
val newAdtDomain = domain.copy(functions = functions, axioms = axioms ++ derivingAxioms)(adt.pos, adt.info, adt.errT)

if (generateWellFoundedness) {
val decreasesAxioms = (constructors map generateDecreasesAxiom(domain)) :+ generateDecreasesTransitivityAxiom(domain) :+ generateBoundedAxiom(domain)
val wellFoundedDomain = Domain(getWellFoundedDomainName(domain.name), Seq(), decreasesAxioms, domain.typVars)(adt.pos, adt.info, adt.errT)
Seq(newAdtDomain, wellFoundedDomain)
}else {
Seq(newAdtDomain)
}
}
}

Expand Down Expand Up @@ -427,6 +441,171 @@ class AdtEncoder(val program: Program) extends AdtNameManager {
AnonymousDomainAxiom(forall)(ac.pos, ac.info, ac.adtName, ac.errT)
}

/**
* Generates an axiom that expresses transitivity of the decreases relation for the current type ADT:
* forall v1: ADT, v2: ADT, v3: ADT :: { decreases(v1, v2), decreases(v2, v3) }
* decreases(v1, v2) && decreases(v2, v3) ==> decreases(v1, v3)
*/
private def generateDecreasesTransitivityAxiom(domain: Domain): AnonymousDomainAxiom = {
val dt = DomainType(domain, defaultTypeVarsFromDomain(domain))
val v1 = LocalVarDecl("v1", dt)()
val v2 = LocalVarDecl("v2", dt)()
val v3 = LocalVarDecl("v3", dt)()
val decreases12 = DomainFuncApp(
getDecreasesFunctionName,
Seq(v1.localVar, v2.localVar),
Map(TypeVar("T") -> dt)
)(domain.pos, domain.info, Bool, getWellFoundedOrderDeclarationDomainName, domain.errT)
val decreases23 = DomainFuncApp(
getDecreasesFunctionName,
Seq(v2.localVar, v3.localVar),
Map(TypeVar("T") -> dt)
)(domain.pos, domain.info, Bool, getWellFoundedOrderDeclarationDomainName, domain.errT)
val decreases13 = DomainFuncApp(
getDecreasesFunctionName,
Seq(v1.localVar, v3.localVar),
Map(TypeVar("T") -> dt)
)(domain.pos, domain.info, Bool, getWellFoundedOrderDeclarationDomainName, domain.errT)
val trigger = Trigger(Seq(decreases12, decreases23))(domain.pos, domain.info, domain.errT)
val body = Implies(And(decreases12, decreases23)(domain.pos, domain.info, domain.errT), decreases13)(domain.pos, domain.info, domain.errT)
val forall = Forall(Seq(v1, v2, v3), Seq(trigger), body)()
AnonymousDomainAxiom(forall)(domain.pos, domain.info, getWellFoundedDomainName(domain.name), domain.errT)
}

/**
* Generates an axiom that expresses that all values of the current ADT are bounded:
* forall x: ADT :: { bounded(x) } bounded(x)
* This is justified by the fact that Viper ADTs are recursive types and are always finite.
*/
private def generateBoundedAxiom(domain: Domain): AnonymousDomainAxiom = {
val domainType = DomainType(domain, defaultTypeVarsFromDomain(domain))
val param = LocalVarDecl("x", domainType)()
val boundedApp = DomainFuncApp(
getBoundedFunctionName,
Seq(param.localVar),
Map(TypeVar("T") -> domainType)
)(domain.pos, domain.info, Bool, getWellFoundedOrderDeclarationDomainName, domain.errT)
val trigger = Trigger(Seq(boundedApp))(domain.pos, domain.info, domain.errT)
val forall = Forall(Seq(param), Seq(trigger), boundedApp)(domain.pos, domain.info, domain.errT)
AnonymousDomainAxiom(forall)(domain.pos, domain.info, getWellFoundedDomainName(domain.name), domain.errT)
}

/**
* Generates an axiom for the given constructor that expresses that all values of the current ADT contained
* inside an ADT value constructed using said constructor are smaller than the ADT value itself.
* E.g., for List { Nil() Cons(i: Int, l: List) }:
* forall i: Int, l: List :: { Cons(i, l) } decreases(l, Cons(i, l))
* Also takes into account values that may be contained through constructors of other ADT types (in cases of mutually
* recursive ADT definitions).
*/
private def generateDecreasesAxiom(domain: Domain)(ac: AdtConstructor): AnonymousDomainAxiom = {
assert(domain.name == ac.adtName, "AdtEncoder: An error in the ADT encoding occurred.")

val localVars = ac.formalArgs.map { lv =>
lv.typ match {
case a: AdtType => localVarTFromType(encodeAdtTypeAsDomainType(a), Some(lv.name))(ac.pos, ac.info, ac.errT)
case d => localVarTFromType(d, Some(lv.name))(ac.pos, ac.info, ac.errT)
}
}

val decreases = if (ac.formalArgs.isEmpty) {
Nil
} else {
val localVarDecl = ac.formalArgs.collect { case l: LocalVarDecl => l }

assert(localVarDecl.size == localVars.size, "AdtEncoder: An error in the ADT encoding occurred.")

/**
* Given a variable currentVar that represents an argument of the current ADT constructor, if the variable's
* type is an ADT type, recursively looks for values of the original ADT type, either in the variable itself
* or in all constructors of its type.
* E.g., if ac is a constructor for ADT1, then:
* - if the type of currentVar is ADT1, then we have already found a value of the original ADT type
* - if the type of the variable is ADT2, and ADT2 has a constructor C(T1, ADT1), then we have found a value
* of the original type inside this constructor.
* The method returns a sequence of tuples, where each tuple contains
* - a list of all variables referred to in the second argument
* - an expression containing a value of type ADT1 (either just a variable of the type, or an ADT constructor
* applied to some arguments, one of which is either itself a variable of type ADT1 or another ADT constructor
* that has an argument that is or contains a value of said type).
* - the variable of type ADT1 contained in the second term.
* So, in the first scenario above, we return Seq((Seq(currentVar), currentVar, currentVar)).
* In the second scenario, we return Seq((Seq(t: T, a: ADT1), C(t, a), a))
*/
def getNestedADTVals(visitedADTTypes: Set[AdtType], currentVar: LocalVarDecl): Seq[(Seq[LocalVarDecl], Exp, Exp)] = {
currentVar.typ match {
case at: AdtType if at == ac.typ =>
// case 1: The variable directly has the type that we are looking for.
val newName = currentVar.name + "_" + visitedADTTypes.size
val renamedCurrentVar = currentVar.copy(name = newName)(currentVar.pos, currentVar.info, currentVar.errT)
Seq((Seq(renamedCurrentVar), renamedCurrentVar.localVar, renamedCurrentVar.localVar))
case at: AdtType if !visitedADTTypes.contains(at) =>
// case 2: The variable has a different ADT type, which may have one or more constructors that contain
// a value of the type we are looking for.
val adt = program.extensions.find {
case a: Adt if a.name == at.adtName => true
case _ => false
}.get.asInstanceOf[Adt]

// Look through all constructors
adt.constructors.flatMap(ac2 => {
val argDecls = ac2.formalArgs.map { case l: LocalVarDecl => l.copy(name = l.name + "_" + visitedADTTypes.size)(l.pos, l.info, l.errT) }

// Recursively check the type of the constructor's arguments
val argVals = ac2.formalArgs.map(fa2 => getNestedADTVals(visitedADTTypes + at, fa2))
argVals.zipWithIndex.flatMap{ case (avs, i) =>
val res: Seq[(Seq[LocalVarDecl], Exp, Exp)] = avs.map(av => {
// Apply the current constructor to the arguments for every occurrence that was found.
val qvars = av._1 ++ (argDecls diff Seq(argDecls(i)))
val cApp = DomainFuncApp(
ac2.name,
argDecls.take(i).map(_.localVar) ++ Seq(av._2) ++ argDecls.drop(i + 1).map(_.localVar),
encodeAdtTypeAsDomainType(ac2.typ).typVarsMap,
)(ac.pos, ac.info, encodeAdtTypeAsDomainType(ac2.typ), ac.adtName, ac.errT)
(qvars, cApp, av._3)
})
res
}
})
case _ =>
// case 3: Different type or an ADT type we already looked at.
Seq()
}
}
var decreasesQuants: List[Exp] = Nil

val nestedADTVals = localVarDecl.map(lvd => getNestedADTVals(Set(), lvd))

// For each found nested value of our type, generate a quantified expression that states states that the contained
// value is less than the original value.
for ((avs, i) <- nestedADTVals.zipWithIndex) {
val otherVars = localVarDecl.take(i) ++ localVarDecl.drop(i + 1)
for (av <- avs) {
val (qvars, value, smallerValue) = av
val allQvars = qvars ++ otherVars
val constructorArgs = localVars.take(i) ++ Seq(value) ++ localVars.drop(i + 1)
val constructorApp = DomainFuncApp(
ac.name,
constructorArgs,
defaultTypeVarsFromDomain(domain)
)(ac.pos, ac.info, encodeAdtTypeAsDomainType(ac.typ), ac.adtName, ac.errT)
val trigger = Trigger(Seq(constructorApp))(ac.pos, ac.info, ac.errT)
val decreasesApp = DomainFuncApp(
getDecreasesFunctionName,
Seq(smallerValue, constructorApp),
Map(TypeVar("T") -> constructorApp.typ)
)(ac.pos, ac.info, Bool, getWellFoundedOrderDeclarationDomainName, ac.errT)
val forall = Forall(allQvars, Seq(trigger), decreasesApp)(ac.pos, ac.info, ac.errT)
decreasesQuants = forall :: decreasesQuants
}
}
decreasesQuants
}

val body = decreases.foldLeft[Exp](TrueLit()())((a, b) => And(a, b)())
AnonymousDomainAxiom(body)(ac.pos, ac.info, getWellFoundedDomainName(domain.name), ac.errT)
}

/**
* This method encodes the transitivity of the contains function. Namely it collects arguments types of
* all contains applications as tuples, computes its transitive closure and finally the corresponding axioms.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ trait AdtNameManager {

def getContainsFunctionName: String = "contains"

def getWellFoundedDomainName(typeName: String): String = typeName + "WellFoundedOrder"
def getWellFoundedOrderDeclarationDomainName : String = "WellFoundedOrder"
def getDecreasesFunctionName: String = "decreasing"
def getBoundedFunctionName: String = "bounded"

def getContainsTransitivityDomain: String = getName("ContainsTransitivityDomain")

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TerminationPlugin(@unused reporter: viper.silver.reporter.Reporter,
fp: FastParser) extends SilverPlugin with ParserPluginTemplate {
import fp.{FP, keyword, exp, ParserExtension}

private def deactivated: Boolean = config != null && config.terminationPlugin.toOption.getOrElse(false)
private def deactivated: Boolean = config != null && config.disableTerminationPlugin.toOption.getOrElse(false)

private var decreasesClauses: Seq[PDecreasesClause] = Seq.empty

Expand Down
105 changes: 105 additions & 0 deletions src/test/resources/adt/termination_1.vpr
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Any copyright is dedicated to the Public Domain.
// http://creativecommons.org/publicdomain/zero/1.0/


domain Val {}

adt List[V] {
Nil()
Cons(value: V, tail: List[V])
}

function len(l: List[Val]): Int
ensures result >= 0
decreases l
{
l.isNil ? 0 : 1 + len(l.tail)
}

function len2(l: List[Val]): Int
ensures result >= 0
decreases l
{
l.isNil ? 0 : (l.tail.isNil ? 1 : 2 + len2(l.tail.tail))
}

function lenBad(l: List[Val], v: Val): Int
ensures result >= 0
decreases l
{
//:: ExpectedOutput(termination.failed:tuple.false)
lenBad(Cons(v, Nil()), v)
}

function lenBad2(l: List[Val]): Int
ensures result >= 0
decreases l
{
//:: ExpectedOutput(termination.failed:tuple.false)
1 + lenBad2(l)
}

////////////////////////

adt IntList {
INil()
ICons(ivalue: Int, itail: IntList)
}

function ilen(l: IntList): Int
ensures result >= 0
decreases l
{
l.isINil ? 0 : 1 + ilen(l.itail)
}

function ilen2(l: IntList): Int
ensures result >= 0
decreases l
{
l.isINil ? 0 : (l.itail.isINil ? 1 : 2 + ilen2(l.itail.itail))
}

function ilenBad(l: IntList, v: Int): Int
ensures result >= 0
decreases l
{
//:: ExpectedOutput(termination.failed:tuple.false)
ilenBad(ICons(v, INil()), v)
}

////////////////////////

// non-recursive data type with two type variables
adt Pair[T, V] {
pair(fst: T, snd: V)
}

function stupidFunc(p: Pair[Int, Val]): Val
decreases p
{
//:: ExpectedOutput(termination.failed:tuple.false)
stupidFunc(p)
}

// two type variables
adt DList[V, T] {
DNil()
DCons(dvalue1: V, dvalue2: T, dtail: DList[V, T])
}

function dlen(l: DList[Int, Val]): Int
ensures result >= 0
decreases l
{
l.isDNil ? 0 : 1 + dlen(l.dtail)
}

function dlenBad(l: DList[Int, Val]): Int
ensures result >= 0
decreases l
{
//:: ExpectedOutput(termination.failed:tuple.false)
l.isDNil ? 0 : 1 + dlenBad(l)
}

Loading

0 comments on commit c49b10a

Please sign in to comment.