From 80076269693bfb7bb4639c9898220ce47b966d87 Mon Sep 17 00:00:00 2001
From: Adrien Piquerez <adrien.piquerez@gmail.com>
Date: Wed, 19 Jun 2024 16:40:02 +0200
Subject: [PATCH] Test products of compilation to jar

---
 .../xsbt/ExtractUsedNamesSpecification.scala  |   5 +-
 .../test/xsbt/ProductsSpecification.scala     |  34 ++++++
 .../xsbt/ScalaCompilerForUnitTesting.scala    | 102 +++++++++---------
 3 files changed, 89 insertions(+), 52 deletions(-)
 create mode 100644 sbt-bridge/test/xsbt/ProductsSpecification.scala

diff --git a/sbt-bridge/test/xsbt/ExtractUsedNamesSpecification.scala b/sbt-bridge/test/xsbt/ExtractUsedNamesSpecification.scala
index e47371175de6..0abefe2985c3 100644
--- a/sbt-bridge/test/xsbt/ExtractUsedNamesSpecification.scala
+++ b/sbt-bridge/test/xsbt/ExtractUsedNamesSpecification.scala
@@ -1,7 +1,6 @@
 package xsbt
 
 import xsbti.UseScope
-import ScalaCompilerForUnitTesting.Callbacks
 
 import org.junit.{ Test, Ignore }
 import org.junit.Assert._
@@ -227,9 +226,9 @@ class ExtractUsedNamesSpecification {
 
     def findPatMatUsages(in: String): Set[String] = {
       val compilerForTesting = new ScalaCompilerForUnitTesting
-      val (_, Callbacks(callback, _)) =
+      val output =
         compilerForTesting.compileSrcs(List(List(sealedClass, in)))
-      val clientNames = callback.usedNamesAndScopes.view.filterKeys(!_.startsWith("base."))
+      val clientNames = output.analysis.usedNamesAndScopes.view.filterKeys(!_.startsWith("base."))
 
       val names: Set[String] = clientNames.flatMap {
         case (_, usages) =>
diff --git a/sbt-bridge/test/xsbt/ProductsSpecification.scala b/sbt-bridge/test/xsbt/ProductsSpecification.scala
new file mode 100644
index 000000000000..b13defecc4cc
--- /dev/null
+++ b/sbt-bridge/test/xsbt/ProductsSpecification.scala
@@ -0,0 +1,34 @@
+package xsbt
+
+import org.junit.Assert.*
+import org.junit.Ignore
+import org.junit.Test
+
+import java.io.File
+import java.nio.file.Path
+import java.nio.file.Paths
+
+class ProductsSpecification {
+
+  @Test
+  def extractProductsFromJar = {
+    val src =
+      """package example
+        |
+        |class A {
+        |  class B
+        |  def foo =
+        |    class C
+        |}""".stripMargin
+    val output = compiler.compileSrcsToJar(src)
+    val srcFile = output.srcFiles.head
+    val products = output.analysis.productClassesToSources.filter(_._2 == srcFile).keys.toSet
+    
+    def toPathInJar(className: String): Path =
+      Paths.get(s"${output.classesOutput}!${className.replace('.', File.separatorChar)}.class")
+    val expected = Set("example.A", "example.A$B", "example.A$C$1").map(toPathInJar)
+    assertEquals(products, expected)
+  }
+
+  private def compiler = new ScalaCompilerForUnitTesting
+}
\ No newline at end of file
diff --git a/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala b/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala
index f17be692ee50..fd125f25560b 100644
--- a/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala
+++ b/sbt-bridge/test/xsbt/ScalaCompilerForUnitTesting.scala
@@ -1,22 +1,19 @@
 /** Adapted from https://github.com/sbt/sbt/blob/0.13/compile/interface/src/test/scala/xsbt/ScalaCompilerForUnitTesting.scala */
 package xsbt
 
-import xsbti.compile.{CompileProgress, SingleOutput}
-import java.io.File
-import xsbti._
-import sbt.io.IO
-import xsbti.api.{ ClassLike, Def, DependencyContext }
-import DependencyContext._
-import xsbt.api.SameAPI
-import sbt.internal.util.ConsoleLogger
-import dotty.tools.io.PlainFile.toPlainFile
 import dotty.tools.xsbt.CompilerBridge
+import sbt.io.IO
+import xsbti.*
+import xsbti.api.ClassLike
+import xsbti.api.DependencyContext.*
+import xsbti.compile.SingleOutput
+
+import java.io.File
+import java.nio.file.Path
 
 import TestCallback.ExtractedClassDependencies
-import ScalaCompilerForUnitTesting.Callbacks
 
-object ScalaCompilerForUnitTesting:
-  case class Callbacks(analysis: TestCallback, progress: TestCompileProgress)
+case class CompileOutput(srcFiles: Seq[VirtualFileRef], classesOutput: Path, analysis: TestCallback, progress: TestCompileProgress)
 
 /**
  * Provides common functionality needed for unit tests that require compiling
@@ -25,29 +22,24 @@ object ScalaCompilerForUnitTesting:
 class ScalaCompilerForUnitTesting {
 
   def extractEnteredPhases(srcs: String*): Seq[List[String]] = {
-    val (tempSrcFiles, Callbacks(_, testProgress)) = compileSrcs(srcs*)
-    val run = testProgress.runs.head
-    tempSrcFiles.map(src => run.unitPhases(src.id))
+    val output = compileSrcs(srcs*)
+    val run = output.progress.runs.head
+    output.srcFiles.map(src => run.unitPhases(src.id))
   }
 
-  def extractTotal(srcs: String*)(extraSourcePath: String*): Int = {
-    val (tempSrcFiles, Callbacks(_, testProgress)) = compileSrcs(List(srcs.toList), extraSourcePath.toList)
-    val run = testProgress.runs.head
-    run.total
-  }
+  def extractTotal(srcs: String*)(extraSourcePath: String*): Int =
+    compileSrcs(List(srcs.toList), extraSourcePath.toList).progress.runs.head.total
 
-  def extractProgressPhases(srcs: String*): List[String] = {
-    val (_, Callbacks(_, testProgress)) = compileSrcs(srcs*)
-    testProgress.runs.head.phases
-  }
+  def extractProgressPhases(srcs: String*): List[String] =
+    compileSrcs(srcs*).progress.runs.head.phases
 
   /**
    * Compiles given source code using Scala compiler and returns API representation
    * extracted by ExtractAPI class.
    */
   def extractApiFromSrc(src: String): Seq[ClassLike] = {
-    val (Seq(tempSrcFile), Callbacks(analysisCallback, _)) = compileSrcs(src)
-    analysisCallback.apis(tempSrcFile)
+    val output = compileSrcs(src)
+    output.analysis.apis(output.srcFiles.head)
   }
 
   /**
@@ -55,8 +47,8 @@ class ScalaCompilerForUnitTesting {
    * extracted by ExtractAPI class.
    */
   def extractApisFromSrcs(srcs: List[String]*): Seq[Seq[ClassLike]] = {
-    val (tempSrcFiles, Callbacks(analysisCallback, _)) = compileSrcs(srcs.toList)
-    tempSrcFiles.map(analysisCallback.apis)
+    val output = compileSrcs(srcs.toList)
+    output.srcFiles.map(output.analysis.apis)
   }
 
   /**
@@ -73,15 +65,16 @@ class ScalaCompilerForUnitTesting {
       assertDefaultScope: Boolean = true
   ): Map[String, Set[String]] = {
     // we drop temp src file corresponding to the definition src file
-    val (Seq(_, tempSrcFile), Callbacks(analysisCallback, _)) = compileSrcs(definitionSrc, actualSrc)
+    val output = compileSrcs(definitionSrc, actualSrc)
+    val analysis = output.analysis
 
     if (assertDefaultScope) for {
-      (className, used) <- analysisCallback.usedNamesAndScopes
-      analysisCallback.TestUsedName(name, scopes) <- used
+      (className, used) <- analysis.usedNamesAndScopes
+      analysis.TestUsedName(name, scopes) <- used
     } assert(scopes.size() == 1 && scopes.contains(UseScope.Default), s"$className uses $name in $scopes")
 
-    val classesInActualSrc = analysisCallback.classNames(tempSrcFile).map(_._1)
-    classesInActualSrc.map(className => className -> analysisCallback.usedNames(className)).toMap
+    val classesInActualSrc = analysis.classNames(output.srcFiles.head).map(_._1)
+    classesInActualSrc.map(className => className -> analysis.usedNames(className)).toMap
   }
 
   /**
@@ -91,11 +84,11 @@ class ScalaCompilerForUnitTesting {
    * Only the names used in the last src file are returned.
    */
   def extractUsedNamesFromSrc(sources: String*): Map[String, Set[String]] = {
-    val (srcFiles, Callbacks(analysisCallback, _)) = compileSrcs(sources*)
-    srcFiles
+    val output = compileSrcs(sources*)
+    output.srcFiles
       .map { srcFile =>
-        val classesInSrc = analysisCallback.classNames(srcFile).map(_._1)
-        classesInSrc.map(className => className -> analysisCallback.usedNames(className)).toMap
+        val classesInSrc = output.analysis.classNames(srcFile).map(_._1)
+        classesInSrc.map(className => className -> output.analysis.usedNames(className)).toMap
       }
       .reduce(_ ++ _)
   }
@@ -113,15 +106,15 @@ class ScalaCompilerForUnitTesting {
    * file system-independent way of testing dependencies between source code "files".
    */
   def extractDependenciesFromSrcs(srcs: List[List[String]]): ExtractedClassDependencies = {
-    val (_, Callbacks(testCallback, _)) = compileSrcs(srcs)
+    val analysis = compileSrcs(srcs).analysis
 
-    val memberRefDeps = testCallback.classDependencies collect {
+    val memberRefDeps = analysis.classDependencies collect {
       case (target, src, DependencyByMemberRef) => (src, target)
     }
-    val inheritanceDeps = testCallback.classDependencies collect {
+    val inheritanceDeps = analysis.classDependencies collect {
       case (target, src, DependencyByInheritance) => (src, target)
     }
-    val localInheritanceDeps = testCallback.classDependencies collect {
+    val localInheritanceDeps = analysis.classDependencies collect {
       case (target, src, LocalDependencyByInheritance) => (src, target)
     }
     ExtractedClassDependencies.fromPairs(memberRefDeps, inheritanceDeps, localInheritanceDeps)
@@ -142,12 +135,20 @@ class ScalaCompilerForUnitTesting {
    * The sequence of temporary files corresponding to passed snippets and analysis
    * callback is returned as a result.
    */
-  def compileSrcs(groupedSrcs: List[List[String]], sourcePath: List[String] = Nil): (Seq[VirtualFile], Callbacks) = {
+  def compileSrcs(groupedSrcs: List[List[String]], sourcePath: List[String] = Nil, compileToJar: Boolean = false): CompileOutput = {
       val temp = IO.createTemporaryDirectory
       val analysisCallback = new TestCallback
       val testProgress = new TestCompileProgress
-      val classesDir = new File(temp, "classes")
-      classesDir.mkdir()
+      val classesOutput = 
+        if (compileToJar) {
+          val jar = new File(temp, "classes.jar")
+          jar.createNewFile()
+          jar
+        } else {
+          val dir = new File(temp, "classes")
+          dir.mkdir()
+          dir
+        }
 
       val bridge = new CompilerBridge
 
@@ -164,16 +165,16 @@ class ScalaCompilerForUnitTesting {
         }
 
         val virtualSrcFiles = srcFiles.toArray
-        val classesDirPath = classesDir.getAbsolutePath.toString
+        val classesOutputPath = classesOutput.getAbsolutePath()
         val output = new SingleOutput:
-          def getOutputDirectory() = classesDir
+          def getOutputDirectory() = classesOutput
 
         val maybeSourcePath = if extraFiles.isEmpty then Nil else List("-sourcepath", temp.getAbsolutePath.toString)
 
         bridge.run(
           virtualSrcFiles,
           new TestDependencyChanges,
-          Array("-Yforce-sbt-phases", "-classpath", classesDirPath, "-usejavacp", "-d", classesDirPath) ++ maybeSourcePath,
+          Array("-Yforce-sbt-phases", "-classpath", classesOutputPath, "-usejavacp", "-d", classesOutputPath) ++ maybeSourcePath,
           output,
           analysisCallback,
           new TestReporter,
@@ -185,13 +186,16 @@ class ScalaCompilerForUnitTesting {
 
         srcFiles
       }
-      (files.flatten.toSeq, Callbacks(analysisCallback, testProgress))
+      CompileOutput(files.flatten.toSeq, classesOutput.toPath, analysisCallback, testProgress)
   }
 
-  def compileSrcs(srcs: String*): (Seq[VirtualFile], Callbacks) = {
+  def compileSrcs(srcs: String*): CompileOutput = {
     compileSrcs(List(srcs.toList))
   }
 
+  def compileSrcsToJar(srcs: String*): CompileOutput =
+    compileSrcs(List(srcs.toList), compileToJar = true)
+
   private def prepareSrcFile(baseDir: File, fileName: String, src: String): VirtualFile = {
     val srcFile = new File(baseDir, fileName)
     IO.write(srcFile, src)