From 526c32a952a58fc885d0976e7d09aa0834cca47a Mon Sep 17 00:00:00 2001
From: Guillaume Martres <smarter@ubuntu.com>
Date: Tue, 22 Aug 2023 15:58:14 +0200
Subject: [PATCH] MacroCompiler: Support runtime sorting (at granularity DELTA)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Compute the sort order at runtime and directly use it without an extra
compilation step: instead of re-ordering the code we generate, we re-order the
data the code we generate operate on.

Benchmarks on my unplugged laptop:

Before:
    Benchmark                          Mode   Cnt  Score   Error  Units
    BenchMacro.ackermann_opt_macro     thrpt   10  1,026 ± 0,036  ops/s
    BenchMacro.ackermann_worst_macro   thrpt   10  0,773 ± 0,022  ops/s

After:
    BenchMacro.ackermann_opt_macro     thrpt   10  2,386 ± 0,128  ops/s
    BenchMacro.ackermann_worst_macro   thrpt   10  2,548 ± 0,144  ops/s

This is comparable to results on the lambda backend:
    BenchMacro.ackermann_opt_lambda    thrpt   10  2,556 ± 0,075  ops/s
    BenchMacro.ackermann_worst_lambda  thrpt   10  2,636 ± 0,093  ops/s
---
 .../scala/datalog/benchmarks/BenchMacro.scala | 20 ++++++
 .../scala/datalog/execution/JITOptions.scala  |  9 ++-
 .../datalog/execution/MacroCompiler.scala     |  4 +-
 .../datalog/execution/QuoteCompiler.scala     | 64 +++++++++++++++++++
 .../execution/StagedExecutionEngine.scala     |  1 +
 5 files changed, 94 insertions(+), 4 deletions(-)

diff --git a/bench/src/test/scala/datalog/benchmarks/BenchMacro.scala b/bench/src/test/scala/datalog/benchmarks/BenchMacro.scala
index 9f61dced..91874ef2 100644
--- a/bench/src/test/scala/datalog/benchmarks/BenchMacro.scala
+++ b/bench/src/test/scala/datalog/benchmarks/BenchMacro.scala
@@ -64,6 +64,26 @@ class BenchMacro {
     // println(res)
   }
 
+  @Benchmark
+  def ackermann_worst_lambda = {
+    val engine = StagedExecutionEngine(DefaultStorageManager(), AckermannWorstMacroCompiler.jitOptions.copy(backend = Backend.Lambda))
+    val facts = Paths.get(AckermannWorstMacroCompiler.factDir)
+    val program = AckermannWorstMacroCompiler.makeProgram(engine)
+    program.loadFromFactDir(facts.toString)
+    val res = program.namedRelation(program.toSolve).solve()
+    // println(res)
+  }
+
+  @Benchmark
+  def ackermann_opt_lambda = {
+    val engine = StagedExecutionEngine(DefaultStorageManager(), AckermannOptimizedMacroCompiler.jitOptions.copy(backend = Backend.Lambda))
+    val facts = Paths.get(AckermannOptimizedMacroCompiler.factDir)
+    val program = AckermannOptimizedMacroCompiler.makeProgram(engine)
+    program.loadFromFactDir(facts.toString)
+    val res = program.namedRelation(program.toSolve).solve()
+    // println(res)
+  }
+
   @Benchmark
   def simple_interpreter = {
     val engine = StagedExecutionEngine(DefaultStorageManager(), SimpleMacroCompiler.jitOptions.copy(
diff --git a/src/main/scala/datalog/execution/JITOptions.scala b/src/main/scala/datalog/execution/JITOptions.scala
index 11b9badf..1c6332b1 100644
--- a/src/main/scala/datalog/execution/JITOptions.scala
+++ b/src/main/scala/datalog/execution/JITOptions.scala
@@ -11,7 +11,7 @@ enum CompileSync:
 enum SortOrder:
   case Sel, IntMax, Mixed, Badluck, Unordered, Worst
 enum Backend:
-  case Quotes, Bytecode, Lambda
+  case MacroQuotes, Quotes, Bytecode, Lambda
 enum Granularity(val flag: OpCode):
   case ALL extends Granularity(OpCode.EVAL_RULE_SN)
   case RULE extends Granularity(OpCode.EVAL_RULE_BODY)
@@ -42,7 +42,7 @@ case class JITOptions(
     throw new Exception(s"Do you really want to set JIT options with $mode?")
   if (
     (mode == Mode.Interpreted && backend != Backend.Quotes) ||
-      (mode == Mode.Compiled && sortOrder != SortOrder.Unordered) ||
+      // (mode == Mode.Compiled && sortOrder != SortOrder.Unordered) ||
       (fuzzy != DEFAULT_FUZZY && compileSync == CompileSync.Blocking) ||
       (compileSync != CompileSync.Async && !useGlobalContext))
     throw new Exception(s"Weird options for mode $mode ($backend, $sortOrder, or $compileSync), are you sure?")
@@ -57,6 +57,11 @@ case class JITOptions(
     s"${programStr}_${granStr}_$backendStr"
 
   def getSortFn(storageManager: StorageManager): (Atom, Boolean) => (Boolean, Int) =
+    JITOptions.getSortFn(sortOrder, storageManager)
+}
+
+object JITOptions {
+  def getSortFn(sortOrder: SortOrder, storageManager: StorageManager): (Atom, Boolean) => (Boolean, Int) =
       sortOrder match
         case SortOrder.IntMax =>
           (a: Atom, isDelta: Boolean) => if (storageManager.edbContains(a.rId))
diff --git a/src/main/scala/datalog/execution/MacroCompiler.scala b/src/main/scala/datalog/execution/MacroCompiler.scala
index 88d2cd51..8e4b9ee8 100644
--- a/src/main/scala/datalog/execution/MacroCompiler.scala
+++ b/src/main/scala/datalog/execution/MacroCompiler.scala
@@ -47,12 +47,12 @@ abstract class MacroCompiler[T <: SolvableProgram](val makeProgram: ExecutionEng
   def makeEngine(): StagedExecutionEngine = {
     val storageManager = DefaultStorageManager()
     StagedExecutionEngine(storageManager, JITOptions(
-      mode = Mode.JIT, granularity = Granularity.ALL,
+      mode = Mode.JIT, granularity = Granularity.DELTA,
       // FIXME: make the dotty parameter optional, maybe by making it a
       // parameter of Backend.Quotes and having a separate Backend.Macro.
       dotty = null,
       compileSync = CompileSync.Blocking, sortOrder = SortOrder.Sel,
-      backend = Backend.Quotes))
+      backend = Backend.MacroQuotes))
   }
   private val engine: StagedExecutionEngine = makeEngine()
   val jitOptions: JITOptions = engine.defaultJITOptions
diff --git a/src/main/scala/datalog/execution/QuoteCompiler.scala b/src/main/scala/datalog/execution/QuoteCompiler.scala
index e4f0a649..dc625a5f 100644
--- a/src/main/scala/datalog/execution/QuoteCompiler.scala
+++ b/src/main/scala/datalog/execution/QuoteCompiler.scala
@@ -11,6 +11,58 @@ import java.util.concurrent.atomic.AtomicInteger
 import scala.collection.{immutable, mutable}
 import scala.quoted.*
 
+class MacroQuoteCompiler(storageManager: StorageManager)(using JITOptions) extends QuoteCompiler(storageManager) {
+  override def compileIRRelOp(irTree: IROp[EDB])(using stagedSM: Expr[StorageManager])(using Quotes): Expr[EDB] = {
+    irTree match {
+      case ProjectJoinFilterOp(rId, k, children: _*) =>
+        val deltaIdx = Expr(children.indexWhere(op =>
+          op match
+            case o: ScanOp => o.db == DB.Delta
+            case _ => false)
+        )
+        val unorderedChildren = Expr.ofSeq(children.map(compileIRRelOp))
+        val childrenLength = Expr(children.length)
+        val sortOrder = Expr(jitOptions.sortOrder)
+        val stagedId = Expr(rId)
+        val ruleHash = rId.toString + "%" + k.hash
+
+        jitOptions.sortOrder match {
+          case SortOrder.Sel | SortOrder.Mixed | SortOrder.IntMax | SortOrder.Worst if jitOptions.granularity.flag == irTree.code => '{
+            val originalK = $stagedSM.allRulesAllIndexes($stagedId).apply(${Expr(k.hash)})
+            val sortBy = JITOptions.getSortFn($sortOrder, $stagedSM)
+            val (newBody, newHash) =
+              if ($sortOrder == SortOrder.Worst)
+                JoinIndexes.presortSelectWorst(sortBy, originalK, $stagedSM, $deltaIdx)
+              else
+                JoinIndexes.presortSelect(sortBy, originalK, $stagedSM, $deltaIdx)
+            val newK = $stagedSM.allRulesAllIndexes($stagedId).getOrElseUpdate(
+              newHash,
+              JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns))
+            )
+            val unordered = $unorderedChildren
+            val orderedSeq = newK.atoms.view.drop(1).map(a => unordered(originalK.atoms.view.drop(1).indexOf(a))).to(immutable.ArraySeq)
+            $stagedSM.joinProjectHelper_withHash(
+              orderedSeq,
+              ${ Expr(rId) },
+              newK.hash,
+              ${ Expr(jitOptions.onlineSort) }
+            )
+          }
+          case _ => '{
+            $stagedSM.joinProjectHelper_withHash(
+              $unorderedChildren,
+              $stagedId,
+              ${ Expr(k.hash) },
+              ${ Expr(jitOptions.onlineSort) }
+            )
+          }
+        }
+      case _ =>
+        super.compileIRRelOp(irTree)
+    }
+  }
+}
+
 /**
  * Separate out compile logic from StagedExecutionEngine
  */
@@ -79,6 +131,18 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend
     }
   }
 
+  given ToExpr[SortOrder] with {
+    def apply(x: SortOrder)(using Quotes) = {
+      x match
+        case SortOrder.Sel => '{ SortOrder.Sel }
+        case SortOrder.IntMax => '{ SortOrder.IntMax }
+        case SortOrder.Mixed => '{ SortOrder.Mixed }
+        case SortOrder.Badluck => '{ SortOrder.Badluck }
+        case SortOrder.Unordered => '{ SortOrder.Unordered }
+        case SortOrder.Worst => '{ SortOrder.Worst }
+    }
+  }
+
   /**
    * Compiles a relational operator into a quote that returns an EDB. Future TODO: merge with compileIR when dotty supports.
    */
diff --git a/src/main/scala/datalog/execution/StagedExecutionEngine.scala b/src/main/scala/datalog/execution/StagedExecutionEngine.scala
index b75ba02b..93eae2f2 100644
--- a/src/main/scala/datalog/execution/StagedExecutionEngine.scala
+++ b/src/main/scala/datalog/execution/StagedExecutionEngine.scala
@@ -27,6 +27,7 @@ class StagedExecutionEngine(val storageManager: StorageManager, val defaultJITOp
 //  var threadpool: ExecutorService = null
 
   val compiler: StagedCompiler = defaultJITOptions.backend match
+    case Backend.MacroQuotes => MacroQuoteCompiler(storageManager)
     case Backend.Quotes => QuoteCompiler(storageManager)
     case Backend.Bytecode => BytecodeCompiler(storageManager)
     case Backend.Lambda => LambdaCompiler(storageManager)