Skip to content

Commit

Permalink
Add an API for generating quotes at compile-time using macros
Browse files Browse the repository at this point in the history
Because the carac dsl is implemented in Scala, it is possible to dynamically
compose programs based on information only known at runtime, but most of the
time this level of dynamicity is not necessary because the program is known at
compile-time and only facts need to be loaded at runtime.

We can take advantage of this by leveraging the existing quote backend to
generate programs at compile-time using the standard Scala macro mechanism,
the result is faster than both the interpreter and lambda backend on at least
one simple benchmark:

    BenchMacro.simple_interpreter  thrpt   10  28,511 ± 1,442  ops/s
    BenchMacro.simple_lambda       thrpt   10  27,863 ± 0,397  ops/s
    BenchMacro.simple_macro        thrpt   10  31,917 ± 0,334  ops/s

It'd be interesting to try to extend this system to generate code that can
still be re-optimized at runtime with the JIT.
  • Loading branch information
smarter authored and aherlihy committed Aug 13, 2023
1 parent 50f7077 commit 6e09812
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 1 deletion.
60 changes: 60 additions & 0 deletions bench/src/test/scala/datalog/benchmarks/BenchMacro.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package datalog.benchmarks

import java.util.concurrent.TimeUnit
import org.openjdk.jmh.annotations.*
import org.openjdk.jmh.infra.Blackhole

import test.{SimpleProgram, SimpleMacroCompiler as simple}

import scala.compiletime.uninitialized

import datalog.execution.ir.InterpreterContext
import datalog.execution.{Backend, Granularity, Mode, StagedExecutionEngine}
import datalog.storage.DefaultStorageManager

object BenchMacro {
val simpleCompiled = simple.compile()
}
import BenchMacro.*

@Fork(1)
@Warmup(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS, batchSize = 100)
@Measurement(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS, batchSize = 100)
@State(Scope.Thread)
// @BenchmarkMode(Array(Mode.AverageTime))
class BenchMacro {
/** Add extra facts at runtime. */
def addExtraFacts(program: SimpleProgram): Unit =
program.edge("b", "c") :- ()
program.edge("d", "e") :- ()
program.edge("e", "f") :- ()
program.edge("f", "g") :- ()
program.edge("g", "h") :- ()
program.edge("h", "i") :- ()
program.edge("i", "j") :- ()
program.edge("j", "k") :- ()
program.edge("k", "l") :- ()
program.edge("l", "m") :- ()

@Benchmark
def simple_macro = {
simple.runCompiled(simpleCompiled)(addExtraFacts)
}

@Benchmark
def simple_interpreter = {
val engine = StagedExecutionEngine(DefaultStorageManager(), simple.jitOptions.copy(
mode = Mode.Interpreted, granularity = Granularity.NEVER))
val program = simple.makeProgram(engine)
addExtraFacts(program)
program.toSolve.solve()
}

@Benchmark
def simple_lambda = {
val engine = StagedExecutionEngine(DefaultStorageManager(), simple.jitOptions.copy(backend = Backend.Lambda))
val program = simple.makeProgram(engine)
addExtraFacts(program)
program.toSolve.solve()
}
}
106 changes: 106 additions & 0 deletions src/main/scala/datalog/execution/MacroCompiler.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package datalog
package execution

import datalog.dsl.*
import datalog.execution.JITOptions
import datalog.execution.ir.*
import datalog.storage.{DefaultStorageManager, StorageManager, RelationId}

import scala.quoted.*
import scala.compiletime.uninitialized

/** A program that specifies a relation `toSolve` to be solved. */
abstract class SolvableProgram(engine: ExecutionEngine) extends Program(engine) {
/** The relation to be solved when running the program. */
val toSolve: Relation[?]
}

/**
* A base class to pre-compile program during the Scala compiler compile-time.
*
* This API is a bit awkward to use since the Scala macro system requires us
* to define the macro in a separate file from where it's used.
* To pre-compile a program, first define it as a subclass of SolvableProgram:
*
* class MyProgram(engine: ExecutionEngine) extends SolvableProgram(engine) {
* val edge = relation[Constant]("edge")
* // ...
* override val toSolve = edge
* }
*
* Then define an object to hold the pre-compiled lambda:
*
* object MyMacroCompiler extends MacroCompiler(MyProgram(_)) {
* inline def compile(): StorageManager => Any = ${compileImpl()}
* }
*
* Then in a separate file, call the macro to cache its result:
*
* val compiled = MyMacroCompiler.compile()
*
* You can then run the program, with extra facts loaded at runtime if desired:
*
* MyMacroCompiler.runCompiled(compiled)(p => p.edge("b", "c") :- ())
*/
abstract class MacroCompiler[T <: SolvableProgram](val makeProgram: ExecutionEngine => T) {
/** Generate an engine suitable for use with the output of `compile()`. */
def makeEngine(): StagedExecutionEngine = {
val storageManager = DefaultStorageManager()
StagedExecutionEngine(storageManager, JITOptions(
mode = Mode.JIT, granularity = Granularity.ALL,
// FIXME: make the dotty parameter optional, maybe by making it a
// parameter of Backend.Quotes and having a separate Backend.Macro.
dotty = null,
compileSync = CompileSync.Blocking, sortOrder = SortOrder.Unordered,
backend = Backend.Quotes))
}
private val engine: StagedExecutionEngine = makeEngine()
val jitOptions: JITOptions = engine.defaultJITOptions
private val program: T = makeProgram(engine)

protected def compileImpl()(using Quotes): Expr[StorageManager => Any] = {
val irTree = engine.generateProgramTree(program.toSolve.id)._1
// TODO: more precise type for engine.compiler to avoid the cast.
val compiler = engine.compiler.asInstanceOf[QuoteCompiler]
val x = '{ (sm: StorageManager) =>
${compiler.compileIR(irTree)(using 'sm)}
}
// println(x.show)
x
}

/**
* Generate the macro-compiled program solver.
*
* Cache the result in a val to avoid running the macro multiple times,
* then pass it to `runCompiled`.
*
* Subclasses should implement this by just calling `${compileImpl()}`.
*/
inline def compile(): StorageManager => Any

/**
* Run a macro-compiled program solver with a fresh Program at runtime.
*
* @param compiled The output of a call to `this.compile()`.
* @param op Operations to run on the fresh program, this
* can be used to add extra facts at runtime.
* TODO: Find a nice way to restrict this to only allow
* adding extra facts and nothing else.
*/
def runCompiled(compiled: StorageManager => Any)(op: T => Any): Any = {
val runtimeEngine = makeEngine()
val runtimeProgram = makeProgram(runtimeEngine)

// Even though we don't use the generated tree at runtime,
// we still need to generate it to find the de-aliased irCtx.toSolve
// and to populate runtimeEngine.storageManager.allRulesAllIndexes
val (_, irCtx) = runtimeEngine.generateProgramTree(program.toSolve.id)

op(runtimeProgram)

compiled(runtimeEngine.storageManager)

runtimeEngine.storageManager.getNewIDBResult(irCtx.toSolve)
}
}
4 changes: 3 additions & 1 deletion src/main/scala/datalog/execution/QuoteCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ import scala.quoted.*
*/
class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extends StagedCompiler(storageManager) {
given staging.Compiler = jitOptions.dotty
clearDottyThread()
// FIXME: Make dotty an optional parameter, this is null in MacroCompiler.
if (jitOptions.dotty != null)
clearDottyThread()

given MutableMapToExpr[T: Type : ToExpr, U: Type : ToExpr]: ToExpr[mutable.Map[T, U]] with {
def apply(map: mutable.Map[T, U])(using Quotes): Expr[mutable.Map[T, U]] =
Expand Down
31 changes: 31 additions & 0 deletions src/test/scala/test/MacroCompilerPrograms.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package test

import datalog.execution.{ExecutionEngine, SolvableProgram, MacroCompiler}
import datalog.dsl.*
import datalog.execution.ir.*
import datalog.storage.{DefaultStorageManager, StorageManager}

/// Used in MacroCompilerTest

class SimpleProgram(engine: ExecutionEngine) extends SolvableProgram(engine) {
val edge = relation[Constant]("edge")
val path = relation[Constant]("path")
val pathFromA = relation[Constant]("pathFromA")

val x, y, z = variable()

edge("a", "b") :- ()

edge("c", "d") :- ()

path(x, y) :- edge(x, y)
path(x, z) :- (edge(x, y), path(y, z))

pathFromA(x) :- path("a", x)

override val toSolve = pathFromA
}

object SimpleMacroCompiler extends MacroCompiler(SimpleProgram(_)) {
inline def compile(): StorageManager => Any = ${compileImpl()}
}
20 changes: 20 additions & 0 deletions src/test/scala/test/MacroCompilerTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package test

import datalog.execution.{ExecutionEngine, SolvableProgram, MacroCompiler}
import datalog.dsl.*
import datalog.execution.ir.*
import datalog.storage.{DefaultStorageManager, StorageManager}

object MacroCompilerTest {
val simpleCompiled = SimpleMacroCompiler.compile()
}
import MacroCompilerTest.*

class MacroCompilerTest extends munit.FunSuite {
test("can add facts at runtime to macro-compiled program") {
val res = SimpleMacroCompiler.runCompiled(simpleCompiled) { program =>
program.edge("b", "c") :- ()
}
assertEquals(res, Set(Seq("b"), Seq("c"), Seq("d")))
}
}

0 comments on commit 6e09812

Please sign in to comment.