From 057572b75fa21b12ffe8d044b5b1b2815197c1cd Mon Sep 17 00:00:00 2001 From: AndreyYurko Date: Thu, 18 Apr 2024 15:00:53 +0300 Subject: [PATCH 1/3] Weighted tree path selector impl --- .../concolic/InstructionConcolicChecker.kt | 14 +- .../coverage/ExperimentPathSelector.kt | 997 ++++++++++++++++++ .../concolic/coverage/ExperimentSelector.kt | 4 + .../concolic/coverage/WeightedGraph.kt | 271 +++++ 4 files changed, 1280 insertions(+), 6 deletions(-) create mode 100644 kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentPathSelector.kt create mode 100644 kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentSelector.kt create mode 100644 kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/WeightedGraph.kt diff --git a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/InstructionConcolicChecker.kt b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/InstructionConcolicChecker.kt index 74da014da..9ce29fc15 100644 --- a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/InstructionConcolicChecker.kt +++ b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/InstructionConcolicChecker.kt @@ -8,6 +8,7 @@ import org.vorpal.research.kex.ExecutionContext import org.vorpal.research.kex.asm.analysis.concolic.bfs.BfsPathSelectorManager import org.vorpal.research.kex.asm.analysis.concolic.cgs.ContextGuidedSelectorManager import org.vorpal.research.kex.asm.analysis.concolic.coverage.CoverageGuidedSelectorManager +import org.vorpal.research.kex.asm.analysis.concolic.coverage.ExperimentPathSelectorManager import org.vorpal.research.kex.asm.analysis.util.analyzeOrTimeout import org.vorpal.research.kex.asm.analysis.util.checkAsync import org.vorpal.research.kex.assertions.extractFinalInfo @@ -69,12 +70,13 @@ class InstructionConcolicChecker( ctx: ExecutionContext, targets: Set, strategyName: String, - ): ConcolicPathSelectorManager = when (strategyName) { - "bfs" -> BfsPathSelectorManager(ctx, targets) - "cgs" -> ContextGuidedSelectorManager(ctx, targets) - "coverage" -> CoverageGuidedSelectorManager(ctx, targets) - else -> unreachable { log.error("Unknown type of search strategy $strategyName") } - } + ): ConcolicPathSelectorManager = ExperimentPathSelectorManager(ctx, targets) +// when (strategyName) { +// "bfs" -> BfsPathSelectorManager(ctx, targets) +// "cgs" -> ContextGuidedSelectorManager(ctx, targets) +// "coverage" -> CoverageGuidedSelectorManager(ctx, targets) +// else -> unreachable { log.error("Unknown type of search strategy $strategyName") } +// } @ExperimentalTime @DelicateCoroutinesApi diff --git a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentPathSelector.kt b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentPathSelector.kt new file mode 100644 index 000000000..f7799044b --- /dev/null +++ b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentPathSelector.kt @@ -0,0 +1,997 @@ +package org.vorpal.research.kex.asm.analysis.concolic.coverage + +import kotlinx.collections.immutable.* +import kotlinx.coroutines.currentCoroutineContext +import kotlinx.coroutines.withContext +import kotlinx.coroutines.yield +import org.vorpal.research.kex.ExecutionContext +import org.vorpal.research.kex.asm.analysis.concolic.ConcolicPathSelector +import org.vorpal.research.kex.asm.analysis.concolic.ConcolicPathSelectorManager +import org.vorpal.research.kex.asm.analysis.symbolic.* +import org.vorpal.research.kex.asm.analysis.util.checkAsync +import org.vorpal.research.kex.asm.analysis.util.checkAsyncIncremental +import org.vorpal.research.kex.asm.manager.instantiationManager +import org.vorpal.research.kex.asserter.ExecutionFinalInfo +import org.vorpal.research.kex.asserter.extractExceptionFinalInfo +import org.vorpal.research.kex.asserter.extractSuccessFinalInfo +import org.vorpal.research.kex.compile.CompilationException +import org.vorpal.research.kex.descriptor.Descriptor +import org.vorpal.research.kex.descriptor.DescriptorBuilder +import org.vorpal.research.kex.ktype.KexPointer +import org.vorpal.research.kex.ktype.KexRtManager.rtMapped +import org.vorpal.research.kex.ktype.KexType +import org.vorpal.research.kex.ktype.kexType +import org.vorpal.research.kex.parameters.Parameters +import org.vorpal.research.kex.reanimator.UnsafeGenerator +import org.vorpal.research.kex.reanimator.codegen.klassName +import org.vorpal.research.kex.state.predicate.inverse +import org.vorpal.research.kex.state.predicate.path +import org.vorpal.research.kex.state.predicate.state +import org.vorpal.research.kex.state.term.* +import org.vorpal.research.kex.state.term.TermBuilder.Terms.arg +import org.vorpal.research.kex.state.term.TermBuilder.Terms.generate +import org.vorpal.research.kex.state.term.TermBuilder.Terms.length +import org.vorpal.research.kex.state.term.TermBuilder.Terms.staticRef +import org.vorpal.research.kex.state.term.TermBuilder.Terms.`this` +import org.vorpal.research.kex.state.transformer.isThis +import org.vorpal.research.kex.trace.symbolic.* +import org.vorpal.research.kex.trace.symbolic.protocol.ExecutionCompletedResult +import org.vorpal.research.kex.util.isSubtypeOfCached +import org.vorpal.research.kfg.ir.BasicBlock +import org.vorpal.research.kfg.ir.Method +import org.vorpal.research.kfg.ir.value.Constant +import org.vorpal.research.kfg.ir.value.Value +import org.vorpal.research.kfg.ir.value.ValueFactory +import org.vorpal.research.kfg.ir.value.instruction.* +import org.vorpal.research.kfg.type.Type +import org.vorpal.research.kfg.type.TypeFactory +import org.vorpal.research.kthelper.assert.unreachable +import org.vorpal.research.kthelper.logging.log + +//data class TraverserState( +// val symbolicState: PersistentSymbolicState, +// val valueMap: PersistentMap, +// val stackTrace: PersistentList, +// val typeInfo: PersistentMap, +// val blockPath: PersistentList, +// val nullCheckedTerms: PersistentSet, +// val boundCheckedTerms: PersistentSet>, +// val typeCheckedTerms: PersistentMap +//) { +// fun mkTerm(value: Value): Term = when (value) { +// is Constant -> term { const(value) } +// else -> valueMap.getValue(value) +// } +// +// fun copyTermInfo(from: Term, to: Term): TraverserState = this.copy( +// nullCheckedTerms = when (from) { +// in nullCheckedTerms -> nullCheckedTerms.add(to) +// else -> nullCheckedTerms +// }, +// typeCheckedTerms = when (from) { +// in typeCheckedTerms -> typeCheckedTerms.put(to, typeCheckedTerms[from]!!) +// else -> typeCheckedTerms +// } +// ) +// +// operator fun plus(state: PersistentSymbolicState): TraverserState = this.copy( +// symbolicState = this.symbolicState + state +// ) +// +// operator fun plus(clause: StateClause): TraverserState = this.copy( +// symbolicState = this.symbolicState + clause +// ) +// +// operator fun plus(clause: PathClause): TraverserState = this.copy( +// symbolicState = this.symbolicState + clause +// ) +// +// operator fun plus(basicBlock: BasicBlock): TraverserState = this.copy( +// blockPath = this.blockPath.add(basicBlock) +// ) +//} + +class ExperimentPathSelectorManager ( + override val ctx: ExecutionContext, + override val targets: Set +) : ConcolicPathSelectorManager { + + private val targetInstructions = targets.flatMapTo(mutableSetOf()) { it.body.flatten() } + private val coveredInstructions = mutableSetOf() + + val weightedGraph = WeightedGraph(targets, targetInstructions) + + fun isCovered(): Boolean { + val result = coveredInstructions.containsAll(targetInstructions) || + weightedGraph.targets.sumOf { weightedGraph.getVertex(it.body.entry.instructions.first()).score } == 0 + log.debug("Temp") + return result + } + + fun addCoverage(trace: List) { + coveredInstructions += trace + } + + override fun createPathSelectorFor(target: Method): ConcolicPathSelector = ExperimentPathSelector(this) +} + +class ExperimentPathSelector( + private val manager: ExperimentPathSelectorManager +) : ConcolicPathSelector { + + override val ctx: ExecutionContext + get() = manager.ctx + + override suspend fun isEmpty(): Boolean = manager.isCovered() + + override suspend fun addExecutionTrace( + method: Method, + checkedState: PersistentSymbolicState, + result: ExecutionCompletedResult + ) { + manager.addCoverage(result.trace) + manager.weightedGraph.addTrace(result.trace) + } + + override fun reverse(pathClause: PathClause): PathClause? { + TODO("Not yet implemented") + } + + override suspend fun hasNext(): Boolean = !isEmpty() + + override suspend fun next(): Pair { + val bestMethod = manager.targets.maxBy { manager.weightedGraph.getVertex(it.body.entry.instructions.first()).score } + val root = bestMethod.body.entry.instructions.first() + val path = manager.weightedGraph.getPath(root) + val state = processMethod(bestMethod, path) + return Pair(bestMethod, state.symbolicState) + } + + protected val Type.symbolicType: KexType get() = kexType.rtMapped + protected val org.vorpal.research.kfg.ir.Class.symbolicClass: KexType get() = kexType.rtMapped + + val types: TypeFactory + get() = ctx.types + + val values: ValueFactory + get() = ctx.values + + protected open suspend fun processMethod(method: Method, path: List): TraverserState { + val thisValue = values.getThis(method.klass) + val initialArguments = buildMap { + val values = this@ExperimentPathSelector.values + if (!method.isStatic) { + this[thisValue] = `this`(method.klass.symbolicClass) + } + for ((index, type) in method.argTypes.withIndex()) { + this[values.getArgument(index, method, type)] = arg(type.symbolicType, index) + } + } + + val initialState = when { + !method.isStatic -> { + val thisTerm = initialArguments[thisValue]!! + val thisType = method.klass.symbolicClass.getKfgType(types) + TraverserState( + symbolicState = persistentSymbolicState( + path = persistentPathConditionOf( + PathClause( + PathClauseType.NULL_CHECK, + method.body.entry.first(), + path { (thisTerm eq null) equality false } + ) + ) + ), + valueMap = initialArguments.toPersistentMap(), + stackTrace = persistentListOf(), + typeInfo = persistentMapOf(thisTerm to thisType), + blockPath = persistentListOf(), + nullCheckedTerms = persistentSetOf(thisTerm), + boundCheckedTerms = persistentSetOf(), + typeCheckedTerms = persistentMapOf(thisTerm to thisType) + ) + } + + else -> TraverserState( + symbolicState = persistentSymbolicState(), + valueMap = initialArguments.toPersistentMap(), + stackTrace = persistentListOf(), + typeInfo = persistentMapOf(), + blockPath = persistentListOf(), + nullCheckedTerms = persistentSetOf(), + boundCheckedTerms = persistentSetOf(), + typeCheckedTerms = persistentMapOf() + ) + } + + return getPersistentState(method, initialState, path) + } + + suspend fun findFirstUnreachable(method: Method, pathStates: List): Int { + if (method.checkAsync(ctx, pathStates.last().symbolicState) != null) return -1 + var startRange = 0 + var endRange = pathStates.size - 1 + while (endRange - startRange > 0) { + val pivot = (startRange + endRange) / 2 + val res = method.checkAsync(ctx, pathStates[pivot].symbolicState) + if (res == null) { + endRange = pivot + } + else { + startRange = pivot + 1 + } + } + return endRange + } + + suspend fun getPersistentState(method: Method, state: TraverserState, path: List): TraverserState { + var currentState: TraverserState = state + val instList = path.map { it.instruction } + log.debug(instList.toString()) + var pathStates = mutableListOf() + for (i in 0 until path.size-1) { + val inst = path[i].instruction + val nextInst = path.getOrNull(i+1)?.instruction + val newState = traverseInstruction(currentState, inst, nextInst) + if (newState == null) return currentState + currentState = newState + pathStates.add(currentState) + } + val lastPathClause = pathStates.indexOfFirst { it.symbolicState.path.size == currentState.symbolicState.path.size } + + val firstUnreachable = findFirstUnreachable(method, pathStates.subList(0, lastPathClause+1)) + if (firstUnreachable != -1) { + // go down until vertex with multiple possible paths + // it is needed because path clause added by this vertex is causing unreachability, the inst itself may be reachable + manager.weightedGraph.unreachables.add(path.slice(0..firstUnreachable+1)) + manager.weightedGraph.getVertex(path[firstUnreachable].instruction).invalidate() + + return pathStates[firstUnreachable-1] + } +// val concreteTypes: MutableMap = mutableMapOf() +// currentState.symbolicState.clauses.forEach { clause -> +// clause.predicate.operands.forEach { term -> +// if (term.type.javaName.contains("java.util")) { +// concreteTypes[term] = +// instantiationManager.getConcreteType(term.type, manager.ctx.cm, ctx.accessLevel, ctx.random) +// } +// term.subTerms.forEach { subTerm -> +// if (subTerm.type.javaName.contains("java.util")) { +// concreteTypes[subTerm] = +// instantiationManager.getConcreteType(subTerm.type, manager.ctx.cm, ctx.accessLevel, ctx.random) +// } +// } +// } +// } + //currentState.symbolicState.concreteTypes = concreteTypes.toPersistentMap() + val resultState = pathStates.getOrNull(lastPathClause+1) ?: currentState + return resultState + } + + suspend fun traverseInstruction(state: TraverserState, inst: Instruction, nextInstruction: Instruction?): TraverserState? { + try { + return when (inst) { + is ArrayLoadInst -> traverseArrayLoadInst(state, inst, nextInstruction) + is ArrayStoreInst -> traverseArrayStoreInst(state, inst, nextInstruction) + is BinaryInst -> traverseBinaryInst(state, inst) + is CallInst -> traverseCallInst(state, inst, nextInstruction) + is CastInst -> traverseCastInst(state, inst, nextInstruction) + is CatchInst -> traverseCatchInst(state, inst) + is CmpInst -> traverseCmpInst(state, inst) + is EnterMonitorInst -> traverseEnterMonitorInst(state, inst, nextInstruction) + is ExitMonitorInst -> traverseExitMonitorInst(state, inst) + is FieldLoadInst -> traverseFieldLoadInst(state, inst, nextInstruction) + is FieldStoreInst -> traverseFieldStoreInst(state, inst, nextInstruction) + is InstanceOfInst -> traverseInstanceOfInst(state, inst) + is InvokeDynamicInst -> traverseInvokeDynamicInst(state, inst) + is NewArrayInst -> traverseNewArrayInst(state, inst, nextInstruction) + is NewInst -> traverseNewInst(state, inst) + is PhiInst -> traversePhiInst(state, inst) + is UnaryInst -> traverseUnaryInst(state, inst, nextInstruction) + is BranchInst -> traverseBranchInst(state, inst, nextInstruction) + is JumpInst -> traverseJumpInst(state, inst) + is ReturnInst -> traverseReturnInst(state, inst) + is SwitchInst -> traverseSwitchInst(state, inst, nextInstruction) + is TableSwitchInst -> traverseTableSwitchInst(state, inst, nextInstruction) + is ThrowInst -> traverseThrowInst(state, inst, nextInstruction) + is UnreachableInst -> traverseUnreachableInst(state, inst) + is UnknownValueInst -> traverseUnknownValueInst(state, inst) + else -> unreachable("Unknown instruction ${inst.print()}") + } + } catch (e: Exception) { + log.debug(e.toString()) + return state + } + } + + fun nullCheck( + traverserState: TraverserState, + inst: Instruction, + nextInstruction: Instruction?, + term: Term + ): Pair { + if (term in traverserState.nullCheckedTerms) return Pair(true, traverserState) + if (term is ConstClassTerm) return Pair(true, traverserState) + if (term is StaticClassRefTerm) return Pair(true, traverserState) + if (term.isThis) return Pair(true, traverserState) + + val nullityClause = PathClause( + PathClauseType.NULL_CHECK, + inst, + path { (term eq null) equality true } + ) + return if (nextInstruction is CatchInst) { + Pair(false, traverserState + nullityClause) + } + else { + Pair(true, traverserState + nullityClause.inverse()) + } + } + + fun boundsCheck( + traverserState: TraverserState, + inst: Instruction, + nextInstruction: Instruction?, + index: Term, + length: Term + ): Pair { + if (index to index in traverserState.boundCheckedTerms) return Pair(true, traverserState) + val zeroClause = PathClause( + PathClauseType.BOUNDS_CHECK, + inst, + path { (index ge 0) equality false } + ) + val lengthClause = PathClause( + PathClauseType.BOUNDS_CHECK, + inst, + path { (index lt length) equality false } + ) + // TODO: think about other case + return if (nextInstruction is CatchInst) { + Pair(false, traverserState + zeroClause) + } + else { + Pair(true, traverserState + zeroClause.inverse() + lengthClause.inverse()) + } + } + + fun typeCheck( + state: TraverserState, + inst: Instruction, + nextInstruction: Instruction?, + term: Term, + type: KexType + ): Pair { + if (type !is KexPointer) return Pair(true, state) + val previouslyCheckedType = state.typeCheckedTerms[term] + val currentlyCheckedType = type.getKfgType(ctx.types) + if (previouslyCheckedType != null && currentlyCheckedType.isSubtypeOfCached(previouslyCheckedType)) { + return Pair(true, state) + } + + val typeClause = PathClause( + PathClauseType.TYPE_CHECK, + inst, + path { (term `is` type) equality false } + ) + + return if (nextInstruction is CatchInst) { + Pair(false, state + typeClause) + } + else { + Pair(true, state + typeClause.inverse()) + } + } + + fun newArrayBoundsCheck( + state: TraverserState, + inst: Instruction, + nextInstruction: Instruction?, + index: Term + ): Pair { + if (index to index in state.boundCheckedTerms) return Pair(true, state) + + val zeroClause = PathClause( + PathClauseType.BOUNDS_CHECK, + inst, + path { (index ge 0) equality false } + ) + val noExceptionConstraints = persistentSymbolicState() + zeroClause.inverse() + val zeroCheckConstraints = persistentSymbolicState() + zeroClause + + if (nextInstruction is CatchInst) { + return Pair(false, state + zeroCheckConstraints) + } + else { + val res = state + noExceptionConstraints + return Pair(true, res.copy(boundCheckedTerms = res.boundCheckedTerms.add(index to index)) + noExceptionConstraints) + } + } + + protected open suspend fun traverseArrayLoadInst( + traverserState: TraverserState, + inst: ArrayLoadInst, + nextInstruction: Instruction? + ): TraverserState? { + val arrayTerm = traverserState.mkTerm(inst.arrayRef) + val indexTerm = traverserState.mkTerm(inst.index) + val res = generate(inst.type.symbolicType) + + if (arrayTerm is NullTerm) { + return nullCheck(traverserState, inst, nextInstruction, arrayTerm).second + } + + val clause = StateClause(inst, state { res equality arrayTerm[indexTerm].load() }) + + var result = nullCheck(traverserState, inst, nextInstruction, arrayTerm) + if (!result.first) { + return result.second + } + result = boundsCheck(result.second, inst, nextInstruction, indexTerm, arrayTerm.length()) + if (!result.first) { + return result.second + } + return result.second.copy( + symbolicState = result.second.symbolicState + clause, + valueMap = result.second.valueMap.put(inst, res) + ) + } + + protected open suspend fun traverseArrayStoreInst( + traverserState: TraverserState, + inst: ArrayStoreInst, + nextInstruction: Instruction? + ): TraverserState? { + val arrayTerm = traverserState.mkTerm(inst.arrayRef) + val indexTerm = traverserState.mkTerm(inst.index) + val valueTerm = traverserState.mkTerm(inst.value) + + if (arrayTerm is NullTerm) { + return nullCheck(traverserState, inst, nextInstruction, arrayTerm).second + } + + val clause = StateClause(inst, state { arrayTerm[indexTerm].store(valueTerm) }) + + var result = nullCheck(traverserState, inst, nextInstruction, arrayTerm) + if (!result.first) { + return result.second + } + result = boundsCheck(result.second, inst, nextInstruction, indexTerm, arrayTerm.length()) + if (!result.first) { + return result.second + } + return result.second + clause + } + + protected open suspend fun traverseBinaryInst(traverserState: TraverserState, inst: BinaryInst): TraverserState { + val lhvTerm = traverserState.mkTerm(inst.lhv) + val rhvTerm = traverserState.mkTerm(inst.rhv) + val resultTerm = generate(inst.type.symbolicType) + + val clause = StateClause( + inst, + state { resultTerm equality lhvTerm.apply(resultTerm.type, inst.opcode, rhvTerm) } + ) + return traverserState.copy( + symbolicState = traverserState.symbolicState + clause, + valueMap = traverserState.valueMap.put(inst, resultTerm) + ) + } + + protected open suspend fun traverseBranchInst( + traverserState: TraverserState, + inst: BranchInst, + nextInstruction: Instruction? + ): TraverserState { + val condTerm = traverserState.mkTerm(inst.cond) + + val trueClause = PathClause( + PathClauseType.CONDITION_CHECK, + inst, + path { condTerm equality true } + ) + val falseClause = trueClause.inverse() + + if (nextInstruction in inst.trueSuccessor) { + return traverserState + trueClause + inst.parent + } + else return traverserState + falseClause + inst.parent + } + + val callResolver: SymbolicCallResolver = DefaultCallResolver(ctx) + + protected open suspend fun traverseCallInst( + traverserState: TraverserState, + inst: CallInst, + nextInstruction: Instruction? + ): TraverserState { + val callee = when { + inst.isStatic -> staticRef(inst.method.klass) + else -> traverserState.mkTerm(inst.callee) + } + val argumentTerms = inst.args.map { traverserState.mkTerm(it) } + val candidates = callResolver.resolve(traverserState, inst) + + var (isCheckSuccess, result) = nullCheck(traverserState, inst, nextInstruction, callee) + if (!isCheckSuccess) { + return result + } + val candidate = candidates.find { !it.body.entry.isEmpty && it.body.entry.instructions[0] == nextInstruction } + result = when { + candidate == null -> { + var varState = result + val receiver = when { + inst.isNameDefined -> { + val res = generate(inst.type.symbolicType) + varState = varState.copy( + valueMap = traverserState.valueMap.put(inst, res) + ) + res + } + + else -> null + } + val callClause = StateClause( + inst, state { + val callTerm = callee.call(inst.method, argumentTerms) + receiver?.call(callTerm) ?: call(callTerm) + } + ) + varState + callClause + } + + else -> processMethodCall(result, inst, nextInstruction, candidate, callee, argumentTerms) + } + return result + } + + protected open suspend fun traverseCastInst( + traverserState: TraverserState, + inst: CastInst, + nextInstruction: Instruction? + ): TraverserState { + val operandTerm = traverserState.mkTerm(inst.operand) + val resultTerm = generate(inst.type.symbolicType) + val clause = StateClause( + inst, + state { resultTerm equality (operandTerm `as` resultTerm.type) } + ) + + var (isCheckSuccess, result) = typeCheck(traverserState, inst, nextInstruction, operandTerm, resultTerm.type) + if (!isCheckSuccess) { + return result + } + result = result.copy( + symbolicState = result.symbolicState + clause, + valueMap = result.valueMap.put(inst, resultTerm) + ).copyTermInfo(operandTerm, resultTerm) + + return result + } + + protected open suspend fun traverseCatchInst(traverserState: TraverserState, inst: CatchInst): TraverserState { + return traverserState + } + + protected open suspend fun traverseCmpInst( + traverserState: TraverserState, + inst: CmpInst + ): TraverserState { + val lhvTerm = traverserState.mkTerm(inst.lhv) + val rhvTerm = traverserState.mkTerm(inst.rhv) + val resultTerm = generate(inst.type.symbolicType) + + val clause = StateClause( + inst, + state { resultTerm equality lhvTerm.apply(inst.opcode, rhvTerm) } + ) + return traverserState.copy( + symbolicState = traverserState.symbolicState + clause, + valueMap = traverserState.valueMap.put(inst, resultTerm) + ) + } + + protected open suspend fun traverseEnterMonitorInst( + traverserState: TraverserState, + inst: EnterMonitorInst, + nextInstruction: Instruction? + ): TraverserState { + val monitorTerm = traverserState.mkTerm(inst.owner) + val clause = StateClause( + inst, + state { enterMonitor(monitorTerm) } + ) + + val (isCheckSuccess, result) = nullCheck(traverserState, inst, nextInstruction, monitorTerm) + if (!isCheckSuccess) { + return result + } + return result + clause + } + + protected open suspend fun traverseExitMonitorInst( + traverserState: TraverserState, + inst: ExitMonitorInst + ): TraverserState { + val monitorTerm = traverserState.mkTerm(inst.owner) + val clause = StateClause( + inst, + state { exitMonitor(monitorTerm) } + ) + return traverserState + clause + } + + protected open suspend fun traverseFieldLoadInst( + traverserState: TraverserState, + inst: FieldLoadInst, + nextInstruction: Instruction? + ): TraverserState { + val field = inst.field + val objectTerm = when { + inst.isStatic -> staticRef(field.klass) + else -> traverserState.mkTerm(inst.owner) + } + + if (objectTerm is NullTerm) { + return nullCheck(traverserState, inst, nextInstruction, objectTerm).second + } + + val res = generate(inst.type.symbolicType) + val clause = StateClause( + inst, + state { res equality objectTerm.field(field.type.symbolicType, field.name).load() } + ) + + val (isCheckSuccess, result) = nullCheck(traverserState, inst, nextInstruction, objectTerm) + if (!isCheckSuccess) return result + + val newNullChecked = when { + field.isStatic && field.isFinal -> when (field.defaultValue) { + null -> result.nullCheckedTerms.add(res) + ctx.values.nullConstant -> result.nullCheckedTerms + else -> result.nullCheckedTerms.add(res) + } + + else -> result.nullCheckedTerms + } + return result.copy( + symbolicState = result.symbolicState + clause, + valueMap = result.valueMap.put(inst, res), + nullCheckedTerms = newNullChecked + ) + } + + protected open suspend fun traverseFieldStoreInst( + traverserState: TraverserState, + inst: FieldStoreInst, + nextInstruction: Instruction? + ): TraverserState { + val objectTerm = when { + inst.isStatic -> staticRef(inst.field.klass) + else -> traverserState.mkTerm(inst.owner) + } + + if (objectTerm is NullTerm) { + return nullCheck(traverserState, inst, nextInstruction, objectTerm).second + } + + val valueTerm = traverserState.mkTerm(inst.value) + val clause = StateClause( + inst, + state { objectTerm.field(inst.field.type.symbolicType, inst.field.name).store(valueTerm) } + ) + + val (isCheckSuccess, result) = nullCheck(traverserState, inst, nextInstruction, objectTerm) + if (!isCheckSuccess) return result + + return result.copy( + symbolicState = result.symbolicState + clause, + valueMap = result.valueMap.put(inst, valueTerm) + ) + } + + protected open suspend fun traverseInstanceOfInst( + traverserState: TraverserState, + inst: InstanceOfInst + ): TraverserState { + val operandTerm = traverserState.mkTerm(inst.operand) + val resultTerm = generate(inst.type.symbolicType) + + val clause = StateClause( + inst, + state { resultTerm equality (operandTerm `is` inst.targetType.symbolicType) } + ) + + val previouslyCheckedType = traverserState.typeCheckedTerms[operandTerm] + val currentlyCheckedType = operandTerm.type.getKfgType(ctx.types) + + return traverserState.copy( + symbolicState = traverserState.symbolicState + clause, + valueMap = traverserState.valueMap.put(inst, resultTerm), + typeCheckedTerms = when { + previouslyCheckedType != null && currentlyCheckedType.isSubtypeOfCached(previouslyCheckedType) -> + traverserState.typeCheckedTerms.put(operandTerm, inst.targetType) + + else -> traverserState.typeCheckedTerms + } + ) + } + + val invokeDynamicResolver: SymbolicInvokeDynamicResolver = DefaultCallResolver(ctx) + + protected open suspend fun traverseInvokeDynamicInst( + traverserState: TraverserState, + inst: InvokeDynamicInst + ): TraverserState? { + return when (invokeDynamicResolver.resolve(traverserState, inst)) { + null -> traverserState.copy( + valueMap = traverserState.valueMap.put(inst, generate(inst.type.kexType)) + ) + + else -> invokeDynamicResolver.resolve(traverserState, inst) + } + } + + protected open suspend fun processMethodCall( + traverserState: TraverserState, + inst: Instruction, + nextInstruction: Instruction?, + candidate: Method, + callee: Term, + argumentTerms: List + ): TraverserState { + if (candidate.body.isEmpty()) return traverserState + + var newValueMap = traverserState.valueMap.builder().let { builder -> + if (!candidate.isStatic) builder[values.getThis(candidate.klass)] = callee + for ((index, type) in candidate.argTypes.withIndex()) { + builder[values.getArgument(index, candidate, type)] = argumentTerms[index] + } + builder.build() + } + + when { + candidate.isStatic -> return traverserState.copy( + valueMap = newValueMap, + stackTrace = traverserState.stackTrace.add( + SymbolicStackTraceElement(inst.parent.method, inst, traverserState.valueMap) + ) + ) + + else -> { + var (isCheckSuccess, result) = typeCheck(traverserState, inst, nextInstruction, callee, candidate.klass.symbolicClass) + if (!isCheckSuccess) { + return result + } + result = when { + candidate.klass.asType.isSubtypeOfCached(callee.type.getKfgType(types)) -> { + val newCalleeTerm = generate(candidate.klass.symbolicClass) + val convertClause = StateClause(inst, state { + newCalleeTerm equality (callee `as` candidate.klass.symbolicClass) + }) + newValueMap = newValueMap.mapValues { (_, term) -> + when (term) { + callee -> newCalleeTerm + else -> term + } + }.toPersistentMap() + result.copy( + symbolicState = result.symbolicState + convertClause + ).copyTermInfo(callee, newCalleeTerm) + } + + else -> traverserState + }.copy( + valueMap = newValueMap, + stackTrace = result.stackTrace.add( + SymbolicStackTraceElement(inst.parent.method, inst, result.valueMap) + ) + ) + return result + } + } + } + + protected open suspend fun traverseNewArrayInst( + traverserState: TraverserState, + inst: NewArrayInst, + nextInstruction: Instruction? + ): TraverserState { + val dimensions = inst.dimensions.map { traverserState.mkTerm(it) } + val resultTerm = generate(inst.type.symbolicType) + val clause = StateClause(inst, state { resultTerm.new(dimensions) }) + + var result: TraverserState = traverserState + dimensions.forEach { dimension -> + val r = newArrayBoundsCheck(traverserState, inst, nextInstruction, dimension) + if (!r.first) { + return result + } + result = r.second + } + + return result.copy( + symbolicState = result.symbolicState + clause, + typeInfo = result.typeInfo.put(resultTerm, inst.type.rtMapped), + valueMap = result.valueMap.put(inst, resultTerm), + nullCheckedTerms = result.nullCheckedTerms.add(resultTerm), + typeCheckedTerms = result.typeCheckedTerms.put(resultTerm, inst.type) + ) + } + + protected open suspend fun traverseNewInst( + traverserState: TraverserState, + inst: NewInst + ): TraverserState { + val resultTerm = generate(inst.type.symbolicType) + val clause = StateClause( + inst, + state { resultTerm.new() } + ) + return traverserState.copy( + symbolicState = traverserState.symbolicState + clause, + typeInfo = traverserState.typeInfo.put(resultTerm, inst.type.rtMapped), + valueMap = traverserState.valueMap.put(inst, resultTerm), + nullCheckedTerms = traverserState.nullCheckedTerms.add(resultTerm), + typeCheckedTerms = traverserState.typeCheckedTerms.put(resultTerm, inst.type) + ) + } + + protected open suspend fun traversePhiInst( + traverserState: TraverserState, + inst: PhiInst + ): TraverserState { + val previousBlock = traverserState.blockPath.last { it.method == inst.parent.method } + val value = traverserState.mkTerm(inst.incomings.getValue(previousBlock)) + return traverserState.copy( + valueMap = traverserState.valueMap.put(inst, value) + ) + } + + protected open suspend fun traverseUnaryInst( + traverserState: TraverserState, + inst: UnaryInst, + nextInstruction: Instruction? + ): TraverserState { + val operandTerm = traverserState.mkTerm(inst.operand) + val resultTerm = generate(inst.type.symbolicType) + val clause = StateClause( + inst, + state { resultTerm equality operandTerm.apply(inst.opcode) } + ) + + val result: TraverserState = when (inst.opcode) { + UnaryOpcode.LENGTH -> nullCheck(traverserState, inst, nextInstruction, operandTerm).second + else -> traverserState + } + + return result.copy( + symbolicState = result.symbolicState + clause, + valueMap = result.valueMap.put(inst, resultTerm) + ) + } + + protected open suspend fun traverseJumpInst( + traverserState: TraverserState, + inst: JumpInst + ): TraverserState { + return traverserState + inst.parent + } + + protected open suspend fun traverseReturnInst( + traverserState: TraverserState, + inst: ReturnInst + ): TraverserState { + val stackTrace = traverserState.stackTrace + val stackTraceElement = stackTrace.lastOrNull() + val receiver = stackTraceElement?.instruction + val result = when { + receiver == null -> { + return traverserState + } + + inst.hasReturnValue && receiver.isNameDefined -> { + val returnTerm = traverserState.mkTerm(inst.returnValue) + traverserState.copy( + valueMap = stackTraceElement.valueMap.put(receiver, returnTerm), + stackTrace = stackTrace.removeAt(stackTrace.lastIndex) + ) + } + + else -> traverserState.copy( + valueMap = stackTraceElement.valueMap, + stackTrace = stackTrace.removeAt(stackTrace.lastIndex) + ) + } + return result + } + + protected open suspend fun traverseSwitchInst( + traverserState: TraverserState, + inst: SwitchInst, + nextInstruction: Instruction? + ): TraverserState { + val key = traverserState.mkTerm(inst.key) + + for ((value, branch) in inst.branches) { + if (nextInstruction !in branch.instructions) { + continue + } + val path = PathClause( + PathClauseType.CONDITION_CHECK, + inst, + path { (key eq traverserState.mkTerm(value)) equality true } + ) + return traverserState + path + inst.parent + } + val defaultPath = PathClause( + PathClauseType.CONDITION_CHECK, + inst, + path { key `!in` inst.operands.map { traverserState.mkTerm(it) } } + ) + return traverserState + defaultPath + inst.parent + } + + protected open suspend fun traverseTableSwitchInst( + traverserState: TraverserState, + inst: TableSwitchInst, + nextInstruction: Instruction? + ): TraverserState? { + val key = traverserState.mkTerm(inst.index) + val min = inst.range.first + for ((index, branch) in inst.branches.withIndex()) { + if (nextInstruction !in branch.instructions) { + continue + } + val path = PathClause( + PathClauseType.CONDITION_CHECK, + inst, + path { (key eq const(min + index)) equality true } + ) + return traverserState + path + inst.parent + } + val defaultPath = PathClause( + PathClauseType.CONDITION_CHECK, + inst, + path { key `!in` inst.range.map { const(it) } } + ) + return traverserState + defaultPath + inst.parent + } + + protected open suspend fun traverseThrowInst( + traverserState: TraverserState, + inst: ThrowInst, + nextInstruction: Instruction? + ): TraverserState? { + val throwableTerm = traverserState.mkTerm(inst.throwable) + val throwClause = StateClause( + inst, + state { `throw`(throwableTerm) } + ) + + var (isCheckPassed, result) = nullCheck(traverserState, inst, nextInstruction, throwableTerm) + if (!isCheckPassed) { + return result + } + return result + throwClause + } + + protected open suspend fun traverseUnreachableInst( + traverserState: TraverserState, + inst: UnreachableInst + ): TraverserState? { + return null + } + + protected open suspend fun traverseUnknownValueInst( + traverserState: TraverserState, + inst: UnknownValueInst + ): TraverserState? { + return unreachable("Unexpected visit of $inst in symbolic traverser") + } + + @Suppress("NOTHING_TO_INLINE") + protected inline fun PathClause.inverse(): PathClause = this.copy( + predicate = this.predicate.inverse(ctx.random) + ) +} \ No newline at end of file diff --git a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentSelector.kt b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentSelector.kt new file mode 100644 index 000000000..938343f6d --- /dev/null +++ b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentSelector.kt @@ -0,0 +1,4 @@ +package org.vorpal.research.kex.asm.analysis.concolic.coverage + +class ExperimentSelector { +} \ No newline at end of file diff --git a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/WeightedGraph.kt b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/WeightedGraph.kt new file mode 100644 index 000000000..ab290c663 --- /dev/null +++ b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/WeightedGraph.kt @@ -0,0 +1,271 @@ +package org.vorpal.research.kex.asm.analysis.concolic.coverage + +import org.vorpal.research.kex.asm.manager.instantiationManager +import org.vorpal.research.kex.asm.util.AccessModifier +import org.vorpal.research.kex.ktype.KexRtManager.isKexRt +import org.vorpal.research.kex.ktype.KexRtManager.rtMapped +import org.vorpal.research.kex.ktype.KexRtManager.rtUnmapped +import org.vorpal.research.kfg.ir.BasicBlock +import org.vorpal.research.kfg.ir.Method +import org.vorpal.research.kfg.ir.value.instruction.* +import org.vorpal.research.kthelper.assert.ktassert +import org.vorpal.research.kthelper.collection.mapToArray +import org.vorpal.research.kthelper.collection.queueOf +import org.vorpal.research.kthelper.logging.log +import org.vorpal.research.kthelper.tryOrNull +import kotlin.math.max + +class WeightedGraph( + val targets: Set, + val targetInstructions: Set +) { + private val MIN_COVERED_SCORE = 40 + + val nodes = mutableMapOf() + var unreachables: MutableList> = mutableListOf() + + inner class Vertex(val instruction: Instruction, val predecessors: MutableSet, var coveredScore: Int = 1) { + + val CYCLE_EDGE_SCORE = 4 + + val upEdges = mutableSetOf() + val downEdges = mutableSetOf() + val cycleEdges = mutableMapOf() + var score = 0 + var isValid = false + + override fun toString(): String = instruction.print() + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + + other as Vertex + + return instruction == other.instruction + } + + override fun hashCode(): Int { + return instruction.hashCode() + } + + fun linkDown(other: Vertex) { + downEdges += other + other.upEdges += this + other.addPredecessors(predecessors + this) +// other.predecessors += this +// other.predecessors += predecessors + } + + fun addPredecessors(pred: Set) { + predecessors += pred + for (v in downEdges) { + v.addPredecessors(predecessors) + } + } + + fun addCycleEdge(other: Vertex) { + cycleEdges[other] = CYCLE_EDGE_SCORE + predecessors += other.predecessors + } + + fun decreaseCycleEdgeWeight(other: Vertex) { + val currentWeight = cycleEdges[other] + if (currentWeight != null && currentWeight >= 1) { + cycleEdges[other] = currentWeight - 1 + } + } + + fun invalidate() { + isValid = false + predecessors.forEach { it.isValid = false } + } + + fun recomputeScore(visited: MutableList = mutableListOf()) { + if (isValid) return + visited.add(this) + if (visited.size > 100) { + log.debug("Hahaha") + } + score = 0 + if (instruction in targetInstructions) { + score += coveredScore + } + for (vertex in downEdges) { + if (!vertex.isValid) { + vertex.recomputeScore(visited) + } + if (visited + vertex !in unreachables) { + score += vertex.score + } + } + for (s in cycleEdges.values) { + score += s + } + isValid = true + visited.removeLast() + } + } + + data class QueueEntry(val prev: Vertex?, val block: BasicBlock, val index: Int) + + fun getVertex(instruction: Instruction): Vertex { + if (instruction in nodes) { + return nodes[instruction]!! + } else { + val method = instruction.parent.method + ktassert(instruction == method.body.entry.first()) + + val queue = queueOf() + queue += QueueEntry(null, method.body.entry, 0) + queue.addAll(method.body.catchEntries.map { QueueEntry(null, it, 0) }) + val visited = mutableSetOf>() + val resolves = mutableMapOf>() + + while (queue.isNotEmpty()) { + val (prev, block, index) = queue.poll() + val current = block.instructions[index] + if (prev?.instruction to current in visited) continue + visited += prev?.instruction to current + + val vertex = nodes.getOrPut(current) { Vertex(current, prev?.predecessors?.toMutableSet() ?: mutableSetOf()) } + + if (prev?.predecessors?.contains(vertex) == true) { + prev.addCycleEdge(vertex) + } + else { + prev?.linkDown(vertex) + } + + when (current) { + is CallInst -> { + val resolvedMethods = resolves.getOrPut(current) { + when (current.opcode) { + CallOpcode.STATIC -> listOf(current.method) + CallOpcode.SPECIAL -> listOf(current.method) + CallOpcode.INTERFACE, CallOpcode.VIRTUAL -> { + val currentMethod = current.method + + val targetPackages = targets.map { it.klass.pkg }.toSet() + + val retTypeMapped = currentMethod.returnType.rtMapped + val argTypesMapped = currentMethod.argTypes.mapToArray { it.rtMapped } + val retTypeUnmapped = currentMethod.returnType.rtUnmapped + val argTypesUnmapped = currentMethod.argTypes.mapToArray { it.rtUnmapped } + instantiationManager.getAllConcreteSubtypes( + currentMethod.klass, + AccessModifier.Private + ) + .filter { klass -> targetPackages.any { it.isParent(klass.pkg) } } + .mapNotNullTo(mutableSetOf()) { + tryOrNull { + if (it.isKexRt) { + it.getMethod(currentMethod.name, retTypeMapped, *argTypesMapped) + } else { + it.getMethod(currentMethod.name, retTypeUnmapped, *argTypesUnmapped) + } + } + } + .filter { it.hasBody } + } + }.filter { it.hasBody } + } + var connectedExits = false + for (candidate in resolvedMethods) { + queue += QueueEntry(vertex, candidate.body.entry, 0) + queue.addAll(candidate.body.catchEntries.map { QueueEntry(null, it, 0) }) + + candidate.body.flatten().filterIsInstance().forEach { + connectedExits = true + val returnVertex = nodes.getOrPut(it) { Vertex(it, mutableSetOf()) } + queue += QueueEntry(returnVertex, block, index + 1) + } + } + + if (!connectedExits) { + queue += QueueEntry(vertex, block, index + 1) + } + } + + is TerminateInst -> current.successors.forEach { + queue += QueueEntry(vertex, it, 0) + } + + else -> queue += QueueEntry(vertex, block, index + 1) + } + } + + return nodes.getOrPut(instruction) { Vertex(instruction, mutableSetOf()) } + } + } + + fun addTrace(trace: List) { + var prev: Vertex? = null + for (inst in trace) { + val current = getVertex(inst) + current.coveredScore = max(0, current.coveredScore-1) + if (prev == null) { + prev = current + continue + } + if (prev.cycleEdges.contains(current)) { + prev.decreaseCycleEdgeWeight(current) + prev.invalidate() + } + else if (prev.predecessors.contains(current)) { + prev.addCycleEdge(current) + } + else if (!prev.predecessors.contains(current)) { + prev.linkDown(current) + } + + prev = current + } + prev?.invalidate() + targets.forEach { getVertex(it.body.entry.instructions.first()).recomputeScore() } +// nodes[trace[0]]?.recomputeScore() + } + + fun getPath(root: Instruction): List { + var prev = getVertex(root) +// if (!prev.isValid) { +// prev.recomputeScore() +// } + val path = mutableListOf(prev) + var coveredScore = 0 + while (prev.downEdges.size > 0 && coveredScore < MIN_COVERED_SCORE) { + var nextVertex = prev.downEdges.maxBy { if (path + it in unreachables) 0 else it.score } + val nextVertexInclCycles = prev.cycleEdges.maxByOrNull { if (path + it in unreachables) 0 else it.value } + if (nextVertexInclCycles != null && nextVertexInclCycles.value != 0) { + // explore cycle + val randomNum = (0..10).random() + if (randomNum <= nextVertexInclCycles.value && nextVertex.score <= nextVertexInclCycles.key.score) { + nextVertex = nextVertexInclCycles.key + coveredScore += 1 + // TODO: move to add trace + prev.decreaseCycleEdgeWeight(nextVertex) + prev.invalidate() + } + } + + nextVertex.coveredScore = max(0, nextVertex.coveredScore-1) + + if (nextVertex.score == 0) break + + if (nextVertex.instruction in targetInstructions) { + // TODO: move to add trace + coveredScore += nextVertex.coveredScore + } +// nextVertex.cycleEdges.values.forEach { +// coveredScore += it +// } + path.add(nextVertex) + prev = nextVertex + } + prev.invalidate() + targets.forEach { getVertex(it.body.entry.instructions.first()).recomputeScore() } + return path + } + + +} \ No newline at end of file From 7e7d392d831f27aa687c900a0b09e39564cae1d8 Mon Sep 17 00:00:00 2001 From: AndreyYurko Date: Thu, 18 Apr 2024 21:50:31 +0300 Subject: [PATCH 2/3] small change --- .../concolic/coverage/ExperimentPathSelector.kt | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentPathSelector.kt b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentPathSelector.kt index f7799044b..76c61a542 100644 --- a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentPathSelector.kt +++ b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentPathSelector.kt @@ -9,21 +9,10 @@ import org.vorpal.research.kex.asm.analysis.concolic.ConcolicPathSelector import org.vorpal.research.kex.asm.analysis.concolic.ConcolicPathSelectorManager import org.vorpal.research.kex.asm.analysis.symbolic.* import org.vorpal.research.kex.asm.analysis.util.checkAsync -import org.vorpal.research.kex.asm.analysis.util.checkAsyncIncremental -import org.vorpal.research.kex.asm.manager.instantiationManager -import org.vorpal.research.kex.asserter.ExecutionFinalInfo -import org.vorpal.research.kex.asserter.extractExceptionFinalInfo -import org.vorpal.research.kex.asserter.extractSuccessFinalInfo -import org.vorpal.research.kex.compile.CompilationException -import org.vorpal.research.kex.descriptor.Descriptor -import org.vorpal.research.kex.descriptor.DescriptorBuilder import org.vorpal.research.kex.ktype.KexPointer import org.vorpal.research.kex.ktype.KexRtManager.rtMapped import org.vorpal.research.kex.ktype.KexType import org.vorpal.research.kex.ktype.kexType -import org.vorpal.research.kex.parameters.Parameters -import org.vorpal.research.kex.reanimator.UnsafeGenerator -import org.vorpal.research.kex.reanimator.codegen.klassName import org.vorpal.research.kex.state.predicate.inverse import org.vorpal.research.kex.state.predicate.path import org.vorpal.research.kex.state.predicate.state @@ -37,10 +26,7 @@ import org.vorpal.research.kex.state.transformer.isThis import org.vorpal.research.kex.trace.symbolic.* import org.vorpal.research.kex.trace.symbolic.protocol.ExecutionCompletedResult import org.vorpal.research.kex.util.isSubtypeOfCached -import org.vorpal.research.kfg.ir.BasicBlock import org.vorpal.research.kfg.ir.Method -import org.vorpal.research.kfg.ir.value.Constant -import org.vorpal.research.kfg.ir.value.Value import org.vorpal.research.kfg.ir.value.ValueFactory import org.vorpal.research.kfg.ir.value.instruction.* import org.vorpal.research.kfg.type.Type From e0103a82f71206e4057123e343238666b69eabcc Mon Sep 17 00:00:00 2001 From: AndreyYurko Date: Wed, 29 May 2024 12:56:13 +0300 Subject: [PATCH 3/3] weighted path selector finalisation --- .../concolic/InstructionConcolicChecker.kt | 16 +- .../concolic/coverage/ExperimentSelector.kt | 4 - .../concolic/coverage/WeightedGraph.kt | 271 ---------- .../concolic/weighted/WeightedGraph.kt | 473 ++++++++++++++++++ .../WeightedPathSelector.kt} | 419 +++++++--------- 5 files changed, 658 insertions(+), 525 deletions(-) delete mode 100644 kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentSelector.kt delete mode 100644 kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/WeightedGraph.kt create mode 100644 kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/weighted/WeightedGraph.kt rename kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/{coverage/ExperimentPathSelector.kt => weighted/WeightedPathSelector.kt} (70%) diff --git a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/InstructionConcolicChecker.kt b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/InstructionConcolicChecker.kt index 9ce29fc15..6d30fd147 100644 --- a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/InstructionConcolicChecker.kt +++ b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/InstructionConcolicChecker.kt @@ -8,7 +8,7 @@ import org.vorpal.research.kex.ExecutionContext import org.vorpal.research.kex.asm.analysis.concolic.bfs.BfsPathSelectorManager import org.vorpal.research.kex.asm.analysis.concolic.cgs.ContextGuidedSelectorManager import org.vorpal.research.kex.asm.analysis.concolic.coverage.CoverageGuidedSelectorManager -import org.vorpal.research.kex.asm.analysis.concolic.coverage.ExperimentPathSelectorManager +import org.vorpal.research.kex.asm.analysis.concolic.weighted.WeightedPathSelectorManager import org.vorpal.research.kex.asm.analysis.util.analyzeOrTimeout import org.vorpal.research.kex.asm.analysis.util.checkAsync import org.vorpal.research.kex.assertions.extractFinalInfo @@ -70,13 +70,13 @@ class InstructionConcolicChecker( ctx: ExecutionContext, targets: Set, strategyName: String, - ): ConcolicPathSelectorManager = ExperimentPathSelectorManager(ctx, targets) -// when (strategyName) { -// "bfs" -> BfsPathSelectorManager(ctx, targets) -// "cgs" -> ContextGuidedSelectorManager(ctx, targets) -// "coverage" -> CoverageGuidedSelectorManager(ctx, targets) -// else -> unreachable { log.error("Unknown type of search strategy $strategyName") } -// } + ): ConcolicPathSelectorManager = when (strategyName) { + "bfs" -> BfsPathSelectorManager(ctx, targets) + "cgs" -> ContextGuidedSelectorManager(ctx, targets) + "coverage" -> CoverageGuidedSelectorManager(ctx, targets) + "weighted" -> WeightedPathSelectorManager(ctx, targets) + else -> unreachable { log.error("Unknown type of search strategy $strategyName") } + } @ExperimentalTime @DelicateCoroutinesApi diff --git a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentSelector.kt b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentSelector.kt deleted file mode 100644 index 938343f6d..000000000 --- a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentSelector.kt +++ /dev/null @@ -1,4 +0,0 @@ -package org.vorpal.research.kex.asm.analysis.concolic.coverage - -class ExperimentSelector { -} \ No newline at end of file diff --git a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/WeightedGraph.kt b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/WeightedGraph.kt deleted file mode 100644 index ab290c663..000000000 --- a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/WeightedGraph.kt +++ /dev/null @@ -1,271 +0,0 @@ -package org.vorpal.research.kex.asm.analysis.concolic.coverage - -import org.vorpal.research.kex.asm.manager.instantiationManager -import org.vorpal.research.kex.asm.util.AccessModifier -import org.vorpal.research.kex.ktype.KexRtManager.isKexRt -import org.vorpal.research.kex.ktype.KexRtManager.rtMapped -import org.vorpal.research.kex.ktype.KexRtManager.rtUnmapped -import org.vorpal.research.kfg.ir.BasicBlock -import org.vorpal.research.kfg.ir.Method -import org.vorpal.research.kfg.ir.value.instruction.* -import org.vorpal.research.kthelper.assert.ktassert -import org.vorpal.research.kthelper.collection.mapToArray -import org.vorpal.research.kthelper.collection.queueOf -import org.vorpal.research.kthelper.logging.log -import org.vorpal.research.kthelper.tryOrNull -import kotlin.math.max - -class WeightedGraph( - val targets: Set, - val targetInstructions: Set -) { - private val MIN_COVERED_SCORE = 40 - - val nodes = mutableMapOf() - var unreachables: MutableList> = mutableListOf() - - inner class Vertex(val instruction: Instruction, val predecessors: MutableSet, var coveredScore: Int = 1) { - - val CYCLE_EDGE_SCORE = 4 - - val upEdges = mutableSetOf() - val downEdges = mutableSetOf() - val cycleEdges = mutableMapOf() - var score = 0 - var isValid = false - - override fun toString(): String = instruction.print() - - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (javaClass != other?.javaClass) return false - - other as Vertex - - return instruction == other.instruction - } - - override fun hashCode(): Int { - return instruction.hashCode() - } - - fun linkDown(other: Vertex) { - downEdges += other - other.upEdges += this - other.addPredecessors(predecessors + this) -// other.predecessors += this -// other.predecessors += predecessors - } - - fun addPredecessors(pred: Set) { - predecessors += pred - for (v in downEdges) { - v.addPredecessors(predecessors) - } - } - - fun addCycleEdge(other: Vertex) { - cycleEdges[other] = CYCLE_EDGE_SCORE - predecessors += other.predecessors - } - - fun decreaseCycleEdgeWeight(other: Vertex) { - val currentWeight = cycleEdges[other] - if (currentWeight != null && currentWeight >= 1) { - cycleEdges[other] = currentWeight - 1 - } - } - - fun invalidate() { - isValid = false - predecessors.forEach { it.isValid = false } - } - - fun recomputeScore(visited: MutableList = mutableListOf()) { - if (isValid) return - visited.add(this) - if (visited.size > 100) { - log.debug("Hahaha") - } - score = 0 - if (instruction in targetInstructions) { - score += coveredScore - } - for (vertex in downEdges) { - if (!vertex.isValid) { - vertex.recomputeScore(visited) - } - if (visited + vertex !in unreachables) { - score += vertex.score - } - } - for (s in cycleEdges.values) { - score += s - } - isValid = true - visited.removeLast() - } - } - - data class QueueEntry(val prev: Vertex?, val block: BasicBlock, val index: Int) - - fun getVertex(instruction: Instruction): Vertex { - if (instruction in nodes) { - return nodes[instruction]!! - } else { - val method = instruction.parent.method - ktassert(instruction == method.body.entry.first()) - - val queue = queueOf() - queue += QueueEntry(null, method.body.entry, 0) - queue.addAll(method.body.catchEntries.map { QueueEntry(null, it, 0) }) - val visited = mutableSetOf>() - val resolves = mutableMapOf>() - - while (queue.isNotEmpty()) { - val (prev, block, index) = queue.poll() - val current = block.instructions[index] - if (prev?.instruction to current in visited) continue - visited += prev?.instruction to current - - val vertex = nodes.getOrPut(current) { Vertex(current, prev?.predecessors?.toMutableSet() ?: mutableSetOf()) } - - if (prev?.predecessors?.contains(vertex) == true) { - prev.addCycleEdge(vertex) - } - else { - prev?.linkDown(vertex) - } - - when (current) { - is CallInst -> { - val resolvedMethods = resolves.getOrPut(current) { - when (current.opcode) { - CallOpcode.STATIC -> listOf(current.method) - CallOpcode.SPECIAL -> listOf(current.method) - CallOpcode.INTERFACE, CallOpcode.VIRTUAL -> { - val currentMethod = current.method - - val targetPackages = targets.map { it.klass.pkg }.toSet() - - val retTypeMapped = currentMethod.returnType.rtMapped - val argTypesMapped = currentMethod.argTypes.mapToArray { it.rtMapped } - val retTypeUnmapped = currentMethod.returnType.rtUnmapped - val argTypesUnmapped = currentMethod.argTypes.mapToArray { it.rtUnmapped } - instantiationManager.getAllConcreteSubtypes( - currentMethod.klass, - AccessModifier.Private - ) - .filter { klass -> targetPackages.any { it.isParent(klass.pkg) } } - .mapNotNullTo(mutableSetOf()) { - tryOrNull { - if (it.isKexRt) { - it.getMethod(currentMethod.name, retTypeMapped, *argTypesMapped) - } else { - it.getMethod(currentMethod.name, retTypeUnmapped, *argTypesUnmapped) - } - } - } - .filter { it.hasBody } - } - }.filter { it.hasBody } - } - var connectedExits = false - for (candidate in resolvedMethods) { - queue += QueueEntry(vertex, candidate.body.entry, 0) - queue.addAll(candidate.body.catchEntries.map { QueueEntry(null, it, 0) }) - - candidate.body.flatten().filterIsInstance().forEach { - connectedExits = true - val returnVertex = nodes.getOrPut(it) { Vertex(it, mutableSetOf()) } - queue += QueueEntry(returnVertex, block, index + 1) - } - } - - if (!connectedExits) { - queue += QueueEntry(vertex, block, index + 1) - } - } - - is TerminateInst -> current.successors.forEach { - queue += QueueEntry(vertex, it, 0) - } - - else -> queue += QueueEntry(vertex, block, index + 1) - } - } - - return nodes.getOrPut(instruction) { Vertex(instruction, mutableSetOf()) } - } - } - - fun addTrace(trace: List) { - var prev: Vertex? = null - for (inst in trace) { - val current = getVertex(inst) - current.coveredScore = max(0, current.coveredScore-1) - if (prev == null) { - prev = current - continue - } - if (prev.cycleEdges.contains(current)) { - prev.decreaseCycleEdgeWeight(current) - prev.invalidate() - } - else if (prev.predecessors.contains(current)) { - prev.addCycleEdge(current) - } - else if (!prev.predecessors.contains(current)) { - prev.linkDown(current) - } - - prev = current - } - prev?.invalidate() - targets.forEach { getVertex(it.body.entry.instructions.first()).recomputeScore() } -// nodes[trace[0]]?.recomputeScore() - } - - fun getPath(root: Instruction): List { - var prev = getVertex(root) -// if (!prev.isValid) { -// prev.recomputeScore() -// } - val path = mutableListOf(prev) - var coveredScore = 0 - while (prev.downEdges.size > 0 && coveredScore < MIN_COVERED_SCORE) { - var nextVertex = prev.downEdges.maxBy { if (path + it in unreachables) 0 else it.score } - val nextVertexInclCycles = prev.cycleEdges.maxByOrNull { if (path + it in unreachables) 0 else it.value } - if (nextVertexInclCycles != null && nextVertexInclCycles.value != 0) { - // explore cycle - val randomNum = (0..10).random() - if (randomNum <= nextVertexInclCycles.value && nextVertex.score <= nextVertexInclCycles.key.score) { - nextVertex = nextVertexInclCycles.key - coveredScore += 1 - // TODO: move to add trace - prev.decreaseCycleEdgeWeight(nextVertex) - prev.invalidate() - } - } - - nextVertex.coveredScore = max(0, nextVertex.coveredScore-1) - - if (nextVertex.score == 0) break - - if (nextVertex.instruction in targetInstructions) { - // TODO: move to add trace - coveredScore += nextVertex.coveredScore - } -// nextVertex.cycleEdges.values.forEach { -// coveredScore += it -// } - path.add(nextVertex) - prev = nextVertex - } - prev.invalidate() - targets.forEach { getVertex(it.body.entry.instructions.first()).recomputeScore() } - return path - } - - -} \ No newline at end of file diff --git a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/weighted/WeightedGraph.kt b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/weighted/WeightedGraph.kt new file mode 100644 index 000000000..fca47c9f4 --- /dev/null +++ b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/weighted/WeightedGraph.kt @@ -0,0 +1,473 @@ +package org.vorpal.research.kex.asm.analysis.concolic.weighted + +import org.vorpal.research.kex.ExecutionContext +import org.vorpal.research.kex.asm.manager.instantiationManager +import org.vorpal.research.kex.asm.util.AccessModifier +import org.vorpal.research.kex.ktype.KexRtManager.isKexRt +import org.vorpal.research.kex.ktype.KexRtManager.rtMapped +import org.vorpal.research.kex.ktype.KexRtManager.rtUnmapped +import org.vorpal.research.kfg.ir.BasicBlock +import org.vorpal.research.kfg.ir.Method +import org.vorpal.research.kfg.ir.value.instruction.* +import org.vorpal.research.kthelper.collection.mapToArray +import org.vorpal.research.kthelper.collection.queueOf +import org.vorpal.research.kthelper.logging.log +import org.vorpal.research.kthelper.tryOrNull +import kotlin.math.max + +class WeightedGraph( + val ctx: ExecutionContext, + val targets: Set, + val targetInstructions: Set +) { + private val MIN_COVERED_SCORE = 40 + private val MAX_DEPTH = 3 + // nodes size greater than this number will cause MAX_DEPTH = 1 behavior + private val MAX_NODES_SIZE = 1000 + val ISUFFICIENT_PATH_SCORE = 0.1 + + private val nodes = mutableMapOf() + + inner class Vertex(val instruction: Instruction, val predecessors: MutableSet, var coveredScore: Int = 1) { + + private val CYCLE_EDGE_SCORE = 4 + private val upEdges = mutableSetOf() + private val downEdges = mutableSetOf() + private val _cycleEdgesScores = mutableMapOf() + private val beforePathFindingCycleEdgesScores = mutableMapOf() + private var beforePathFindingScore = 0 + private var isValid = false + + val cycleEdgesScores: Map + get() = _cycleEdgesScores + var score: Double = 0.0 + private set + + + override fun toString(): String = instruction.print() + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + + other as Vertex + + return instruction == other.instruction + } + + override fun hashCode(): Int { + return instruction.hashCode() + } + + fun hasSuccessors() = downEdges.isNotEmpty() || _cycleEdgesScores.isNotEmpty() + + fun linkDown(other: Vertex) { + downEdges += other + other.upEdges += this + other.addPredecessors(predecessors + this) + invalidate() + } + + // this function used only in case if we need to expand our graph after call instruction + // so predecessors of other still contain this vertex and other predecessors + fun breakLink(other: Vertex) { + downEdges.remove(other) + } + + fun addCycleEdge(other: Vertex) { + _cycleEdgesScores[other] = CYCLE_EDGE_SCORE + addPredecessors(other.predecessors) + invalidate() + } + + private fun addPredecessors(parentPredecessors: Set) { + predecessors += parentPredecessors + for (v in downEdges) { + v.addPredecessors(predecessors) + } + } + + fun decreaseCycleEdgeWeight(other: Vertex) { + val currentWeight = _cycleEdgesScores[other] + if (currentWeight != null && currentWeight >= 1) { + _cycleEdgesScores[other] = currentWeight - 1 + } + invalidate() + } + + fun reassignCycleEdgesScore(score: Int) { + _cycleEdgesScores.forEach { (k, _) -> + _cycleEdgesScores[k] = score + } + } + + fun decreaseCoveredScore() { + coveredScore = max(0, coveredScore - 1) + invalidate() + } + + fun invalidate() { + isValid = false + predecessors.forEach { it.isValid = false } + } + + fun saveScore() { + beforePathFindingScore = coveredScore + beforePathFindingCycleEdgesScores.clear() + _cycleEdgesScores.forEach { (k, v) -> + beforePathFindingCycleEdgesScores[k] = v + } + } + + fun restoreScore() { + coveredScore = beforePathFindingScore + beforePathFindingCycleEdgesScores.forEach { (k, v) -> + _cycleEdgesScores[k] = v + } + } + + fun recomputeScore(currentMultiplierGraphVertex: MultiplierGraphVertex? = rootMultiplier.downEdges[this]) { + if (isValid) return + score = 0.0 + // if instruction not in target instruction, we are not interested in this coverage + if (instruction in targetInstructions) { + score += coveredScore + } + // recompute score for all successors + for (vertex in downEdges) { + val multiplierPathScore = currentMultiplierGraphVertex?.downEdges?.get(vertex)?.scoreMultiplier ?: 1.0 + // path is unreachable or was tried too many times + if (multiplierPathScore <= ISUFFICIENT_PATH_SCORE) { + continue + } + vertex.recomputeScore(currentMultiplierGraphVertex?.downEdges?.get(vertex)) + score += vertex.score * multiplierPathScore + } + // add all cycle edges scores + for (s in _cycleEdgesScores.values) { + score += s + } + isValid = true + } + + fun nextVertex(currentMultiplierGraphVertex: MultiplierGraphVertex?): Vertex? { + if (downEdges.isEmpty() && _cycleEdgesScores.isEmpty()) return null + + val cycleEdgesScoresWithMultiplier = _cycleEdgesScores.mapValues { + val scoreMultiplier = currentMultiplierGraphVertex?.downEdges?.get(it.key)?.scoreMultiplier ?: 1.0 + if (scoreMultiplier < ISUFFICIENT_PATH_SCORE) 0.0 else it.value * scoreMultiplier + } + val totalCycleEdgesScore = cycleEdgesScoresWithMultiplier.values.sum() + // if a cycle is available, when explore it + if (totalCycleEdgesScore > 0.1) { + val random = ctx.random.nextDouble(totalCycleEdgesScore) + var current = 0.0 + for (cycleEdge in cycleEdgesScoresWithMultiplier) { + current += cycleEdge.value + if (random < current) { + val nextVertex = cycleEdge.key + decreaseCycleEdgeWeight(nextVertex) + // jump to some previous vertexes, state computation is needed + invalidate() + recomputeScores() + return nextVertex + } + } + } + val downEdgesScoresWithMultiplier = downEdges.associateWith { + val scoreMultiplier = currentMultiplierGraphVertex?.downEdges?.get(it)?.scoreMultiplier ?: 1.0 + if (scoreMultiplier < ISUFFICIENT_PATH_SCORE) 0.0 else it.score * scoreMultiplier + } + val totalDownEdgesScore = downEdgesScoresWithMultiplier.values.sum() + if (totalDownEdgesScore < 0.1) return null + + val random = ctx.random.nextDouble(totalDownEdgesScore) + var current = 0.0 + for (edge in downEdgesScoresWithMultiplier) { + current += edge.value + if (random < current) { + return edge.key + } + } + + // everything is unreachable + return null + } + } + + fun recomputeScores() = targets.forEach { getVertex(it.body.flatten().first()).restoreScore() } + + fun reassignCyclesEdgesScores(score: Int) { + nodes.forEach { (_, v) -> + v.reassignCycleEdgesScore(score) + } + } + + data class QueueEntry(val prev: Vertex?, val block: BasicBlock, val index: Int, val depth: Int) + + // previous is not null only in case if previous is a call instruction + fun getVertex(instruction: Instruction, previous: Vertex? = null): Vertex { + if (instruction in nodes) { + return nodes[instruction]!! + } else { + val method = instruction.parent.method + //ktassert(instruction == method.body.entry.first() || method.body.catchEntries.any { instruction == it.first() }) + + (previous?.instruction as? CallInst)?.method?.body?.flatten()?.filterIsInstance()?.forEach { + previous.breakLink(getVertex(it)) + } + + val queue = queueOf() + queue.add(QueueEntry(previous, method.body.entry, 0, 0)) + queue.addAll(method.body.catchEntries.map { QueueEntry(null, it, 0, 0) }) + val visited = mutableSetOf>() + val resolves = mutableMapOf>() + + while (queue.isNotEmpty()) { + val (prev, block, index, depth) = queue.poll() + //val (prev, block, index) = queue.poll() + //if (nodes.size > MAX_NODES_SIZE) break + if (depth >= MAX_DEPTH) continue + + val current = block.instructions[index] + if (prev?.instruction to current in visited) continue + visited += prev?.instruction to current + + val vertex = nodes.getOrPut(current) { + Vertex( + current, + prev?.predecessors?.toMutableSet() ?: mutableSetOf() + ) + } + + if (prev?.predecessors?.contains(vertex) == true) { + prev.addCycleEdge(vertex) + } else { + prev?.linkDown(vertex) + } + + when (current) { + is CallInst -> { + val resolvedMethods = resolves.getOrPut(current) { + when (current.opcode) { + CallOpcode.STATIC -> listOf(current.method) + CallOpcode.SPECIAL -> listOf(current.method) + CallOpcode.INTERFACE, CallOpcode.VIRTUAL -> { + val currentMethod = current.method + + val targetPackages = targets.map { it.klass.pkg }.toMutableSet() + + val retTypeMapped = currentMethod.returnType.rtMapped + val argTypesMapped = currentMethod.argTypes.mapToArray { it.rtMapped } + val retTypeUnmapped = currentMethod.returnType.rtUnmapped + val argTypesUnmapped = currentMethod.argTypes.mapToArray { it.rtUnmapped } + instantiationManager.getAllConcreteSubtypes( + currentMethod.klass, + AccessModifier.Private + ) + .filter { klass -> targetPackages.any { it.isParent(klass.pkg) } } + .mapNotNullTo(mutableSetOf()) { + tryOrNull { + if (it.isKexRt) { + it.getMethod(currentMethod.name, retTypeMapped, *argTypesMapped) + } else { + it.getMethod( + currentMethod.name, + retTypeUnmapped, + *argTypesUnmapped + ) + } + } + } + .filter { it.hasBody } + } + }.filter { it.hasBody } + } + var connectedExits = false + for (candidate in resolvedMethods) { + // if nodes too much stop exploring + if (nodes.size < MAX_NODES_SIZE) { + queue += QueueEntry(vertex, candidate.body.entry, 0, depth + 1) + queue.addAll(candidate.body.catchEntries.map { QueueEntry(null, it, 0, depth + 1) }) + } + + candidate.body.flatten().filterIsInstance().forEach { + connectedExits = true + val returnVertex = nodes.getOrPut(it) { + Vertex(it, mutableSetOf()).also { ver -> + // TODO: maybe instead of using the same link and then deleting it on expansion + if (depth + 1 >= MAX_DEPTH || nodes.size >= MAX_NODES_SIZE) { + vertex.linkDown(ver) + } + } + } + queue += QueueEntry(returnVertex, block, index + 1, depth) + } + } + + if (!connectedExits) { + queue += QueueEntry(vertex, block, index + 1, depth) + } + } + + is TerminateInst -> current.successors.forEach { + queue += QueueEntry(vertex, it, 0, depth) + } + + else -> queue += QueueEntry(vertex, block, index + 1, depth) + } + } + + return nodes.getOrPut(instruction) { Vertex(instruction, mutableSetOf()) }.also { it.recomputeScore() } + } + + } + + private var expectedTrace: List? = null + + fun addTrace(trace: List) = try { + var prev: Vertex? = null + for ((i, inst) in trace.withIndex()) { + if (expectedTrace != null) { + val expectedInstruction = expectedTrace?.getOrNull(i)?.instruction + if (expectedInstruction == null) { + expectedTrace = null + } else { + if (inst != expectedInstruction) { + changePathScoreMultiplier(expectedTrace!!.subList(0, i)) + expectedTrace = null + } + } + } + + val current = getVertex(inst, prev) + current.decreaseCoveredScore() + + if (prev == null) { + prev = current + continue + } + + if (prev.cycleEdgesScores.contains(current)) { + prev.decreaseCycleEdgeWeight(current) + } else if (prev.predecessors.contains(current)) { + prev.addCycleEdge(current) + } else if (!prev.predecessors.contains(current)) { + prev.linkDown(current) + } + + prev = current + } + + prev?.invalidate() + targets.forEach { + val methodRootVertex = getVertex(it.body.entry.instructions.first()) + methodRootVertex.recomputeScore() + } + } catch (e: Exception) { + log.debug(e.stackTraceToString()) + } + + private fun saveScores() { + nodes.forEach { + it.value.saveScore() + } + } + + private fun restoreScores() { + nodes.forEach { + it.value.restoreScore() + } + } + + // for transforming path selector to the symbolic uncomment lines below and modify add trace + //var scoreChangeHistory: MutableList Unit>> = mutableListOf() + + fun getPath(root: Instruction): List { + saveScores() + //scoreChangeHistory = mutableListOf() + + var prev = getVertex(root) + // var curIndex = 0 + + prev.decreaseCoveredScore() + //val refToRoot = prev + //scoreChangeHistory.add(Pair(curIndex) { refToRoot.decreaseCoveredScore() }) + + var prevUnreachableGraphVertex = rootMultiplier.downEdges[prev] + val path = mutableListOf(prev) + var coveredScore = 0 + while (prev.hasSuccessors() && coveredScore < MIN_COVERED_SCORE) { + //curIndex += 1 + val nextVertex = prev.nextVertex(prevUnreachableGraphVertex) ?: break + + nextVertex.decreaseCoveredScore() + //scoreChangeHistory.add(Pair(curIndex) { nextVertex.decreaseCoveredScore() }) + + if (nextVertex.instruction in targetInstructions) { + coveredScore += nextVertex.coveredScore + } + + path.add(nextVertex) + prev = nextVertex + prevUnreachableGraphVertex = prevUnreachableGraphVertex?.downEdges?.get(nextVertex) + } + restoreScores() + prev.invalidate() + expectedTrace = path + return path + } + + data class MultiplierGraphVertex( + val vertex: Vertex?, + val downEdges: MutableMap = mutableMapOf(), + var scoreMultiplier: Double = 1.0 + ) + + private val rootMultiplier = MultiplierGraphVertex(null) + + fun changePathScoreMultiplier(path: List, scoreMultiplierChange: Double = 0.5) { + if (path.isEmpty()) return + var currentMultiplierVertex = rootMultiplier + for (vertex in path) { + val nextMultiplierVertex = currentMultiplierVertex.downEdges.getOrPut(vertex) { + MultiplierGraphVertex(vertex) + } + if (nextMultiplierVertex.scoreMultiplier <= ISUFFICIENT_PATH_SCORE) return + currentMultiplierVertex = nextMultiplierVertex + } + currentMultiplierVertex.vertex?.invalidate() + currentMultiplierVertex.scoreMultiplier *= scoreMultiplierChange + getVertex(path.first().instruction).restoreScore() + } + +// fun addUnreachable(unreachablePath: List) { +// if (unreachablePath.isEmpty()) { +// targets.forEach { +// val methodRootVertex = getVertex(it.body.entry.instructions.first()) +// methodRootVertex.recomputeScore() +// } +// return +// } +// restoreScores() +// scoreChangeHistory.filter { it.first < unreachablePath.size-1 }.forEach { +// it.second.invoke() +// } +// var currentUnreachableVertex = rootUnreachable +// for (vertex in unreachablePath) { +// val nextUnreachableGraphVertex = currentUnreachableVertex.downEdges.getOrPut(vertex) { +// UnreachableGraphVertex(vertex) +// } +// // some sub-path already unreachable +// if (nextUnreachableGraphVertex.isTerminal) return +// currentUnreachableVertex = nextUnreachableGraphVertex +// } +// currentUnreachableVertex.isTerminal = true +// currentUnreachableVertex.vertex?.invalidate() +// targets.forEach { +// val methodRootVertex = getVertex(it.body.entry.instructions.first()) +// methodRootVertex.recomputeScore() +// } +// } + +} \ No newline at end of file diff --git a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentPathSelector.kt b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/weighted/WeightedPathSelector.kt similarity index 70% rename from kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentPathSelector.kt rename to kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/weighted/WeightedPathSelector.kt index 76c61a542..e6df4dd45 100644 --- a/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/coverage/ExperimentPathSelector.kt +++ b/kex-runner/src/main/kotlin/org/vorpal/research/kex/asm/analysis/concolic/weighted/WeightedPathSelector.kt @@ -1,9 +1,6 @@ -package org.vorpal.research.kex.asm.analysis.concolic.coverage +package org.vorpal.research.kex.asm.analysis.concolic.weighted import kotlinx.collections.immutable.* -import kotlinx.coroutines.currentCoroutineContext -import kotlinx.coroutines.withContext -import kotlinx.coroutines.yield import org.vorpal.research.kex.ExecutionContext import org.vorpal.research.kex.asm.analysis.concolic.ConcolicPathSelector import org.vorpal.research.kex.asm.analysis.concolic.ConcolicPathSelectorManager @@ -33,76 +30,44 @@ import org.vorpal.research.kfg.type.Type import org.vorpal.research.kfg.type.TypeFactory import org.vorpal.research.kthelper.assert.unreachable import org.vorpal.research.kthelper.logging.log +import kotlin.math.pow -//data class TraverserState( -// val symbolicState: PersistentSymbolicState, -// val valueMap: PersistentMap, -// val stackTrace: PersistentList, -// val typeInfo: PersistentMap, -// val blockPath: PersistentList, -// val nullCheckedTerms: PersistentSet, -// val boundCheckedTerms: PersistentSet>, -// val typeCheckedTerms: PersistentMap -//) { -// fun mkTerm(value: Value): Term = when (value) { -// is Constant -> term { const(value) } -// else -> valueMap.getValue(value) -// } -// -// fun copyTermInfo(from: Term, to: Term): TraverserState = this.copy( -// nullCheckedTerms = when (from) { -// in nullCheckedTerms -> nullCheckedTerms.add(to) -// else -> nullCheckedTerms -// }, -// typeCheckedTerms = when (from) { -// in typeCheckedTerms -> typeCheckedTerms.put(to, typeCheckedTerms[from]!!) -// else -> typeCheckedTerms -// } -// ) -// -// operator fun plus(state: PersistentSymbolicState): TraverserState = this.copy( -// symbolicState = this.symbolicState + state -// ) -// -// operator fun plus(clause: StateClause): TraverserState = this.copy( -// symbolicState = this.symbolicState + clause -// ) -// -// operator fun plus(clause: PathClause): TraverserState = this.copy( -// symbolicState = this.symbolicState + clause -// ) -// -// operator fun plus(basicBlock: BasicBlock): TraverserState = this.copy( -// blockPath = this.blockPath.add(basicBlock) -// ) -//} - -class ExperimentPathSelectorManager ( + +class WeightedPathSelectorManager ( override val ctx: ExecutionContext, override val targets: Set ) : ConcolicPathSelectorManager { private val targetInstructions = targets.flatMapTo(mutableSetOf()) { it.body.flatten() } private val coveredInstructions = mutableSetOf() + private var stage = 1 + private val MAX_STAGE = 3 - val weightedGraph = WeightedGraph(targets, targetInstructions) + val weightedGraph = WeightedGraph(ctx, targets, targetInstructions) fun isCovered(): Boolean { - val result = coveredInstructions.containsAll(targetInstructions) || - weightedGraph.targets.sumOf { weightedGraph.getVertex(it.body.entry.instructions.first()).score } == 0 - log.debug("Temp") - return result + val isStageCovered = coveredInstructions.containsAll(targetInstructions) || weightedGraph.targets.all { + weightedGraph.getVertex(it.body.entry.instructions.first()).score < weightedGraph.ISUFFICIENT_PATH_SCORE + } + + if (isStageCovered) { + stage++ + if (stage <= MAX_STAGE) { + weightedGraph.reassignCyclesEdgesScores(10.0.pow(stage.toDouble()).toInt()) + } + } + return stage > MAX_STAGE } fun addCoverage(trace: List) { - coveredInstructions += trace + coveredInstructions += trace.filter { it in targetInstructions } } - override fun createPathSelectorFor(target: Method): ConcolicPathSelector = ExperimentPathSelector(this) + override fun createPathSelectorFor(target: Method): ConcolicPathSelector = WeightedPathSelector(this) } -class ExperimentPathSelector( - private val manager: ExperimentPathSelectorManager +class WeightedPathSelector( + private val manager: WeightedPathSelectorManager ) : ConcolicPathSelector { override val ctx: ExecutionContext @@ -129,12 +94,15 @@ class ExperimentPathSelector( val bestMethod = manager.targets.maxBy { manager.weightedGraph.getVertex(it.body.entry.instructions.first()).score } val root = bestMethod.body.entry.instructions.first() val path = manager.weightedGraph.getPath(root) + if (path.size <= 1) { + return Pair(bestMethod, persistentSymbolicState()) + } val state = processMethod(bestMethod, path) return Pair(bestMethod, state.symbolicState) } - protected val Type.symbolicType: KexType get() = kexType.rtMapped - protected val org.vorpal.research.kfg.ir.Class.symbolicClass: KexType get() = kexType.rtMapped + private val Type.symbolicType: KexType get() = kexType.rtMapped + private val org.vorpal.research.kfg.ir.Class.symbolicClass: KexType get() = kexType.rtMapped val types: TypeFactory get() = ctx.types @@ -142,10 +110,10 @@ class ExperimentPathSelector( val values: ValueFactory get() = ctx.values - protected open suspend fun processMethod(method: Method, path: List): TraverserState { + private suspend fun processMethod(method: Method, path: List): TraverserState { val thisValue = values.getThis(method.klass) val initialArguments = buildMap { - val values = this@ExperimentPathSelector.values + val values = this@WeightedPathSelector.values if (!method.isStatic) { this[thisValue] = `this`(method.klass.symbolicClass) } @@ -193,7 +161,8 @@ class ExperimentPathSelector( return getPersistentState(method, initialState, path) } - suspend fun findFirstUnreachable(method: Method, pathStates: List): Int { + // binary search for first unreachable + private suspend fun findFirstUnreachable(method: Method, pathStates: List): Int { if (method.checkAsync(ctx, pathStates.last().symbolicState) != null) return -1 var startRange = 0 var endRange = pathStates.size - 1 @@ -210,51 +179,31 @@ class ExperimentPathSelector( return endRange } - suspend fun getPersistentState(method: Method, state: TraverserState, path: List): TraverserState { + private suspend fun getPersistentState(method: Method, state: TraverserState, path: List): TraverserState { var currentState: TraverserState = state - val instList = path.map { it.instruction } - log.debug(instList.toString()) - var pathStates = mutableListOf() + // pathStates contains history of traversal state + val pathStates = mutableListOf(state) for (i in 0 until path.size-1) { val inst = path[i].instruction val nextInst = path.getOrNull(i+1)?.instruction - val newState = traverseInstruction(currentState, inst, nextInst) - if (newState == null) return currentState + val newState = traverseInstruction(currentState, inst, nextInst) ?: break currentState = newState pathStates.add(currentState) } + // for finding first unreachable we only need everything before path clause + chosen path val lastPathClause = pathStates.indexOfFirst { it.symbolicState.path.size == currentState.symbolicState.path.size } - val firstUnreachable = findFirstUnreachable(method, pathStates.subList(0, lastPathClause+1)) - if (firstUnreachable != -1) { - // go down until vertex with multiple possible paths - // it is needed because path clause added by this vertex is causing unreachability, the inst itself may be reachable - manager.weightedGraph.unreachables.add(path.slice(0..firstUnreachable+1)) - manager.weightedGraph.getVertex(path[firstUnreachable].instruction).invalidate() + // -1 means the built path is reachable + if (firstUnreachable != -1) { + manager.weightedGraph.changePathScoreMultiplier(path.slice(0..firstUnreachable), 0.0) return pathStates[firstUnreachable-1] } -// val concreteTypes: MutableMap = mutableMapOf() -// currentState.symbolicState.clauses.forEach { clause -> -// clause.predicate.operands.forEach { term -> -// if (term.type.javaName.contains("java.util")) { -// concreteTypes[term] = -// instantiationManager.getConcreteType(term.type, manager.ctx.cm, ctx.accessLevel, ctx.random) -// } -// term.subTerms.forEach { subTerm -> -// if (subTerm.type.javaName.contains("java.util")) { -// concreteTypes[subTerm] = -// instantiationManager.getConcreteType(subTerm.type, manager.ctx.cm, ctx.accessLevel, ctx.random) -// } -// } -// } -// } - //currentState.symbolicState.concreteTypes = concreteTypes.toPersistentMap() - val resultState = pathStates.getOrNull(lastPathClause+1) ?: currentState + val resultState = pathStates.getOrNull(lastPathClause + 1) ?: currentState return resultState } - suspend fun traverseInstruction(state: TraverserState, inst: Instruction, nextInstruction: Instruction?): TraverserState? { + private fun traverseInstruction(state: TraverserState, inst: Instruction, nextInstruction: Instruction?): TraverserState? { try { return when (inst) { is ArrayLoadInst -> traverseArrayLoadInst(state, inst, nextInstruction) @@ -262,7 +211,7 @@ class ExperimentPathSelector( is BinaryInst -> traverseBinaryInst(state, inst) is CallInst -> traverseCallInst(state, inst, nextInstruction) is CastInst -> traverseCastInst(state, inst, nextInstruction) - is CatchInst -> traverseCatchInst(state, inst) + is CatchInst -> traverseCatchInst(state) is CmpInst -> traverseCmpInst(state, inst) is EnterMonitorInst -> traverseEnterMonitorInst(state, inst, nextInstruction) is ExitMonitorInst -> traverseExitMonitorInst(state, inst) @@ -280,26 +229,31 @@ class ExperimentPathSelector( is SwitchInst -> traverseSwitchInst(state, inst, nextInstruction) is TableSwitchInst -> traverseTableSwitchInst(state, inst, nextInstruction) is ThrowInst -> traverseThrowInst(state, inst, nextInstruction) - is UnreachableInst -> traverseUnreachableInst(state, inst) - is UnknownValueInst -> traverseUnknownValueInst(state, inst) + is UnreachableInst -> traverseUnreachableInst() + is UnknownValueInst -> traverseUnknownValueInst(inst) else -> unreachable("Unknown instruction ${inst.print()}") } } catch (e: Exception) { - log.debug(e.toString()) + log.debug(e.stackTraceToString()) return state } } - fun nullCheck( + sealed class CheckResult(val state: TraverserState) + + class SuccessCheck(state: TraverserState): CheckResult(state) + class UnsuccessfulCheck(state: TraverserState): CheckResult(state) + + private fun nullCheck( traverserState: TraverserState, inst: Instruction, nextInstruction: Instruction?, term: Term - ): Pair { - if (term in traverserState.nullCheckedTerms) return Pair(true, traverserState) - if (term is ConstClassTerm) return Pair(true, traverserState) - if (term is StaticClassRefTerm) return Pair(true, traverserState) - if (term.isThis) return Pair(true, traverserState) + ): CheckResult { + if (term in traverserState.nullCheckedTerms) return SuccessCheck(traverserState) + if (term is ConstClassTerm) return SuccessCheck(traverserState) + if (term is StaticClassRefTerm) return SuccessCheck(traverserState) + if (term.isThis) return SuccessCheck(traverserState) val nullityClause = PathClause( PathClauseType.NULL_CHECK, @@ -307,21 +261,21 @@ class ExperimentPathSelector( path { (term eq null) equality true } ) return if (nextInstruction is CatchInst) { - Pair(false, traverserState + nullityClause) + UnsuccessfulCheck(traverserState + nullityClause) } else { - Pair(true, traverserState + nullityClause.inverse()) + SuccessCheck(traverserState + nullityClause.inverse()) } } - fun boundsCheck( + private fun boundsCheck( traverserState: TraverserState, inst: Instruction, nextInstruction: Instruction?, index: Term, length: Term - ): Pair { - if (index to index in traverserState.boundCheckedTerms) return Pair(true, traverserState) + ): CheckResult { + if (index to index in traverserState.boundCheckedTerms) return SuccessCheck(traverserState) val zeroClause = PathClause( PathClauseType.BOUNDS_CHECK, inst, @@ -334,25 +288,25 @@ class ExperimentPathSelector( ) // TODO: think about other case return if (nextInstruction is CatchInst) { - Pair(false, traverserState + zeroClause) + UnsuccessfulCheck(traverserState + zeroClause) } else { - Pair(true, traverserState + zeroClause.inverse() + lengthClause.inverse()) + SuccessCheck(traverserState + zeroClause.inverse() + lengthClause.inverse()) } } - fun typeCheck( + private fun typeCheck( state: TraverserState, inst: Instruction, nextInstruction: Instruction?, term: Term, type: KexType - ): Pair { - if (type !is KexPointer) return Pair(true, state) + ): CheckResult { + if (type !is KexPointer) return SuccessCheck(state) val previouslyCheckedType = state.typeCheckedTerms[term] val currentlyCheckedType = type.getKfgType(ctx.types) if (previouslyCheckedType != null && currentlyCheckedType.isSubtypeOfCached(previouslyCheckedType)) { - return Pair(true, state) + return SuccessCheck(state) } val typeClause = PathClause( @@ -362,20 +316,20 @@ class ExperimentPathSelector( ) return if (nextInstruction is CatchInst) { - Pair(false, state + typeClause) + UnsuccessfulCheck(state + typeClause) } else { - Pair(true, state + typeClause.inverse()) + SuccessCheck(state + typeClause.inverse()) } } - fun newArrayBoundsCheck( + private fun newArrayBoundsCheck( state: TraverserState, inst: Instruction, nextInstruction: Instruction?, index: Term - ): Pair { - if (index to index in state.boundCheckedTerms) return Pair(true, state) + ): CheckResult { + if (index to index in state.boundCheckedTerms) return SuccessCheck(state) val zeroClause = PathClause( PathClauseType.BOUNDS_CHECK, @@ -386,44 +340,43 @@ class ExperimentPathSelector( val zeroCheckConstraints = persistentSymbolicState() + zeroClause if (nextInstruction is CatchInst) { - return Pair(false, state + zeroCheckConstraints) + return UnsuccessfulCheck(state + zeroCheckConstraints) } else { val res = state + noExceptionConstraints - return Pair(true, res.copy(boundCheckedTerms = res.boundCheckedTerms.add(index to index)) + noExceptionConstraints) + return SuccessCheck(res.copy(boundCheckedTerms = res.boundCheckedTerms.add(index to index)) + noExceptionConstraints) } } - protected open suspend fun traverseArrayLoadInst( + private fun traverseArrayLoadInst( traverserState: TraverserState, inst: ArrayLoadInst, nextInstruction: Instruction? - ): TraverserState? { + ): TraverserState { val arrayTerm = traverserState.mkTerm(inst.arrayRef) val indexTerm = traverserState.mkTerm(inst.index) val res = generate(inst.type.symbolicType) if (arrayTerm is NullTerm) { - return nullCheck(traverserState, inst, nextInstruction, arrayTerm).second + return nullCheck(traverserState, inst, nextInstruction, arrayTerm).state } val clause = StateClause(inst, state { res equality arrayTerm[indexTerm].load() }) var result = nullCheck(traverserState, inst, nextInstruction, arrayTerm) - if (!result.first) { - return result.second - } - result = boundsCheck(result.second, inst, nextInstruction, indexTerm, arrayTerm.length()) - if (!result.first) { - return result.second - } - return result.second.copy( - symbolicState = result.second.symbolicState + clause, - valueMap = result.second.valueMap.put(inst, res) + if (result is UnsuccessfulCheck) return result.state + + result = boundsCheck(result.state, inst, nextInstruction, indexTerm, arrayTerm.length()) + if (result is UnsuccessfulCheck) return result.state + val checkedState = result.state + + return checkedState.copy( + symbolicState = checkedState.symbolicState + clause, + valueMap = checkedState.valueMap.put(inst, res) ) } - protected open suspend fun traverseArrayStoreInst( + private fun traverseArrayStoreInst( traverserState: TraverserState, inst: ArrayStoreInst, nextInstruction: Instruction? @@ -433,23 +386,21 @@ class ExperimentPathSelector( val valueTerm = traverserState.mkTerm(inst.value) if (arrayTerm is NullTerm) { - return nullCheck(traverserState, inst, nextInstruction, arrayTerm).second + return nullCheck(traverserState, inst, nextInstruction, arrayTerm).state } val clause = StateClause(inst, state { arrayTerm[indexTerm].store(valueTerm) }) var result = nullCheck(traverserState, inst, nextInstruction, arrayTerm) - if (!result.first) { - return result.second - } - result = boundsCheck(result.second, inst, nextInstruction, indexTerm, arrayTerm.length()) - if (!result.first) { - return result.second - } - return result.second + clause + if (result is UnsuccessfulCheck) return result.state + + result = boundsCheck(result.state, inst, nextInstruction, indexTerm, arrayTerm.length()) + if (result is UnsuccessfulCheck) return result.state + + return result.state + clause } - protected open suspend fun traverseBinaryInst(traverserState: TraverserState, inst: BinaryInst): TraverserState { + private fun traverseBinaryInst(traverserState: TraverserState, inst: BinaryInst): TraverserState { val lhvTerm = traverserState.mkTerm(inst.lhv) val rhvTerm = traverserState.mkTerm(inst.rhv) val resultTerm = generate(inst.type.symbolicType) @@ -464,7 +415,7 @@ class ExperimentPathSelector( ) } - protected open suspend fun traverseBranchInst( + private fun traverseBranchInst( traverserState: TraverserState, inst: BranchInst, nextInstruction: Instruction? @@ -478,15 +429,14 @@ class ExperimentPathSelector( ) val falseClause = trueClause.inverse() - if (nextInstruction in inst.trueSuccessor) { - return traverserState + trueClause + inst.parent - } - else return traverserState + falseClause + inst.parent + return if (nextInstruction in inst.trueSuccessor) { + traverserState + trueClause + inst.parent + } else traverserState + falseClause + inst.parent } val callResolver: SymbolicCallResolver = DefaultCallResolver(ctx) - protected open suspend fun traverseCallInst( + private fun traverseCallInst( traverserState: TraverserState, inst: CallInst, nextInstruction: Instruction? @@ -498,14 +448,14 @@ class ExperimentPathSelector( val argumentTerms = inst.args.map { traverserState.mkTerm(it) } val candidates = callResolver.resolve(traverserState, inst) - var (isCheckSuccess, result) = nullCheck(traverserState, inst, nextInstruction, callee) - if (!isCheckSuccess) { - return result - } + val checkResult = nullCheck(traverserState, inst, nextInstruction, callee) + if (checkResult is UnsuccessfulCheck) return checkResult.state + val checkedState = checkResult.state + val candidate = candidates.find { !it.body.entry.isEmpty && it.body.entry.instructions[0] == nextInstruction } - result = when { + return when { candidate == null -> { - var varState = result + var varState = checkedState val receiver = when { inst.isNameDefined -> { val res = generate(inst.type.symbolicType) @@ -526,12 +476,11 @@ class ExperimentPathSelector( varState + callClause } - else -> processMethodCall(result, inst, nextInstruction, candidate, callee, argumentTerms) + else -> processMethodCall(checkedState, inst, nextInstruction, candidate, callee, argumentTerms) } - return result } - protected open suspend fun traverseCastInst( + private fun traverseCastInst( traverserState: TraverserState, inst: CastInst, nextInstruction: Instruction? @@ -543,23 +492,21 @@ class ExperimentPathSelector( state { resultTerm equality (operandTerm `as` resultTerm.type) } ) - var (isCheckSuccess, result) = typeCheck(traverserState, inst, nextInstruction, operandTerm, resultTerm.type) - if (!isCheckSuccess) { - return result - } - result = result.copy( - symbolicState = result.symbolicState + clause, - valueMap = result.valueMap.put(inst, resultTerm) - ).copyTermInfo(operandTerm, resultTerm) + val checkResult = typeCheck(traverserState, inst, nextInstruction, operandTerm, resultTerm.type) + if (checkResult is UnsuccessfulCheck) return checkResult.state + val checkedState = checkResult.state - return result + return checkedState.copy( + symbolicState = checkedState.symbolicState + clause, + valueMap = checkedState.valueMap.put(inst, resultTerm) + ).copyTermInfo(operandTerm, resultTerm) } - protected open suspend fun traverseCatchInst(traverserState: TraverserState, inst: CatchInst): TraverserState { + private fun traverseCatchInst(traverserState: TraverserState): TraverserState { return traverserState } - protected open suspend fun traverseCmpInst( + private fun traverseCmpInst( traverserState: TraverserState, inst: CmpInst ): TraverserState { @@ -577,7 +524,7 @@ class ExperimentPathSelector( ) } - protected open suspend fun traverseEnterMonitorInst( + private fun traverseEnterMonitorInst( traverserState: TraverserState, inst: EnterMonitorInst, nextInstruction: Instruction? @@ -588,14 +535,12 @@ class ExperimentPathSelector( state { enterMonitor(monitorTerm) } ) - val (isCheckSuccess, result) = nullCheck(traverserState, inst, nextInstruction, monitorTerm) - if (!isCheckSuccess) { - return result - } - return result + clause + val checkResult = nullCheck(traverserState, inst, nextInstruction, monitorTerm) + if (checkResult is UnsuccessfulCheck) return checkResult.state + return checkResult.state + clause } - protected open suspend fun traverseExitMonitorInst( + private fun traverseExitMonitorInst( traverserState: TraverserState, inst: ExitMonitorInst ): TraverserState { @@ -607,7 +552,7 @@ class ExperimentPathSelector( return traverserState + clause } - protected open suspend fun traverseFieldLoadInst( + private fun traverseFieldLoadInst( traverserState: TraverserState, inst: FieldLoadInst, nextInstruction: Instruction? @@ -619,7 +564,7 @@ class ExperimentPathSelector( } if (objectTerm is NullTerm) { - return nullCheck(traverserState, inst, nextInstruction, objectTerm).second + return nullCheck(traverserState, inst, nextInstruction, objectTerm).state } val res = generate(inst.type.symbolicType) @@ -628,26 +573,27 @@ class ExperimentPathSelector( state { res equality objectTerm.field(field.type.symbolicType, field.name).load() } ) - val (isCheckSuccess, result) = nullCheck(traverserState, inst, nextInstruction, objectTerm) - if (!isCheckSuccess) return result + val checkResult = nullCheck(traverserState, inst, nextInstruction, objectTerm) + if (checkResult is UnsuccessfulCheck) return checkResult.state + val checkedState = checkResult.state val newNullChecked = when { field.isStatic && field.isFinal -> when (field.defaultValue) { - null -> result.nullCheckedTerms.add(res) - ctx.values.nullConstant -> result.nullCheckedTerms - else -> result.nullCheckedTerms.add(res) + null -> checkedState.nullCheckedTerms.add(res) + ctx.values.nullConstant -> checkedState.nullCheckedTerms + else -> checkedState.nullCheckedTerms.add(res) } - else -> result.nullCheckedTerms + else -> checkedState.nullCheckedTerms } - return result.copy( - symbolicState = result.symbolicState + clause, - valueMap = result.valueMap.put(inst, res), + return checkedState.copy( + symbolicState = checkedState.symbolicState + clause, + valueMap = checkedState.valueMap.put(inst, res), nullCheckedTerms = newNullChecked ) } - protected open suspend fun traverseFieldStoreInst( + private fun traverseFieldStoreInst( traverserState: TraverserState, inst: FieldStoreInst, nextInstruction: Instruction? @@ -658,7 +604,7 @@ class ExperimentPathSelector( } if (objectTerm is NullTerm) { - return nullCheck(traverserState, inst, nextInstruction, objectTerm).second + return nullCheck(traverserState, inst, nextInstruction, objectTerm).state } val valueTerm = traverserState.mkTerm(inst.value) @@ -667,16 +613,17 @@ class ExperimentPathSelector( state { objectTerm.field(inst.field.type.symbolicType, inst.field.name).store(valueTerm) } ) - val (isCheckSuccess, result) = nullCheck(traverserState, inst, nextInstruction, objectTerm) - if (!isCheckSuccess) return result + val checkResult = nullCheck(traverserState, inst, nextInstruction, objectTerm) + if (checkResult is UnsuccessfulCheck) return checkResult.state + val checkedState = checkResult.state - return result.copy( - symbolicState = result.symbolicState + clause, - valueMap = result.valueMap.put(inst, valueTerm) + return checkedState.copy( + symbolicState = checkedState.symbolicState + clause, + valueMap = checkedState.valueMap.put(inst, valueTerm) ) } - protected open suspend fun traverseInstanceOfInst( + private fun traverseInstanceOfInst( traverserState: TraverserState, inst: InstanceOfInst ): TraverserState { @@ -703,9 +650,9 @@ class ExperimentPathSelector( ) } - val invokeDynamicResolver: SymbolicInvokeDynamicResolver = DefaultCallResolver(ctx) + private val invokeDynamicResolver: SymbolicInvokeDynamicResolver = DefaultCallResolver(ctx) - protected open suspend fun traverseInvokeDynamicInst( + private fun traverseInvokeDynamicInst( traverserState: TraverserState, inst: InvokeDynamicInst ): TraverserState? { @@ -718,7 +665,7 @@ class ExperimentPathSelector( } } - protected open suspend fun processMethodCall( + private fun processMethodCall( traverserState: TraverserState, inst: Instruction, nextInstruction: Instruction?, @@ -745,11 +692,11 @@ class ExperimentPathSelector( ) else -> { - var (isCheckSuccess, result) = typeCheck(traverserState, inst, nextInstruction, callee, candidate.klass.symbolicClass) - if (!isCheckSuccess) { - return result - } - result = when { + val checkResult = typeCheck(traverserState, inst, nextInstruction, callee, candidate.klass.symbolicClass) + if (checkResult is UnsuccessfulCheck) return checkResult.state + val checkedState = checkResult.state + + return when { candidate.klass.asType.isSubtypeOfCached(callee.type.getKfgType(types)) -> { val newCalleeTerm = generate(candidate.klass.symbolicClass) val convertClause = StateClause(inst, state { @@ -761,24 +708,23 @@ class ExperimentPathSelector( else -> term } }.toPersistentMap() - result.copy( - symbolicState = result.symbolicState + convertClause + checkedState.copy( + symbolicState = checkedState.symbolicState + convertClause ).copyTermInfo(callee, newCalleeTerm) } else -> traverserState }.copy( valueMap = newValueMap, - stackTrace = result.stackTrace.add( - SymbolicStackTraceElement(inst.parent.method, inst, result.valueMap) + stackTrace = checkedState.stackTrace.add( + SymbolicStackTraceElement(inst.parent.method, inst, checkedState.valueMap) ) ) - return result } } } - protected open suspend fun traverseNewArrayInst( + private fun traverseNewArrayInst( traverserState: TraverserState, inst: NewArrayInst, nextInstruction: Instruction? @@ -789,11 +735,9 @@ class ExperimentPathSelector( var result: TraverserState = traverserState dimensions.forEach { dimension -> - val r = newArrayBoundsCheck(traverserState, inst, nextInstruction, dimension) - if (!r.first) { - return result - } - result = r.second + val checkResult = newArrayBoundsCheck(traverserState, inst, nextInstruction, dimension) + if (checkResult is UnsuccessfulCheck) return checkResult.state + result = checkResult.state } return result.copy( @@ -805,7 +749,7 @@ class ExperimentPathSelector( ) } - protected open suspend fun traverseNewInst( + private fun traverseNewInst( traverserState: TraverserState, inst: NewInst ): TraverserState { @@ -823,7 +767,7 @@ class ExperimentPathSelector( ) } - protected open suspend fun traversePhiInst( + private fun traversePhiInst( traverserState: TraverserState, inst: PhiInst ): TraverserState { @@ -834,7 +778,7 @@ class ExperimentPathSelector( ) } - protected open suspend fun traverseUnaryInst( + private fun traverseUnaryInst( traverserState: TraverserState, inst: UnaryInst, nextInstruction: Instruction? @@ -847,7 +791,7 @@ class ExperimentPathSelector( ) val result: TraverserState = when (inst.opcode) { - UnaryOpcode.LENGTH -> nullCheck(traverserState, inst, nextInstruction, operandTerm).second + UnaryOpcode.LENGTH -> nullCheck(traverserState, inst, nextInstruction, operandTerm).state else -> traverserState } @@ -857,21 +801,21 @@ class ExperimentPathSelector( ) } - protected open suspend fun traverseJumpInst( + private fun traverseJumpInst( traverserState: TraverserState, inst: JumpInst ): TraverserState { return traverserState + inst.parent } - protected open suspend fun traverseReturnInst( + private fun traverseReturnInst( traverserState: TraverserState, inst: ReturnInst ): TraverserState { val stackTrace = traverserState.stackTrace val stackTraceElement = stackTrace.lastOrNull() val receiver = stackTraceElement?.instruction - val result = when { + return when { receiver == null -> { return traverserState } @@ -889,10 +833,9 @@ class ExperimentPathSelector( stackTrace = stackTrace.removeAt(stackTrace.lastIndex) ) } - return result } - protected open suspend fun traverseSwitchInst( + private fun traverseSwitchInst( traverserState: TraverserState, inst: SwitchInst, nextInstruction: Instruction? @@ -918,11 +861,11 @@ class ExperimentPathSelector( return traverserState + defaultPath + inst.parent } - protected open suspend fun traverseTableSwitchInst( + private fun traverseTableSwitchInst( traverserState: TraverserState, inst: TableSwitchInst, nextInstruction: Instruction? - ): TraverserState? { + ): TraverserState { val key = traverserState.mkTerm(inst.index) val min = inst.range.first for ((index, branch) in inst.branches.withIndex()) { @@ -944,40 +887,32 @@ class ExperimentPathSelector( return traverserState + defaultPath + inst.parent } - protected open suspend fun traverseThrowInst( + private fun traverseThrowInst( traverserState: TraverserState, inst: ThrowInst, nextInstruction: Instruction? - ): TraverserState? { + ): TraverserState { val throwableTerm = traverserState.mkTerm(inst.throwable) val throwClause = StateClause( inst, state { `throw`(throwableTerm) } ) - var (isCheckPassed, result) = nullCheck(traverserState, inst, nextInstruction, throwableTerm) - if (!isCheckPassed) { - return result - } - return result + throwClause + val result = nullCheck(traverserState, inst, nextInstruction, throwableTerm) + if (result is UnsuccessfulCheck) return result.state + return result.state + throwClause } - protected open suspend fun traverseUnreachableInst( - traverserState: TraverserState, - inst: UnreachableInst - ): TraverserState? { - return null - } + private fun traverseUnreachableInst(): TraverserState? = null - protected open suspend fun traverseUnknownValueInst( - traverserState: TraverserState, + private fun traverseUnknownValueInst( inst: UnknownValueInst ): TraverserState? { return unreachable("Unexpected visit of $inst in symbolic traverser") } @Suppress("NOTHING_TO_INLINE") - protected inline fun PathClause.inverse(): PathClause = this.copy( + private inline fun PathClause.inverse(): PathClause = this.copy( predicate = this.predicate.inverse(ctx.random) ) } \ No newline at end of file