Skip to content

Commit

Permalink
improvement: allow to run main class from deps with no inputs (#3079)
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek authored Aug 9, 2024
1 parent 7c08b92 commit 6503d2b
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 13 deletions.
11 changes: 8 additions & 3 deletions modules/build/src/main/scala/scala/build/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,14 @@ object Build {
def outputOpt: Some[os.Path] = Some(output)
def dependencyClassPath: Seq[os.Path] = sources.resourceDirs ++ artifacts.classPath
def fullClassPath: Seq[os.Path] = Seq(output) ++ dependencyClassPath
def foundMainClasses(): Seq[String] =
MainClass.find(output).sorted ++
options.classPathOptions.extraClassPath.flatMap(MainClass.find).sorted
def foundMainClasses(): Seq[String] = {
val found =
MainClass.find(output).sorted ++
options.classPathOptions.extraClassPath.flatMap(MainClass.find).sorted
if (inputs.isEmpty && found.isEmpty)
artifacts.jarsForUserExtraDependencies.flatMap(MainClass.findInDependency).sorted
else found
}
def retainedMainClass(
mainClasses: Seq[String],
commandString: String,
Expand Down
11 changes: 11 additions & 0 deletions modules/build/src/main/scala/scala/build/internal/MainClass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.objectweb.asm
import org.objectweb.asm.ClassReader

import java.io.{ByteArrayInputStream, InputStream}
import java.util.jar.{Attributes, JarFile, JarInputStream, Manifest}
import java.util.zip.ZipEntry

import scala.build.input.Element
Expand Down Expand Up @@ -67,6 +68,16 @@ object MainClass {
)
}

def findInDependency(jar: os.Path): Option[String] =
jar match {
case jar if os.isFile(jar) && jar.last.endsWith(".jar") =>
val jarFile = new JarFile(jar.toIO)
val manifest = jarFile.getManifest()
val mainClass = manifest.getMainAttributes().getValue(Attributes.Name.MAIN_CLASS)
Option(mainClass).map(_.asInstanceOf[String])
case _ => None
}

def find(output: os.Path): Seq[String] =
output match {
case o if os.isFile(o) && o.last.endsWith(".class") =>
Expand Down
11 changes: 10 additions & 1 deletion modules/cli/src/main/scala/scala/cli/commands/run/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,22 @@ object Run extends ScalaCommand[RunOptions] with BuildCommandHelpers {
}

def runCommand(
options: RunOptions,
options0: RunOptions,
inputArgs: Seq[String],
programArgs: Seq[String],
defaultInputs: () => Option[Inputs],
logger: Logger,
invokeData: ScalaCliInvokeData
): Unit = {
val shouldDefaultServerFalse =
inputArgs.isEmpty && options0.shared.compilationServer.server.isEmpty &&
!options0.shared.hasSnippets
val options = if (shouldDefaultServerFalse) options0.copy(shared =
options0.shared.copy(compilationServer =
options0.shared.compilationServer.copy(server = Some(false))
)
)
else options0
val initialBuildOptions = {
val buildOptions = buildOptionsOrExit(options)
if (invokeData.subCommand == SubCommand.Shebang) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ final case class SharedOptions(
))
.extractedClassPath

def extraClasspathWasPassed: Boolean = extraJarsAndClassPath.exists(!_.hasSourceJarSuffix)
def extraClasspathWasPassed: Boolean =
extraJarsAndClassPath.exists(!_.hasSourceJarSuffix) || dependencies.dependency.nonEmpty

def extraCompileOnlyClassPath: List[os.Path] = extraCompileOnlyJars.extractedClassPath

Expand Down Expand Up @@ -627,6 +628,10 @@ final case class SharedOptions(
def allJavaSnippets: List[String] = snippet.javaSnippet ++ snippet.executeJava
def allMarkdownSnippets: List[String] = snippet.markdownSnippet ++ snippet.executeMarkdown

def hasSnippets =
allScriptSnippets.nonEmpty || allScalaSnippets.nonEmpty || allJavaSnippets
.nonEmpty || allMarkdownSnippets.nonEmpty

def validateInputArgs(
args: Seq[String]
)(using ScalaCliInvokeData): Seq[Either[String, Seq[Element]]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import scala.util.Properties

trait RunScalacCompatTestDefinitions {
_: RunTestDefinitions =>

final val smithyVersion = "1.50.0"
private def shutdownBloop() =
os.proc(TestUtil.cli, "bloop", "exit", "--power").call(mergeErrIntoOut = true)

def commandLineScalacXOption(): Unit = {
val inputs = TestInputs(
os.rel / "Test.scala" ->
Expand Down Expand Up @@ -274,6 +279,42 @@ trait RunScalacCompatTestDefinitions {
expect(runRes.out.trim() == expectedOutput)
}
}

test("run main class from --dep even when no explicit inputs are passed") {
shutdownBloop()
val output = os.proc(
TestUtil.cli,
"--dep",
s"software.amazon.smithy:smithy-cli:$smithyVersion",
"--main-class",
"software.amazon.smithy.cli.SmithyCli",
"--",
"--version"
).call()
assert(output.exitCode == 0)
assert(output.out.text().contains(smithyVersion))

// assert bloop wasn't started
assertNoDiff(shutdownBloop().out.text(), "No running Bloop server found.")
}

test("find and run main class from --dep even when no explicit inputs are passed") {
shutdownBloop()
val output = os.proc(
TestUtil.cli,
"run",
"--dep",
s"software.amazon.smithy:smithy-cli:$smithyVersion",
"--",
"--version"
).call()
assert(output.exitCode == 0)
assert(output.out.text().contains(smithyVersion))

// assert bloop wasn't started
assertNoDiff(shutdownBloop().out.text(), "No running Bloop server found.")
}

test("dont clear output dir") {
val expectedOutput = "Hello"
val `lib.scala` = os.rel / "lib.scala"
Expand Down
28 changes: 24 additions & 4 deletions modules/options/src/main/scala/scala/build/Artifacts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ import scala.build.errors.{
import scala.build.internal.Constants
import scala.build.internal.Constants.*
import scala.build.internal.CsLoggerUtil.*
import scala.build.internal.Util.PositionedScalaDependencyOps
import scala.build.internal.Util.{PositionedScalaDependencyOps, ScalaModuleOps}
import scala.collection.mutable

final case class Artifacts(
javacPluginDependencies: Seq[(AnyDependency, String, os.Path)],
extraJavacPlugins: Seq[os.Path],
userDependencies: Seq[AnyDependency],
defaultDependencies: Seq[AnyDependency],
extraDependencies: Seq[AnyDependency],
userCompileOnlyDependencies: Seq[AnyDependency],
internalDependencies: Seq[AnyDependency],
detailedArtifacts: Seq[(CsDependency, csCore.Publication, csUtil.Artifact, os.Path)],
Expand All @@ -41,6 +42,22 @@ final case class Artifacts(
hasJvmRunner: Boolean,
resolution: Option[Resolution]
) {

def userDependencies = defaultDependencies ++ extraDependencies
lazy val jarsForUserExtraDependencies = {
val extraDependenciesMap =
extraDependencies.map(dep => dep.module.name -> dep.version).toMap
detailedArtifacts
.iterator
.collect {
case (dep, pub, _, path)
if pub.classifier != Classifier.sources &&
extraDependenciesMap.get(dep.module.name.value).contains(dep.version) => path
}
.toVector
.distinct
}

lazy val artifacts: Seq[(String, os.Path)] =
detailedArtifacts
.iterator
Expand Down Expand Up @@ -93,7 +110,8 @@ object Artifacts {
scalaArtifactsParamsOpt: Option[ScalaArtifactsParams],
javacPluginDependencies: Seq[Positioned[AnyDependency]],
extraJavacPlugins: Seq[os.Path],
dependencies: Seq[Positioned[AnyDependency]],
defaultDependencies: Seq[Positioned[AnyDependency]],
extraDependencies: Seq[Positioned[AnyDependency]],
compileOnlyDependencies: Seq[Positioned[AnyDependency]],
extraClassPath: Seq[os.Path],
extraCompileOnlyJars: Seq[os.Path],
Expand All @@ -109,6 +127,7 @@ object Artifacts {
logger: Logger,
maybeRecoverOnError: BuildException => Option[BuildException]
): Either[BuildException, Artifacts] = either {
val dependencies = defaultDependencies ++ extraDependencies

val jvmTestRunnerDependencies =
if (addJvmTestRunner)
Expand Down Expand Up @@ -428,7 +447,8 @@ object Artifacts {
Artifacts(
javacPlugins0,
extraJavacPlugins,
dependencies.map(_.value) ++ scalaOpt.toSeq.flatMap(_.extraDependencies),
defaultDependencies.map(_.value),
extraDependencies.map(_.value) ++ scalaOpt.toSeq.flatMap(_.extraDependencies),
compileOnlyDependencies.map(_.value),
internalDependencies.map(_.value),
fetchRes.fullDetailedArtifacts.collect { case (d, p, a, Some(f)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,11 @@ final case class BuildOptions(
}
else Nil
}
private def dependencies: Either[BuildException, Seq[Positioned[AnyDependency]]] = either {
private def defaultDependencies: Either[BuildException, Seq[Positioned[AnyDependency]]] = either {
value(maybeJsDependencies).map(Positioned.none(_)) ++
value(maybeNativeDependencies).map(Positioned.none(_)) ++
value(scalaLibraryDependencies).map(Positioned.none(_)) ++
value(scalaCompilerDependencies).map(Positioned.none(_)) ++
classPathOptions.extraDependencies.toSeq
value(scalaCompilerDependencies).map(Positioned.none(_))
}

private def semanticDbPlugins(logger: Logger): Either[BuildException, Seq[AnyDependency]] =
Expand Down Expand Up @@ -451,7 +450,8 @@ final case class BuildOptions(
scalaArtifactsParamsOpt,
javacPluginDependencies = value(javacPluginDependencies),
extraJavacPlugins = javaOptions.javacPlugins.map(_.value),
dependencies = value(dependencies),
defaultDependencies = value(defaultDependencies),
extraDependencies = classPathOptions.extraDependencies.toSeq,
compileOnlyDependencies = classPathOptions.extraCompileOnlyDependencies.toSeq,
extraClassPath = allExtraJars,
extraCompileOnlyJars = allExtraCompileOnlyJars,
Expand Down

0 comments on commit 6503d2b

Please sign in to comment.