Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an API for generating quotes at compile-time using macros #35

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions bench/src/test/scala/datalog/benchmarks/BenchMacro.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package datalog.benchmarks

import java.util.concurrent.TimeUnit
import org.openjdk.jmh.annotations.*
import org.openjdk.jmh.infra.Blackhole
import test.{
AckermannWorstMacroCompiler, AckermannOptimizedMacroCompiler,
SimpleMacroCompiler, SimpleProgram
}

import scala.compiletime.uninitialized
import datalog.execution.ir.InterpreterContext
import datalog.execution.{Backend, Granularity, Mode, StagedExecutionEngine}
import datalog.storage.DefaultStorageManager

import java.nio.file.Paths

object BenchMacro {
val simpleCompiled = SimpleMacroCompiler.compile()
val ackermannOptCompiled = AckermannOptimizedMacroCompiler.compile()
val ackermannWorstCompiled = AckermannWorstMacroCompiler.compile()
}
import BenchMacro.*

@Fork(1)
@Warmup(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS, batchSize = 100)
@Measurement(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS, batchSize = 100)
@State(Scope.Thread)
// @BenchmarkMode(Array(Mode.AverageTime))
class BenchMacro {
/** Add extra facts at runtime. */
def addExtraFacts(program: SimpleProgram): Unit =
program.edge("b", "c") :- ()
program.edge("d", "e") :- ()
program.edge("e", "f") :- ()
program.edge("f", "g") :- ()
program.edge("g", "h") :- ()
program.edge("h", "i") :- ()
program.edge("i", "j") :- ()
program.edge("j", "k") :- ()
program.edge("k", "l") :- ()
program.edge("l", "m") :- ()

@Benchmark
def simple_macro = {
SimpleMacroCompiler.runCompiled(simpleCompiled)(addExtraFacts)
}

@Benchmark
def ackermann_opt_macro = {
val facts = Paths.get(AckermannOptimizedMacroCompiler.factDir)
val res = AckermannOptimizedMacroCompiler.runCompiled(ackermannOptCompiled)(
program => program.loadFromFactDir(facts.toString)
)
// println(res)
}

@Benchmark
def ackermann_worst_macro = {
val facts = Paths.get(AckermannWorstMacroCompiler.factDir)
val res = AckermannWorstMacroCompiler.runCompiled(ackermannWorstCompiled)(
program => program.loadFromFactDir(facts.toString)
)
// println(res)
}

@Benchmark
def ackermann_worst_lambda = {
val engine = StagedExecutionEngine(DefaultStorageManager(), AckermannWorstMacroCompiler.jitOptions.copy(backend = Backend.Lambda))
val facts = Paths.get(AckermannWorstMacroCompiler.factDir)
val program = AckermannWorstMacroCompiler.makeProgram(engine)
program.loadFromFactDir(facts.toString)
val res = program.namedRelation(program.toSolve).solve()
// println(res)
}

@Benchmark
def ackermann_opt_lambda = {
val engine = StagedExecutionEngine(DefaultStorageManager(), AckermannOptimizedMacroCompiler.jitOptions.copy(backend = Backend.Lambda))
val facts = Paths.get(AckermannOptimizedMacroCompiler.factDir)
val program = AckermannOptimizedMacroCompiler.makeProgram(engine)
program.loadFromFactDir(facts.toString)
val res = program.namedRelation(program.toSolve).solve()
// println(res)
}

@Benchmark
def simple_interpreter = {
val engine = StagedExecutionEngine(DefaultStorageManager(), SimpleMacroCompiler.jitOptions.copy(
mode = Mode.Interpreted, granularity = Granularity.NEVER))
val program = SimpleMacroCompiler.makeProgram(engine)
addExtraFacts(program)
program.namedRelation(program.toSolve).solve()
}

@Benchmark
def simple_lambda = {
val engine = StagedExecutionEngine(DefaultStorageManager(), SimpleMacroCompiler.jitOptions.copy(backend = Backend.Lambda))
val program = SimpleMacroCompiler.makeProgram(engine)
addExtraFacts(program)
program.namedRelation(program.toSolve).solve()
}
}
3 changes: 1 addition & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import java.nio.file.{Files, Paths}

inThisBuild(Seq(
organization := "ch.epfl.lamp",
scalaVersion := "3.3.1-RC4",
// scalaVersion := "3.3.1-RC1-bin-SNAPSHOT",
scalaVersion := "3.4.0-RC1-bin-20230818-932c10d-NIGHTLY",
version := "0.1",
))

Expand Down
1 change: 1 addition & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ addSbtPlugin("org.xerial.sbt" % "sbt-pack" % "0.17")
addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.11.0")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.15.0")
addSbtPlugin("org.scalameta" % "sbt-native-image" % "0.3.4")
addSbtPlugin("org.xerial.sbt" % "sbt-pack" % "0.17")
17 changes: 16 additions & 1 deletion src/main/scala/datalog/dsl/Program.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ class Program(engine: ExecutionEngine) extends AbstractProgram {
// TODO: also provide solve for multiple/all predicates, or return table so users can query over the derived DB
def solve(rId: Int): Set[Seq[Term]] = ee.solve(rId).map(s => s.toSeq).toSet

def initializeEmptyFactsFromDir(directory: String): Unit = {
val factdir = Path.of(directory)
if (Files.exists(factdir)) {
Files.walk(factdir, 1)
.filter(p => Files.isRegularFile(p))
.forEach(f => {
val edbName = f.getFileName.toString.replaceFirst("[.][^.]+$", "")
val fact = relation[Constant](edbName)
// println(fact)
})
} else throw new Exception(s"Directory $factdir does not contain any facts")
}

def loadFromFactDir(directory: String): Unit = {
val factdir = Path.of(directory)
if (Files.exists(factdir)) {
Expand All @@ -45,7 +58,9 @@ class Program(engine: ExecutionEngine) extends AbstractProgram {
val firstLine = reader.readLine()
if (firstLine != null) { // empty file, empty EDB
val headers = firstLine.split("\t")
val fact = relation[Constant](edbName)
val fact =
if (ee.storageManager.ns.contains(edbName)) namedRelation[Constant](edbName)
else relation[Constant](edbName)
reader.lines()
.forEach(l => {
val factInput = l
Expand Down
9 changes: 7 additions & 2 deletions src/main/scala/datalog/execution/JITOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ enum CompileSync:
enum SortOrder:
case Sel, IntMax, Mixed, Badluck, Unordered, Worst
enum Backend:
case Quotes, Bytecode, Lambda
case MacroQuotes, Quotes, Bytecode, Lambda
enum Granularity(val flag: OpCode):
case ALL extends Granularity(OpCode.EVAL_RULE_SN)
case RULE extends Granularity(OpCode.EVAL_RULE_BODY)
Expand Down Expand Up @@ -42,7 +42,7 @@ case class JITOptions(
throw new Exception(s"Do you really want to set JIT options with $mode?")
if (
(mode == Mode.Interpreted && backend != Backend.Quotes) ||
(mode == Mode.Compiled && sortOrder != SortOrder.Unordered) ||
// (mode == Mode.Compiled && sortOrder != SortOrder.Unordered) ||
(fuzzy != DEFAULT_FUZZY && compileSync == CompileSync.Blocking) ||
(compileSync != CompileSync.Async && !useGlobalContext))
throw new Exception(s"Weird options for mode $mode ($backend, $sortOrder, or $compileSync), are you sure?")
Expand All @@ -57,6 +57,11 @@ case class JITOptions(
s"${programStr}_${granStr}_$backendStr"

def getSortFn(storageManager: StorageManager): (Atom, Boolean) => (Boolean, Int) =
JITOptions.getSortFn(sortOrder, storageManager)
}

object JITOptions {
def getSortFn(sortOrder: SortOrder, storageManager: StorageManager): (Atom, Boolean) => (Boolean, Int) =
sortOrder match
case SortOrder.IntMax =>
(a: Atom, isDelta: Boolean) => if (storageManager.edbContains(a.rId))
Expand Down
106 changes: 106 additions & 0 deletions src/main/scala/datalog/execution/MacroCompiler.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package datalog
package execution

import datalog.dsl.*
import datalog.execution.JITOptions
import datalog.execution.ir.*
import datalog.storage.{DefaultStorageManager, StorageManager, RelationId}

import scala.quoted.*
import scala.compiletime.uninitialized

/** A program that specifies a relation `toSolve` to be solved. */
abstract class SolvableProgram(engine: ExecutionEngine) extends Program(engine) {
/** The relation to be solved when running the program. */
val toSolve: String
}

/**
* A base class to pre-compile program during the Scala compiler compile-time.
*
* This API is a bit awkward to use since the Scala macro system requires us
* to define the macro in a separate file from where it's used.
* To pre-compile a program, first define it as a subclass of SolvableProgram:
*
* class MyProgram(engine: ExecutionEngine) extends SolvableProgram(engine) {
* val edge = relation[Constant]("edge")
* // ...
* override val toSolve = edge
* }
*
* Then define an object to hold the pre-compiled lambda:
*
* object MyMacroCompiler extends MacroCompiler(MyProgram(_)) {
* inline def compile(): StorageManager => Any = ${compileImpl()}
* }
*
* Then in a separate file, call the macro to cache its result:
*
* val compiled = MyMacroCompiler.compile()
*
* You can then run the program, with extra facts loaded at runtime if desired:
*
* MyMacroCompiler.runCompiled(compiled)(p => p.edge("b", "c") :- ())
*/
abstract class MacroCompiler[T <: SolvableProgram](val makeProgram: ExecutionEngine => T) {
/** Generate an engine suitable for use with the output of `compile()`. */
def makeEngine(): StagedExecutionEngine = {
val storageManager = DefaultStorageManager()
StagedExecutionEngine(storageManager, JITOptions(
mode = Mode.JIT, granularity = Granularity.DELTA,
// FIXME: make the dotty parameter optional, maybe by making it a
// parameter of Backend.Quotes and having a separate Backend.Macro.
dotty = null,
compileSync = CompileSync.Blocking, sortOrder = SortOrder.Sel,
backend = Backend.MacroQuotes))
}
private val engine: StagedExecutionEngine = makeEngine()
val jitOptions: JITOptions = engine.defaultJITOptions
private val program: T = makeProgram(engine)

protected def compileImpl()(using Quotes): Expr[StorageManager => Any] = {
val irTree = engine.generateProgramTree(program.namedRelation(program.toSolve).id)._1
// TODO: more precise type for engine.compiler to avoid the cast.
val compiler = engine.compiler.asInstanceOf[QuoteCompiler]
val x = '{ (sm: StorageManager) =>
${compiler.compileIR(irTree)(using 'sm)}
}
// println(x.show)
x
}

/**
* Generate the macro-compiled program solver.
*
* Cache the result in a val to avoid running the macro multiple times,
* then pass it to `runCompiled`.
*
* Subclasses should implement this by just calling `${compileImpl()}`.
*/
inline def compile(): StorageManager => Any

/**
* Run a macro-compiled program solver with a fresh Program at runtime.
*
* @param compiled The output of a call to `this.compile()`.
* @param op Operations to run on the fresh program, this
* can be used to add extra facts at runtime.
* TODO: Find a nice way to restrict this to only allow
* adding extra facts and nothing else.
*/
def runCompiled(compiled: StorageManager => Any)(op: T => Any): Any = {
val runtimeEngine = makeEngine()
val runtimeProgram = makeProgram(runtimeEngine)

// Even though we don't use the generated tree at runtime,
// we still need to generate it to find the de-aliased irCtx.toSolve
// and to populate runtimeEngine.storageManager.allRulesAllIndexes
val (_, irCtx) = runtimeEngine.generateProgramTree(program.namedRelation(program.toSolve).id)

op(runtimeProgram)

compiled(runtimeEngine.storageManager)

runtimeEngine.storageManager.getNewIDBResult(irCtx.toSolve)
}
}
68 changes: 67 additions & 1 deletion src/main/scala/datalog/execution/QuoteCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,66 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.{immutable, mutable}
import scala.quoted.*

class MacroQuoteCompiler(storageManager: StorageManager)(using JITOptions) extends QuoteCompiler(storageManager) {
override def compileIRRelOp(irTree: IROp[EDB])(using stagedSM: Expr[StorageManager])(using Quotes): Expr[EDB] = {
irTree match {
case ProjectJoinFilterOp(rId, k, children: _*) =>
val deltaIdx = Expr(children.indexWhere(op =>
op match
case o: ScanOp => o.db == DB.Delta
case _ => false)
)
val unorderedChildren = Expr.ofSeq(children.map(compileIRRelOp))
val childrenLength = Expr(children.length)
val sortOrder = Expr(jitOptions.sortOrder)
val stagedId = Expr(rId)
val ruleHash = rId.toString + "%" + k.hash

jitOptions.sortOrder match {
case SortOrder.Sel | SortOrder.Mixed | SortOrder.IntMax | SortOrder.Worst if jitOptions.granularity.flag == irTree.code => '{
val originalK = $stagedSM.allRulesAllIndexes($stagedId).apply(${Expr(k.hash)})
val sortBy = JITOptions.getSortFn($sortOrder, $stagedSM)
val (newBody, newHash) =
if ($sortOrder == SortOrder.Worst)
JoinIndexes.presortSelectWorst(sortBy, originalK, $stagedSM, $deltaIdx)
else
JoinIndexes.presortSelect(sortBy, originalK, $stagedSM, $deltaIdx)
val newK = $stagedSM.allRulesAllIndexes($stagedId).getOrElseUpdate(
newHash,
JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns))
)
val unordered = $unorderedChildren
val orderedSeq = newK.atoms.view.drop(1).map(a => unordered(originalK.atoms.view.drop(1).indexOf(a))).to(immutable.ArraySeq)
$stagedSM.joinProjectHelper_withHash(
orderedSeq,
${ Expr(rId) },
newK.hash,
${ Expr(jitOptions.onlineSort) }
)
}
case _ => '{
$stagedSM.joinProjectHelper_withHash(
$unorderedChildren,
$stagedId,
${ Expr(k.hash) },
${ Expr(jitOptions.onlineSort) }
)
}
}
case _ =>
super.compileIRRelOp(irTree)
}
}
}

/**
* Separate out compile logic from StagedExecutionEngine
*/
class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extends StagedCompiler(storageManager) {
given staging.Compiler = jitOptions.dotty
clearDottyThread()
// FIXME: Make dotty an optional parameter, this is null in MacroCompiler.
if (jitOptions.dotty != null)
clearDottyThread()

given MutableMapToExpr[T: Type : ToExpr, U: Type : ToExpr]: ToExpr[mutable.Map[T, U]] with {
def apply(map: mutable.Map[T, U])(using Quotes): Expr[mutable.Map[T, U]] =
Expand Down Expand Up @@ -77,6 +131,18 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend
}
}

given ToExpr[SortOrder] with {
def apply(x: SortOrder)(using Quotes) = {
x match
case SortOrder.Sel => '{ SortOrder.Sel }
case SortOrder.IntMax => '{ SortOrder.IntMax }
case SortOrder.Mixed => '{ SortOrder.Mixed }
case SortOrder.Badluck => '{ SortOrder.Badluck }
case SortOrder.Unordered => '{ SortOrder.Unordered }
case SortOrder.Worst => '{ SortOrder.Worst }
}
}

/**
* Compiles a relational operator into a quote that returns an EDB. Future TODO: merge with compileIR when dotty supports.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class StagedExecutionEngine(val storageManager: StorageManager, val defaultJITOp
// var threadpool: ExecutorService = null

val compiler: StagedCompiler = defaultJITOptions.backend match
case Backend.MacroQuotes => MacroQuoteCompiler(storageManager)
case Backend.Quotes => QuoteCompiler(storageManager)
case Backend.Bytecode => BytecodeCompiler(storageManager)
case Backend.Lambda => LambdaCompiler(storageManager)
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/datalog/storage/DefaultStorageManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class DefaultStorageManager(ns: NS = new NS()) extends CollectionsStorageManager
}

override def joinProjectHelper_withHash(inputsEDB: Seq[EDB], rId: Int, hash: String, onlineSort: Boolean): CollectionsEDB = {
if (!allRulesAllIndexes.contains(rId)) throw new Exception(s"Missing relation ${ns(rId)} from JoinIndexes cache. Existing keys ${allRulesAllIndexes.keys.map(k => ns(k)).mkString("[", ", ", "]")}")
if (!allRulesAllIndexes(rId).contains(hash)) throw new Exception(s"Missing hash for ${ns(rId)}: $hash from JoinIndexes cache. # hashes: ${allRulesAllIndexes(rId).size}")
val originalK = allRulesAllIndexes(rId)(hash)
val inputs = asCollectionsSeqEDB(inputsEDB)
// var intermediateCardinalities = Seq[Int]()
Expand Down
Loading