-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an API for generating quotes at compile-time using macros
Because the carac dsl is implemented in Scala, it is possible to dynamically compose programs based on information only known at runtime, but most of the time this level of dynamicity is not necessary because the program is known at compile-time and only facts need to be loaded at runtime. We can take advantage of this by leveraging the existing quote backend to generate programs at compile-time using the standard Scala macro mechanism, the result is faster than both the interpreter and lambda backend on at least one simple benchmark: BenchMacro.simple_interpreter thrpt 10 28,511 ± 1,442 ops/s BenchMacro.simple_lambda thrpt 10 27,863 ± 0,397 ops/s BenchMacro.simple_macro thrpt 10 31,917 ± 0,334 ops/s It'd be interesting to try to extend this system to generate code that can still be re-optimized at runtime with the JIT.
- Loading branch information
Showing
5 changed files
with
220 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
package datalog.benchmarks | ||
|
||
import java.util.concurrent.TimeUnit | ||
import org.openjdk.jmh.annotations.* | ||
import org.openjdk.jmh.infra.Blackhole | ||
|
||
import test.{SimpleProgram, SimpleMacroCompiler as simple} | ||
|
||
import scala.compiletime.uninitialized | ||
|
||
import datalog.execution.ir.InterpreterContext | ||
import datalog.execution.{Backend, Granularity, Mode, StagedExecutionEngine} | ||
import datalog.storage.DefaultStorageManager | ||
|
||
object BenchMacro { | ||
val simpleCompiled = simple.compile() | ||
} | ||
import BenchMacro.* | ||
|
||
@Fork(1) | ||
@Warmup(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS, batchSize = 100) | ||
@Measurement(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS, batchSize = 100) | ||
@State(Scope.Thread) | ||
// @BenchmarkMode(Array(Mode.AverageTime)) | ||
class BenchMacro { | ||
/** Add extra facts at runtime. */ | ||
def addExtraFacts(program: SimpleProgram): Unit = | ||
program.edge("b", "c") :- () | ||
program.edge("d", "e") :- () | ||
program.edge("e", "f") :- () | ||
program.edge("f", "g") :- () | ||
program.edge("g", "h") :- () | ||
program.edge("h", "i") :- () | ||
program.edge("i", "j") :- () | ||
program.edge("j", "k") :- () | ||
program.edge("k", "l") :- () | ||
program.edge("l", "m") :- () | ||
|
||
@Benchmark | ||
def simple_macro = { | ||
simple.runCompiled(simpleCompiled)(addExtraFacts) | ||
} | ||
|
||
@Benchmark | ||
def simple_interpreter = { | ||
val engine = StagedExecutionEngine(DefaultStorageManager(), simple.jitOptions.copy( | ||
mode = Mode.Interpreted, granularity = Granularity.NEVER)) | ||
val program = simple.makeProgram(engine) | ||
addExtraFacts(program) | ||
program.toSolve.solve() | ||
} | ||
|
||
@Benchmark | ||
def simple_lambda = { | ||
val engine = StagedExecutionEngine(DefaultStorageManager(), simple.jitOptions.copy(backend = Backend.Lambda)) | ||
val program = simple.makeProgram(engine) | ||
addExtraFacts(program) | ||
program.toSolve.solve() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
package datalog | ||
package execution | ||
|
||
import datalog.dsl.* | ||
import datalog.execution.JITOptions | ||
import datalog.execution.ir.* | ||
import datalog.storage.{DefaultStorageManager, StorageManager, RelationId} | ||
|
||
import scala.quoted.* | ||
import scala.compiletime.uninitialized | ||
|
||
/** A program that specifies a relation `toSolve` to be solved. */ | ||
abstract class SolvableProgram(engine: ExecutionEngine) extends Program(engine) { | ||
/** The relation to be solved when running the program. */ | ||
val toSolve: Relation[?] | ||
} | ||
|
||
/** | ||
* A base class to pre-compile program during the Scala compiler compile-time. | ||
* | ||
* This API is a bit awkward to use since the Scala macro system requires us | ||
* to define the macro in a separate file from where it's used. | ||
* To pre-compile a program, first define it as a subclass of SolvableProgram: | ||
* | ||
* class MyProgram(engine: ExecutionEngine) extends SolvableProgram(engine) { | ||
* val edge = relation[Constant]("edge") | ||
* // ... | ||
* override val toSolve = edge | ||
* } | ||
* | ||
* Then define an object to hold the pre-compiled lambda: | ||
* | ||
* object MyMacroCompiler extends MacroCompiler(MyProgram(_)) { | ||
* inline def compile(): StorageManager => Any = ${compileImpl()} | ||
* } | ||
* | ||
* Then in a separate file, call the macro to cache its result: | ||
* | ||
* val compiled = MyMacroCompiler.compile() | ||
* | ||
* You can then run the program, with extra facts loaded at runtime if desired: | ||
* | ||
* MyMacroCompiler.runCompiled(compiled)(p => p.edge("b", "c") :- ()) | ||
*/ | ||
abstract class MacroCompiler[T <: SolvableProgram](val makeProgram: ExecutionEngine => T) { | ||
/** Generate an engine suitable for use with the output of `compile()`. */ | ||
def makeEngine(): StagedExecutionEngine = { | ||
val storageManager = DefaultStorageManager() | ||
StagedExecutionEngine(storageManager, JITOptions( | ||
mode = Mode.JIT, granularity = Granularity.ALL, | ||
// FIXME: make the dotty parameter optional, maybe by making it a | ||
// parameter of Backend.Quotes and having a separate Backend.Macro. | ||
dotty = null, | ||
compileSync = CompileSync.Blocking, sortOrder = SortOrder.Unordered, | ||
backend = Backend.Quotes)) | ||
} | ||
private val engine: StagedExecutionEngine = makeEngine() | ||
val jitOptions: JITOptions = engine.defaultJITOptions | ||
private val program: T = makeProgram(engine) | ||
|
||
protected def compileImpl()(using Quotes): Expr[StorageManager => Any] = { | ||
val irTree = engine.generateProgramTree(program.toSolve.id)._1 | ||
// TODO: more precise type for engine.compiler to avoid the cast. | ||
val compiler = engine.compiler.asInstanceOf[QuoteCompiler] | ||
val x = '{ (sm: StorageManager) => | ||
${compiler.compileIR(irTree)(using 'sm)} | ||
} | ||
// println(x.show) | ||
x | ||
} | ||
|
||
/** | ||
* Generate the macro-compiled program solver. | ||
* | ||
* Cache the result in a val to avoid running the macro multiple times, | ||
* then pass it to `runCompiled`. | ||
* | ||
* Subclasses should implement this by just calling `${compileImpl()}`. | ||
*/ | ||
inline def compile(): StorageManager => Any | ||
|
||
/** | ||
* Run a macro-compiled program solver with a fresh Program at runtime. | ||
* | ||
* @param compiled The output of a call to `this.compile()`. | ||
* @param op Operations to run on the fresh program, this | ||
* can be used to add extra facts at runtime. | ||
* TODO: Find a nice way to restrict this to only allow | ||
* adding extra facts and nothing else. | ||
*/ | ||
def runCompiled(compiled: StorageManager => Any)(op: T => Any): Any = { | ||
val runtimeEngine = makeEngine() | ||
val runtimeProgram = makeProgram(runtimeEngine) | ||
|
||
// Even though we don't use the generated tree at runtime, | ||
// we still need to generate it to find the de-aliased irCtx.toSolve | ||
// and to populate runtimeEngine.storageManager.allRulesAllIndexes | ||
val (_, irCtx) = runtimeEngine.generateProgramTree(program.toSolve.id) | ||
|
||
op(runtimeProgram) | ||
|
||
compiled(runtimeEngine.storageManager) | ||
|
||
runtimeEngine.storageManager.getNewIDBResult(irCtx.toSolve) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package test | ||
|
||
import datalog.execution.{ExecutionEngine, SolvableProgram, MacroCompiler} | ||
import datalog.dsl.* | ||
import datalog.execution.ir.* | ||
import datalog.storage.{DefaultStorageManager, StorageManager} | ||
|
||
/// Used in MacroCompilerTest | ||
|
||
class SimpleProgram(engine: ExecutionEngine) extends SolvableProgram(engine) { | ||
val edge = relation[Constant]("edge") | ||
val path = relation[Constant]("path") | ||
val pathFromA = relation[Constant]("pathFromA") | ||
|
||
val x, y, z = variable() | ||
|
||
edge("a", "b") :- () | ||
|
||
edge("c", "d") :- () | ||
|
||
path(x, y) :- edge(x, y) | ||
path(x, z) :- (edge(x, y), path(y, z)) | ||
|
||
pathFromA(x) :- path("a", x) | ||
|
||
override val toSolve = pathFromA | ||
} | ||
|
||
object SimpleMacroCompiler extends MacroCompiler(SimpleProgram(_)) { | ||
inline def compile(): StorageManager => Any = ${compileImpl()} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package test | ||
|
||
import datalog.execution.{ExecutionEngine, SolvableProgram, MacroCompiler} | ||
import datalog.dsl.* | ||
import datalog.execution.ir.* | ||
import datalog.storage.{DefaultStorageManager, StorageManager} | ||
|
||
object MacroCompilerTest { | ||
val simpleCompiled = SimpleMacroCompiler.compile() | ||
} | ||
import MacroCompilerTest.* | ||
|
||
class MacroCompilerTest extends munit.FunSuite { | ||
test("can add facts at runtime to macro-compiled program") { | ||
val res = SimpleMacroCompiler.runCompiled(simpleCompiled) { program => | ||
program.edge("b", "c") :- () | ||
} | ||
assertEquals(res, Set(Seq("b"), Seq("c"), Seq("d"))) | ||
} | ||
} |