diff --git a/api/src/main/scala/com.olegych.scastie.api/ApiModels.scala b/api/src/main/scala/com.olegych.scastie.api/ApiModels.scala index 2e7c3682b..4639eca38 100644 --- a/api/src/main/scala/com.olegych.scastie.api/ApiModels.scala +++ b/api/src/main/scala/com.olegych.scastie.api/ApiModels.scala @@ -2,10 +2,10 @@ package com.olegych.scastie.api import play.api.libs.json._ -case object SbtPing -case object SbtPong +case object RunnerPing +case object RunnerPong -case class SbtRunnerConnect(hostname: String, port: Int) +case class RunnerConnect(hostname: String, port: Int) case object ActorConnected object SnippetSummary { @@ -127,7 +127,9 @@ case class ScalaDependency( override def toString: String = target.renderSbt(this) } -case class ScastieMetalsOptions(dependencies: Set[ScalaDependency], scalaTarget: ScalaTarget) +// Note: adding a code parameter is for the metals-runner +// so it can parse dependencies and give support for it :) +case class ScastieMetalsOptions(dependencies: Set[ScalaDependency], scalaTarget: ScalaTarget, code: Option[String] = None) object ScastieMetalsOptions { implicit val scastieMetalsOptions: OFormat[ScastieMetalsOptions] = Json.format[ScastieMetalsOptions] @@ -141,6 +143,8 @@ sealed trait FailureType { case class NoResult(msg: String) extends FailureType case class PresentationCompilerFailure(msg: String) extends FailureType +case class InvalidScalaVersion(msg: String) extends FailureType + object FailureType { implicit val failureTypeFormat: OFormat[FailureType] = Json.format[FailureType] @@ -158,6 +162,10 @@ object ScastieOffsetParams { implicit val scastieOffsetParams: OFormat[ScastieOffsetParams] = Json.format[ScastieOffsetParams] } +object InvalidScalaVersion { + implicit val InvalidScalaVersionFormat: OFormat[InvalidScalaVersion] = Json.format[InvalidScalaVersion] +} + case class LSPRequestDTO(options: ScastieMetalsOptions, offsetParams: ScastieOffsetParams) case class CompletionInfoRequest(options: ScastieMetalsOptions, completionItem: CompletionItemDTO) diff --git a/api/src/main/scala/com.olegych.scastie.api/ScalaTarget.scala b/api/src/main/scala/com.olegych.scastie.api/ScalaTarget.scala index d1e7fceaa..d89d992b2 100644 --- a/api/src/main/scala/com.olegych.scastie.api/ScalaTarget.scala +++ b/api/src/main/scala/com.olegych.scastie.api/ScalaTarget.scala @@ -76,6 +76,7 @@ object ScalaTarget { formatNative.writes(native) ++ JsObject(Seq("tpe" -> JsString("Native"))) case dotty: Scala3 => formatScala3.writes(dotty) ++ JsObject(Seq("tpe" -> JsString("Scala3"))) + case scli: ScalaCli => JsObject(Seq("tpe" -> JsString("ScalaCli"))) } } @@ -91,6 +92,7 @@ object ScalaTarget { case "Typelevel" => formatTypelevel.reads(json) case "Native" => formatNative.reads(json) case "Scala3" | "Dotty" => formatScala3.reads(json) + case "ScalaCli" => JsSuccess(ScalaCli()) case _ => JsError(Seq()) } case _ => JsError(Seq()) @@ -109,6 +111,29 @@ object ScalaTarget { ) ) + def fromScalaVersion(version: String): Option[ScalaTarget] = { + if (version.startsWith("3")) { + if (version == "3") + Some(ScalaTarget.Scala3(BuildInfo.latest3)) + else + Some(ScalaTarget.Scala3(version)) + } else if (version.startsWith("2")) { + if (version == "2") + Some(ScalaTarget.Jvm(BuildInfo.latest213)) + else if (version == "2.13") + Some(ScalaTarget.Jvm(BuildInfo.latest213)) + else if (version == "2.12") + Some(ScalaTarget.Jvm(BuildInfo.latest212)) + else if (version == "2.11") + Some(ScalaTarget.Jvm(BuildInfo.latest211)) + else if (version == "2.10") + Some(ScalaTarget.Jvm(BuildInfo.latest210)) + else + Some(ScalaTarget.Jvm(version)) + } else + None + } + object Jvm { def default: ScalaTarget = ScalaTarget.Jvm(scalaVersion = BuildInfo.latest213) } @@ -299,4 +324,37 @@ object ScalaTarget { override def toString: String = s"Scala $scalaVersion" } + + object ScalaCli { + def default: ScalaTarget = ScalaCli() + + def defaultCode: String = + """|// Hello! + |// Scastie is compatible with Scala CLI! You can use + |// directives: https://scala-cli.virtuslab.org/docs/guides/using-directives/ + | + |println("Hi Scala CLI <3") + """.stripMargin + } + + case class ScalaCli(scalaBinaryVersion0: String = "") extends ScalaTarget { + override def binaryScalaVersion: String = scalaBinaryVersion0 + + override def scalaVersion: String = "" + + override def targetType: ScalaTargetType = ScalaTargetType.ScalaCli + + override def scaladexRequest: Map[String,String] = + Map("target" -> "JVM") + + override def renderSbt(lib: ScalaDependency): String = "// Non-applicable" + + override def sbtConfig: String = "// Non-applicable" + + override def sbtPluginsConfig: String = "// Non-applicable" + + override def sbtRunCommand(worksheetMode: Boolean): String = ??? + + override def toString: String = "Scala-CLI" + } } diff --git a/api/src/main/scala/com.olegych.scastie.api/ScalaTargetType.scala b/api/src/main/scala/com.olegych.scastie.api/ScalaTargetType.scala index 7bc640e9e..74d2c5df6 100644 --- a/api/src/main/scala/com.olegych.scastie.api/ScalaTargetType.scala +++ b/api/src/main/scala/com.olegych.scastie.api/ScalaTargetType.scala @@ -15,6 +15,7 @@ object ScalaTargetType { case "JS" => Some(JS) case "NATIVE" => Some(Native) case "TYPELEVEL" => Some(Typelevel) + case "SCALACLI" => Some(ScalaCli) case _ => None } } @@ -30,7 +31,8 @@ object ScalaTargetType { Scala3, JS, Native, - Typelevel + Typelevel, + ScalaCli ).map(v => (v.toString, v)).toMap def reads(json: JsValue): JsResult[ScalaTargetType] = { @@ -65,4 +67,8 @@ object ScalaTargetType { case object Typelevel extends ScalaTargetType { def defaultScalaTarget: ScalaTarget = ScalaTarget.Typelevel.default } + + case object ScalaCli extends ScalaTargetType { + def defaultScalaTarget: ScalaTarget = ScalaTarget.ScalaCli.default + } } diff --git a/api/src/main/scala/com.olegych.scastie.api/ScliState.scala b/api/src/main/scala/com.olegych.scastie.api/ScliState.scala new file mode 100644 index 000000000..9f4186154 --- /dev/null +++ b/api/src/main/scala/com.olegych.scastie.api/ScliState.scala @@ -0,0 +1,40 @@ +package com.olegych.scastie.api + +import play.api.libs.json._ + +sealed trait ScliState extends ServerState +object ScliState { + case object Unknown extends ScliState { + override def toString: String = "Unknown" + def isReady: Boolean = true + } + + case object Disconnected extends ScliState { + override def toString: String = "Disconnected" + def isReady: Boolean = false + } + + implicit object ScliStateFormat extends Format[ScliState] { + def writes(state: ScliState): JsValue = { + JsString(state.toString) + } + + private val values = + List( + Unknown, + Disconnected + ).map(v => (v.toString, v)).toMap + + def reads(json: JsValue): JsResult[ScliState] = { + json match { + case JsString(tpe) => { + values.get(tpe) match { + case Some(v) => JsSuccess(v) + case _ => JsError(Seq()) + } + } + case _ => JsError(Seq()) + } + } + } +} \ No newline at end of file diff --git a/balancer/src/main/resources/reference.conf b/balancer/src/main/resources/reference.conf index 7d20c58d0..6088977c3 100644 --- a/balancer/src/main/resources/reference.conf +++ b/balancer/src/main/resources/reference.conf @@ -3,9 +3,13 @@ com.olegych.scastie.balancer { snippets-dir = ./target/snippets/ old-snippets-dir = ./target/old-snippets/ - remote-hostname = "127.0.0.1" + remote-sbt-hostname = "127.0.0.1" remote-sbt-ports-start = 5150 remote-sbt-ports-size = 1 + + remote-scli-hostname = "127.0.0.1" + remote-scli-ports-start = 5250 + remote-scli-ports-size = 1 } akka.actor.warn-about-java-serializer-usage = false diff --git a/balancer/src/main/scala/com.olegych.scastie.balancer/BaseDispatcher.scala b/balancer/src/main/scala/com.olegych.scastie.balancer/BaseDispatcher.scala new file mode 100644 index 000000000..ffaf27061 --- /dev/null +++ b/balancer/src/main/scala/com.olegych.scastie.balancer/BaseDispatcher.scala @@ -0,0 +1,72 @@ +package com.olegych.scastie.balancer + +import com.typesafe.config.Config +import akka.actor.ActorSelection +import com.olegych.scastie.api.ActorConnected +import akka.actor.ActorLogging +import akka.actor.Actor +import akka.actor.ActorRef +import scala.concurrent.Future +import akka.pattern.ask +import akka.util.Timeout +import scala.concurrent.duration._ +import com.olegych.scastie.api.RunnerPing + +abstract class BaseDispatcher[R, S](config: Config) extends Actor with ActorLogging { + case class SocketAddress(host: String, port: Int) + + import context._ + + private def getRemoteActorsPath( + key: String, + runnerName: String, + actorName: String + ): Map[SocketAddress, String] = { + val host = config.getString(s"remote-$key-hostname") + val portStart = config.getInt(s"remote-$key-ports-start") + val portSize = config.getInt(s"remote-$key-ports-size") + (0 until portSize).map(_ + portStart) + .map(port => { + val addr = SocketAddress(host, port) + (addr, getRemoteActorPath(runnerName, addr, actorName)) + }) + .toMap + } + + def getRemoteActorPath( + runnerName: String, + runnerAddress: SocketAddress, + actorName: String + ) = s"akka://$runnerName@${runnerAddress.host}:${runnerAddress.port}/user/$actorName" + + def connectRunner(path: String): ActorSelection = { + val selection = context.actorSelection(path) + selection ! ActorConnected + selection + } + + def getRemoteServers( + key: String, + runnerName: String, + actorName: String + ): Map[SocketAddress, ActorSelection] = { + getRemoteActorsPath(key, runnerName, actorName).map { + case (address, url) => (address, connectRunner(url)) + } + } + + def ping(servers: List[ActorSelection]): Future[List[Boolean]] = { + implicit val timeout: Timeout = Timeout(10.seconds) + val futures = servers.map { s => + (s ? RunnerPing).map { _ => + log.info(s"pinged $s") + true + }.recover { e => + log.error(e, s"could not ping $s") + false + } + } + Future.sequence(futures) + } + +} diff --git a/balancer/src/main/scala/com.olegych.scastie.balancer/DispatchActor.scala b/balancer/src/main/scala/com.olegych.scastie.balancer/DispatchActor.scala index 3f0181550..6f3b48136 100644 --- a/balancer/src/main/scala/com.olegych.scastie.balancer/DispatchActor.scala +++ b/balancer/src/main/scala/com.olegych.scastie.balancer/DispatchActor.scala @@ -6,6 +6,7 @@ import akka.actor.ActorRef import akka.actor.ActorSelection import akka.actor.OneForOneStrategy import akka.actor.SupervisorStrategy +import akka.actor.Props import akka.event import akka.pattern.ask import akka.remote.DisassociatedEvent @@ -24,6 +25,12 @@ import java.time.Instant import java.util.concurrent.Executors import scala.concurrent._ import scala.concurrent.duration._ +import com.olegych.scastie.api.ScalaTarget.Typelevel +import com.olegych.scastie.api.ScalaTarget.Native +import com.olegych.scastie.api.ScalaTarget.Jvm +import com.olegych.scastie.api.ScalaTarget.Js +import com.olegych.scastie.api.ScalaTarget.Scala3 +import com.olegych.scastie.api.ScalaTarget.ScalaCli case class Address(host: String, port: Int) case class SbtConfig(config: String) @@ -61,48 +68,36 @@ case class Done(progress: api.SnippetProgress, retries: Int) case object Ping +/** + * This Actor creates and takes care of two dispatchers: SbtDispatcher and ScliDispatcher. + * It will receive every request and forward to the proper dispatcher every request. + * + * @param progressActor + * @param statusActor + */ class DispatchActor(progressActor: ActorRef, statusActor: ActorRef) // extends PersistentActor with AtLeastOnceDelivery extends Actor with ActorLogging { - override def supervisorStrategy: SupervisorStrategy = OneForOneStrategy() { - case e => - log.error(e, "failure") - SupervisorStrategy.resume - } - private val config = ConfigFactory.load().getConfig("com.olegych.scastie.balancer") - private val host = config.getString("remote-hostname") - private val sbtPortsStart = config.getInt("remote-sbt-ports-start") - private val sbtPortsSize = config.getInt("remote-sbt-ports-size") - - private val sbtPorts = (0 until sbtPortsSize).map(sbtPortsStart + _) - - private def connectRunner( - runnerName: String, - actorName: String, - host: String - )(port: Int): ((String, Int), ActorSelection) = { - val path = s"akka://$runnerName@$host:$port/user/$actorName" - log.info(s"Connecting to ${path}") - val selection = context.actorSelection(path) - selection ! ActorConnected - (host, port) -> selection - } - private var remoteSbtSelections = - sbtPorts.map(connectRunner("SbtRunner", "SbtActor", host)).toMap + // Dispatchers + val sbtDispatcher: ActorRef = context.actorOf( + Props(new SbtDispatcher(config, progressActor, statusActor)), + "SbtDispatcher" + ) - private var sbtLoadBalancer: SbtBalancer = { - val sbtServers = remoteSbtSelections.to(Vector).map { - case (_, ref) => - val state: SbtState = SbtState.Unknown - Server(ref, Inputs.default, state) - } + val scliDispatcher: ActorRef = context.actorOf( + Props(new ScliDispatcher(config, progressActor, statusActor)), + "ScliDispatcher" + ) - LoadBalancer(servers = sbtServers) + override def supervisorStrategy: SupervisorStrategy = OneForOneStrategy() { + case e => + log.error(e, "failure") + SupervisorStrategy.resume } import context._ @@ -138,38 +133,8 @@ class DispatchActor(progressActor: ActorRef, statusActor: ActorRef) new InMemoryContainer } - private def updateSbtBalancer(newSbtBalancer: SbtBalancer): Unit = { - if (sbtLoadBalancer != newSbtBalancer) { - statusActor ! SbtLoadBalancerUpdate(newSbtBalancer) - } - sbtLoadBalancer = newSbtBalancer - () - } - - //can be called from future - private def run(inputsWithIpAndUser: InputsWithIpAndUser, snippetId: SnippetId): Unit = { + def run(inputsWithIpAndUser: InputsWithIpAndUser, snippetId: SnippetId) = self ! Run(inputsWithIpAndUser, snippetId) - } - //cannot be called from future - private def run0(inputsWithIpAndUser: InputsWithIpAndUser, snippetId: SnippetId): Unit = { - - val InputsWithIpAndUser(inputs, UserTrace(ip, user)) = inputsWithIpAndUser - - log.info("id: {}, ip: {} run inputs: {}", snippetId, ip, inputs) - - val task = Task(inputs, Ip(ip), TaskId(snippetId), Instant.now) - - sbtLoadBalancer.add(task) match { - case Some((server, newBalancer)) => - updateSbtBalancer(newBalancer) - - server.ref.tell( - SbtTask(snippetId, inputs, ip, user.map(_.login), progressActor), - self - ) - case _ => () - } - } private def logError[T](f: Future[T]) = { f.recover { @@ -178,12 +143,10 @@ class DispatchActor(progressActor: ActorRef, statusActor: ActorRef) } def receive: Receive = event.LoggingReceive(event.Logging.InfoLevel) { - case SbtPong => () + case api.RunnerPong => () case format: FormatRequest => - val server = sbtLoadBalancer.getRandomServer - server.foreach(_.ref.tell(format, sender())) - () + sbtDispatcher.tell(format, sender()) case x @ RunSnippet(inputsWithIpAndUser) => log.info(s"starting ${x}") @@ -269,14 +232,17 @@ class DispatchActor(progressActor: ActorRef, statusActor: ActorRef) val sender = this.sender() logError(container.removeUserSnippets(UserLogin(user.login)).map(sender ! _)) - case progress: api.SnippetProgress => + case x @ ReceiveStatus(requester) => sbtDispatcher.tell(x, sender()) + + case statusProgress: StatusProgress => + statusActor ! statusProgress + + case progress: SnippetProgress => val sender = this.sender() - if (progress.isDone) { - self ! Done(progress, retries = 100) - } + + logError( - container - .appendOutput(progress) + container.appendOutput(progress) .recover { case e => log.error(e, s"failed to save $progress from $sender") @@ -285,85 +251,29 @@ class DispatchActor(progressActor: ActorRef, statusActor: ActorRef) .map(sender ! _) ) - case done: Done => - done.progress.snippetId.foreach { sid => - val newBalancer = sbtLoadBalancer.done(TaskId(sid)) - newBalancer match { - case Some(newBalancer) => - updateSbtBalancer(newBalancer) - case None => - if (done.retries >= 0) { - system.scheduler.scheduleOnce(1.second) { - self ! done.copy(retries = done.retries - 1) - } - } else { - val taskIds = - sbtLoadBalancer.servers.flatMap(_.mailbox.map(_.taskId)) - log.error(s"stopped retrying to update ${taskIds} with ${done}") - } - } - } - - case event: DisassociatedEvent => - for { - host <- event.remoteAddress.host - port <- event.remoteAddress.port - ref <- remoteSbtSelections.get((host, port)) - } { - log.warning("removing disconnected: {}", ref) - val previousRemoteSbtSelections = remoteSbtSelections - remoteSbtSelections = remoteSbtSelections - ((host, port)) - if (previousRemoteSbtSelections != remoteSbtSelections) { - updateSbtBalancer(sbtLoadBalancer.removeServer(ref)) - } + case run: Run => { + run.inputsWithIpAndUser.inputs.target match { + case ScalaCli(_) => + println(s"Forwarding run to Scala-CLI dispatcher: ${run.snippetId}") + scliDispatcher ! run + case _ => + println(s"Forwarding run to SBT dispatcher: ${run.snippetId}") + sbtDispatcher ! run } + } - case SbtUp => - log.info("SbtUp") - - case Replay(SbtRun(snippetId, inputs, progressActor, snippetActor)) => - log.info("Replay: " + inputs.code) - - case SbtRunnerConnect(runnerHostname, runnerAkkaPort) => - if (!remoteSbtSelections.contains((runnerHostname, runnerAkkaPort))) { - log.info("Connected Runner {}", runnerAkkaPort) - - val sel = connectRunner("SbtRunner", "SbtActor", runnerHostname)( - runnerAkkaPort - ) - val (_, ref) = sel - - remoteSbtSelections = remoteSbtSelections + sel - - val state: SbtState = SbtState.Unknown - - updateSbtBalancer( - sbtLoadBalancer.addServer( - Server(ref, Inputs.default, state) - ) - ) - } - - case ReceiveStatus(requester) => - sender() ! LoadBalancerInfo(sbtLoadBalancer, requester) - - case statusProgress: StatusProgress => - statusActor ! statusProgress - - case run: Run => - run0(run.inputsWithIpAndUser, run.snippetId) case ping: Ping.type => implicit val timeout: Timeout = Timeout(10.seconds) - logError(Future.sequence { - sbtLoadBalancer.servers.map { s => - (s.ref ? SbtPing) - .map { _ => - log.info(s"pinged ${s.ref} server") - } - .recover { - case e => log.error(e, s"couldn't ping ${s} server") - } + val seq = Future.sequence( + List(scliDispatcher, sbtDispatcher).map { + s => (s ? Ping).map(_ => + log.info(s"Pinged ${s}") + ).recover(_ => + log.info(s"Failed to ping ${s}") + ) } - }) + ) } + + } diff --git a/balancer/src/main/scala/com.olegych.scastie.balancer/SbtDispatcher.scala b/balancer/src/main/scala/com.olegych.scastie.balancer/SbtDispatcher.scala new file mode 100644 index 000000000..ebc5d07d7 --- /dev/null +++ b/balancer/src/main/scala/com.olegych.scastie.balancer/SbtDispatcher.scala @@ -0,0 +1,157 @@ +package com.olegych.scastie.balancer + +import akka.event +import akka.actor.Actor +import akka.actor.ActorRef +import com.typesafe.config.Config +import akka.actor.ActorLogging +import com.typesafe.config.ConfigFactory +import akka.actor.ActorSelection +import com.olegych.scastie.api +import com.olegych.scastie.api._ +import com.olegych.scastie.util._ +import scala.concurrent.Future +import akka.remote.DisassociatedEvent +import com.olegych.scastie.api.SnippetId +import java.time.Instant +import com.olegych.scastie.api.TaskId +import scala.concurrent._ +import akka.pattern.ask + +import scala.concurrent.duration._ +import java.util.concurrent.Executors +import akka.actor.Address +import akka.actor.ActorSystem +import akka.util.Timeout + +class SbtDispatcher(config: Config, progressActor: ActorRef, statusActor: ActorRef) + extends BaseDispatcher[ActorSelection, SbtState](config) with Actor { + + private val parent = context.parent + + var remoteSbtSelections = getRemoteServers("sbt", "SbtRunner", "SbtActor") + + var balancer: SbtBalancer = { + val sbtServers = remoteSbtSelections.to(Vector).map { + case (_, ref) => + val state: SbtState = SbtState.Unknown + Server(ref, Inputs.default, state) + } + + LoadBalancer(servers = sbtServers) + } + + private def updateSbtBalancer(newBalancer: SbtBalancer): Unit = { + if (balancer != newBalancer) { + statusActor ! SbtLoadBalancerUpdate(newBalancer) + } + balancer = newBalancer + () + } + + import context._ + + // cannot be called from future + private def run0(inputsWithIpAndUser: InputsWithIpAndUser, snippetId: SnippetId): Unit = { + + val InputsWithIpAndUser(inputs, UserTrace(ip, user)) = inputsWithIpAndUser + + log.info("id: {}, ip: {} run inputs: {}", snippetId, ip, inputs) + + val task = Task(inputs, Ip(ip), TaskId(snippetId), Instant.now) + + balancer.add(task) match { + case Some((server, newBalancer)) => + updateSbtBalancer(newBalancer) + + server.ref.tell( + SbtTask(snippetId, inputs, ip, user.map(_.login), progressActor), + self + ) + case _ => () + } + } + + def receive: Receive = event.LoggingReceive(event.Logging.InfoLevel) { + case progress: api.SnippetProgress => + implicit val timeout: Timeout = Timeout(10.seconds) + val sender = this.sender() + if (progress.isDone) { + self ! Done(progress, retries = 100) + } + (parent ? progress).map(sender ! _) + + case done: Done => + done.progress.snippetId.foreach { sid => + val newBalancer = balancer.done(TaskId(sid)) + newBalancer match { + case Some(newBalancer) => + updateSbtBalancer(newBalancer) + case None => () // never happens + } + } + + case RunnerPong => () + + case SbtUp => + log.info("SbtUp") + + case format: FormatRequest => + val server = balancer.getRandomServer + server.foreach(_.ref.tell(format, sender())) + () + + + case event: DisassociatedEvent => + for { + host <- event.remoteAddress.host + port <- event.remoteAddress.port + ref <- remoteSbtSelections.get(SocketAddress(host, port)) + } { + log.warning("removing disconnected: {}", ref) + val previousRemoteSbtSelections = remoteSbtSelections + remoteSbtSelections = remoteSbtSelections - (SocketAddress(host, port)) + if (previousRemoteSbtSelections != remoteSbtSelections) { + updateSbtBalancer(balancer.removeServer(ref)) + } + } + + case Replay(SbtRun(snippetId, inputs, progressActor, snippetActor)) => + log.info("Replay: " + inputs.code) + + case RunnerConnect(runnerHostname, runnerAkkaPort) => + if (!remoteSbtSelections.contains(SocketAddress(runnerHostname, runnerAkkaPort))) { + log.info("Connected Runner {}", runnerAkkaPort) + + val address = SocketAddress(runnerHostname, runnerAkkaPort) + val ref = connectRunner(getRemoteActorPath("SbtRunner", address, "SbtActor")) + val sel = SocketAddress(runnerHostname, runnerAkkaPort) -> ref + + remoteSbtSelections = remoteSbtSelections + sel + + val state: SbtState = SbtState.Unknown + + updateSbtBalancer( + balancer.addServer( + Server(ref, Inputs.default, state) + ) + ) + } + + case ReceiveStatus(requester) => + sender() ! LoadBalancerInfo(balancer, requester) + + case run: Run => + run0(run.inputsWithIpAndUser, run.snippetId) + + case p: Ping.type => + val sender = this.sender() + ping(remoteSbtSelections.values.toList).andThen(s => sender ! RunnerPong) + } + + private def logError[T](f: Future[T]) = { + f.recover { + case e => log.error(e, "failed future") + } + } +} diff --git a/balancer/src/main/scala/com.olegych.scastie.balancer/ScliDispatcher.scala b/balancer/src/main/scala/com.olegych.scastie.balancer/ScliDispatcher.scala new file mode 100644 index 000000000..6c18ab8b1 --- /dev/null +++ b/balancer/src/main/scala/com.olegych.scastie.balancer/ScliDispatcher.scala @@ -0,0 +1,123 @@ +package com.olegych.scastie.balancer + +import akka.actor.Actor +import akka.actor.ActorLogging +import com.typesafe.config.Config +import akka.actor.ActorRef +import akka.actor.ActorSelection +import com.olegych.scastie.api.ActorConnected +import com.olegych.scastie.api.ScliState +import com.olegych.scastie.api.SnippetId +import com.olegych.scastie.balancer.Ping +import com.olegych.scastie.api.RunnerConnect +import com.olegych.scastie.api.RunnerPong +import com.olegych.scastie.api.TaskId +import java.time.Instant +import scala.collection.immutable.Queue +import com.olegych.scastie.util.SbtTask +import com.olegych.scastie.util.ScliActorTask +import com.olegych.scastie.api.SnippetProgress +import akka.util.Timeout +import scala.concurrent.duration._ +import akka.pattern.ask +import akka.remote.DisassociatedEvent + +class ScliDispatcher(config: Config, progressActor: ActorRef, statusActor: ActorRef) + extends BaseDispatcher[ActorSelection, ScliState](config) { + + private val parent = context.parent + + var remoteServers = getRemoteServers("scli", "ScliRunner", "ScliActor") + + var availableServersQueue = Queue[(SocketAddress, ActorSelection)](remoteServers.toSeq :_ *) + var taskQueue = Queue[Task]() + var processedSnippetsId: Map[SnippetId, (SocketAddress, ActorSelection)] = Nil.toMap + + private def run0(inputsWithIpAndUser: InputsWithIpAndUser, snippetId: SnippetId) = { + val InputsWithIpAndUser(inputs, UserTrace(ip, user)) = inputsWithIpAndUser + + log.info("id: {}, ip: {} run inputs: {}", snippetId, ip, inputs) + + val task = Task(inputs, Ip(ip), TaskId(snippetId), Instant.now) + + taskQueue = taskQueue.enqueue(task) + + giveTask + } + + private def enqueueAvailableServer(addr: SocketAddress, server: ActorSelection) = + if (remoteServers.contains(addr)) { + availableServersQueue = availableServersQueue.enqueue((addr, server)) + giveTask + } + + + private def giveTask = { + if (!taskQueue.isEmpty) { + val (task, newTaskQueue) = taskQueue.dequeue + + availableServersQueue.dequeueOption match { + case None => () + case Some(((addr, server), newQueue)) => { + log.info(s"Giving task ${task.taskId} to ${server.pathString}") + taskQueue = newTaskQueue + availableServersQueue = newQueue + server ! ScliActorTask(task.taskId.snippetId, task.config, task.ip.v, progressActor) + processedSnippetsId += task.taskId.snippetId -> (addr, server) + } + } + } + } + + import context._ + + def receive: Receive = { + case RunnerPong => () + + case p: Ping.type => + val sender = this.sender() + ping(remoteServers.values.toList).andThen(s => sender ! RunnerPong) + + case RunnerConnect(runnerHostname, runnerAkkaPort) => + if (!remoteServers.contains(SocketAddress(runnerHostname, runnerAkkaPort))) { + log.info("Connected runner {}", runnerAkkaPort) + + val address = SocketAddress(runnerHostname, runnerAkkaPort) + val ref = connectRunner(getRemoteActorPath("ScliRunner", address, "ScliActor")) + + remoteServers += address -> ref + enqueueAvailableServer(address, ref) + giveTask + } + + case progress: SnippetProgress => + implicit val timeout: Timeout = Timeout(10.seconds) + val sender = this.sender() + if (progress.isDone) { + self ! Done(progress, retries = 100) + } + (parent ? progress).map(sender ! _) + + case done: Done => + done.progress.snippetId.foreach { sid => + val (addr, server) = processedSnippetsId(sid) + log.info(s"Runner $addr has finished processing $sid.") + processedSnippetsId -= sid + enqueueAvailableServer(addr, server) + } + + case Run(inputsWithIpAndUser, snippetId) => run0(inputsWithIpAndUser, snippetId) + + case event: DisassociatedEvent => + for { + host <- event.remoteAddress.host + port <- event.remoteAddress.port + ref <- remoteServers.get(SocketAddress(host, port)) + } { + log.warning("removing disconnected: {}", ref) + remoteServers = remoteServers - (SocketAddress(host, port)) + } + + case _ => () + } +} diff --git a/build.sbt b/build.sbt index 9056a9eea..154157f63 100644 --- a/build.sbt +++ b/build.sbt @@ -8,8 +8,9 @@ def akka(module: String) = "com.typesafe.akka" %% ("akka-" + module) % "2.6.19" val akkaHttpVersion = "10.2.9" -addCommandAlias("startAll", "sbtRunner/reStart;server/reStart;metalsRunner/reStart;client/fastLinkJS") -addCommandAlias("startAllProd", "sbtRunner/reStart;metalsRunner/reStart;server/fullLinkJS/reStart") + +addCommandAlias("startAll", "scliRunner/reStart;sbtRunner/reStart;server/reStart;metalsRunner/reStart;client/fastLinkJS") +addCommandAlias("startAllProd", "scliRunner/reStart;sbtRunner/reStart;metalsRunner/reStart;server/fullLinkJS/reStart") val yarnBuild = taskKey[Unit]("builds es modules with `yarn build`") @@ -27,7 +28,8 @@ lazy val scastie = project server, storage, utils, - metalsRunner + metalsRunner, + scliRunner ).map(_.project)): _* ) .settings(baseSettings) @@ -132,7 +134,8 @@ lazy val metalsRunner = project "io.circe" %% "circe-generic" % "0.14.5", "com.evolutiongaming" %% "scache" % "4.2.3", "org.scalameta" %% "munit" % "0.7.29" % Test, - "org.typelevel" %% "munit-cats-effect-3" % "1.0.7" % Test + "org.typelevel" %% "munit-cats-effect-3" % "1.0.7" % Test, + "org.virtuslab" % "using_directives" % "0.1.0" // Used for parsing scala cli directives ) ) .enablePlugins(JavaServerAppPackaging, sbtdocker.DockerPlugin) @@ -295,3 +298,24 @@ lazy val sbtScastie = project ) .settings(version := versionRuntime) .dependsOn(api.jvm(ScalaVersions.sbt)) + +lazy val scliRunner = project + .in(file("scli-runner")) + .settings(baseNoCrossSettings) + .settings(loggingAndTest) + .settings(runnerRuntimeDependenciesInTest) + .settings( + reStart / javaOptions += "-Xmx256m", + Test / parallelExecution := false, + reStart := reStart.dependsOn(runnerRuntimeDependencies: _*).evaluated, + resolvers ++= Resolver.sonatypeOssRepos("public"), + libraryDependencies ++= Seq( + akka("actor"), + akka("testkit") % Test, + akka("cluster"), + akka("slf4j"), + "org.scalameta" %% "scalafmt-core" % "3.6.1", + "ch.epfl.scala" % "bsp4j" % "2.1.0-M3" + ) + ) + .dependsOn(api.jvm(ScalaVersions.jvm), instrumentation, utils) diff --git a/client/src/main/scala/com.olegych.scastie.client/ScastieBackend.scala b/client/src/main/scala/com.olegych.scastie.client/ScastieBackend.scala index 3b503dded..ea24a9900 100644 --- a/client/src/main/scala/com.olegych.scastie.client/ScastieBackend.scala +++ b/client/src/main/scala/com.olegych.scastie.client/ScastieBackend.scala @@ -11,6 +11,7 @@ import japgolly.scalajs.react.util.Effect.Id import org.scalajs.dom.{Position => _, _} import scala.concurrent.Future +import scala.concurrent.duration._ import scala.scalajs.concurrent.JSExecutionContext.Implicits.queue case class ScastieBackend(scastieId: UUID, serverUrl: Option[String], scope: BackendScope[Scastie, ScastieState]) { @@ -25,8 +26,32 @@ case class ScastieBackend(scastieId: UUID, serverUrl: Option[String], scope: Bac Callback(Global.subscribe(scope, scastieId)) } + val reloadStaleMetals: Reusable[Callback] = + Reusable.always { scope.modState({ state => { + previousDirectives = Some(takeDirectives(state.inputs)) + state.copyAndSave(isMetalsStale = false) + } }) } + + private def takeDirectives(inp: Inputs) = inp.code.split("\n").takeWhile(_.startsWith("//>")).toList + private var previousDirectives: Option[List[String]] = None + + val checkIfMetalsStale = scope.modState(state => { + val newDirectives = takeDirectives(state.inputs) + previousDirectives match { + case None => { previousDirectives = Some(newDirectives); state } + case Some(previousDirectives) if previousDirectives != newDirectives => state.copy(isMetalsStale = true) + case _ => state + } + }).async.rateLimit(1.second) + val codeChange: String ~=> Callback = - Reusable.fn(code => scope.modState(_.setCode(code))) + Reusable.fn(code => { + checkIfMetalsStale.runNow() + scope.modState(state => { + val newState = state.setCode(code) + newState + }) + }) val sbtConfigChange: String ~=> Callback = { Reusable.fn(newConfig => scope.modState(_.setSbtConfigExtra(newConfig))) @@ -106,7 +131,7 @@ case class ScastieBackend(scastieId: UUID, serverUrl: Option[String], scope: Bac Reusable.fn(status => scope.modState(_.setMetalsStatus(status))) val toggleMetalsStatus: Reusable[Callback] = - Reusable.always(scope.modState(_.toggleMetalsStatus)) + Reusable.always(scope.modState({ state => previousDirectives = None; state.toggleMetalsStatus })) val toggleLineNumbers: Reusable[Callback] = Reusable.always(scope.modState(_.toggleLineNumbers)) @@ -491,4 +516,8 @@ case class ScastieBackend(scastieId: UUID, serverUrl: Option[String], scope: Bac .map(_.set(Home)) .getOrElse(Callback.empty) ) + + // Convert to Scala-CLI + val convertToScalaCli: Reusable[Callback] = + Reusable.always(scope.modState(_.convertToScalaCli)) } diff --git a/client/src/main/scala/com.olegych.scastie.client/ScastieState.scala b/client/src/main/scala/com.olegych.scastie.client/ScastieState.scala index b36f21e49..1a971ebab 100644 --- a/client/src/main/scala/com.olegych.scastie.client/ScastieState.scala +++ b/client/src/main/scala/com.olegych.scastie.client/ScastieState.scala @@ -4,6 +4,7 @@ import com.olegych.scastie.api._ import org.scalajs.dom.HTMLElement import org.scalajs.dom.{Position => _} import play.api.libs.json._ +import com.olegych.scastie.client.scli.ScalaCliUtils sealed trait MetalsStatus { val info: String @@ -109,8 +110,11 @@ case class ScastieState( outputs: Outputs, status: StatusState, metalsStatus: MetalsStatus = MetalsLoading, + isMetalsStale : Boolean = false, isEmbedded: Boolean = false, transient: Boolean = false, + + scalaCliConversionError: Option[String] = None ) { def snippetId: Option[SnippetId] = snippetState.snippetId def loadSnippet: Boolean = snippetState.loadSnippet @@ -136,7 +140,9 @@ case class ScastieState( outputs: Outputs = outputs, status: StatusState = status, metalsStatus: MetalsStatus = metalsStatus, + isMetalsStale: Boolean = isMetalsStale, transient: Boolean = transient, + scalaCliConversionError: Option[String] = scalaCliConversionError ): ScastieState = { val state0 = copy( @@ -167,6 +173,8 @@ case class ScastieState( metalsStatus = metalsStatus, isEmbedded = isEmbedded, transient = transient, + scalaCliConversionError = scalaCliConversionError, + isMetalsStale = isMetalsStale ) if (!isEmbedded && !transient) { @@ -207,7 +215,7 @@ case class ScastieState( copyAndSave(metalsStatus = status) def toggleMetalsStatus: ScastieState = - copyAndSave(metalsStatus = if (metalsStatus != MetalsDisabled) MetalsDisabled else MetalsLoading) + copyAndSave(metalsStatus = if (metalsStatus != MetalsDisabled) MetalsDisabled else MetalsLoading, isMetalsStale = false) def toggleLineNumbers: ScastieState = copyAndSave(showLineNumbers = !showLineNumbers) @@ -530,5 +538,15 @@ case class ScastieState( ) } + // Returns None if it has been correctly converted. + // Returns a String explaining why the conversion failed. + // + def convertToScalaCli: ScastieState = { + ScalaCliUtils.convertInputsToScalaCli(inputs) match { + case Left(inputs) => copyAndSave(inputs = inputs, scalaCliConversionError = None, view = View.Editor) + case Right(error) => copyAndSave(scalaCliConversionError = Some(error)) + } + } + override def toString: String = Json.toJson(this).toString() } diff --git a/client/src/main/scala/com.olegych.scastie.client/components/BuildSettings.scala b/client/src/main/scala/com.olegych.scastie.client/components/BuildSettings.scala index f43c6572a..7694a53f3 100644 --- a/client/src/main/scala/com.olegych.scastie.client/components/BuildSettings.scala +++ b/client/src/main/scala/com.olegych.scastie.client/components/BuildSettings.scala @@ -5,6 +5,8 @@ import com.olegych.scastie.client.components.editor.SimpleEditor import japgolly.scalajs.react._ import vdom.all._ +import com.olegych.scastie.api.ScalaTarget._ +import japgolly.scalajs.react.feature.ReactFragment final case class BuildSettings( visible: Boolean, @@ -23,7 +25,10 @@ final case class BuildSettings( sbtConfigChange: String ~=> Callback, removeScalaDependency: ScalaDependency ~=> Callback, updateDependencyVersion: (ScalaDependency, String) ~=> Callback, - addScalaDependency: (ScalaDependency, Project) ~=> Callback + addScalaDependency: (ScalaDependency, Project) ~=> Callback, + + convertToScalaCli: Reusable[Callback], + scalaCliConversionError: Option[String] ) { @inline def render: VdomElement = BuildSettings.component(this) @@ -66,6 +71,11 @@ object BuildSettings { scalaTarget = props.scalaTarget ).render + val isScalaCli = props.scalaTarget match { + case _: ScalaCli => true + case _ => false + } + div(cls := "build-settings-container")( resetButton, h2( @@ -75,44 +85,68 @@ object BuildSettings { h2( span("Scala Version") ), - VersionSelector(props.scalaTarget, props.setTarget).render, + VersionSelector(props.scalaTarget, props.setTarget).render.unless(isScalaCli), + p()( + "To use a specific version of Scala with Scala-CLI, use directives. See ", + a(href := "https://scala-cli.virtuslab.org/docs/reference/directives/#scala-version", target := "_blank")("Scala version directive on Scala-CLI documentation"), + "." + ).when(isScalaCli), h2( span("Libraries") ), - scaladexSearch, - h2( - span("Extra Sbt Configuration") - ), - pre(cls := "configuration")( - SimpleEditor( - value = props.sbtConfigExtra, - isDarkTheme = props.isDarkTheme, - readOnly = false, - onChange = props.sbtConfigChange - ).render - ), - h2( - span("Base Sbt Configuration (readonly)") - ), - pre(cls := "configuration")( - SimpleEditor( - value = props.sbtConfig, - isDarkTheme = props.isDarkTheme, - readOnly = true, - onChange = Reusable.always(_ => Callback.empty) - ).render - ), - h2( - span("Base Sbt Plugins Configuration (readonly)") - ), - pre(cls := "configuration")( - SimpleEditor( - value = props.sbtPluginsConfig, - isDarkTheme = props.isDarkTheme, - readOnly = true, - onChange = Reusable.always(_ => Callback.empty) - ).render - ) + scaladexSearch.unless(isScalaCli), + p()( + "To use libraries with Scala-CLI, use directives. See ", + a(href := "https://scala-cli.virtuslab.org/docs/reference/directives#dependency", target := "_blank")("Dependency directive on Scala-CLI documentation"), + "." + ).when(isScalaCli), + ReactFragment( + h2( + span("Extra Sbt Configuration") + ), + pre(cls := "configuration")( + SimpleEditor( + value = props.sbtConfigExtra, + isDarkTheme = props.isDarkTheme, + readOnly = false, + onChange = props.sbtConfigChange + ).render + ), + h2( + span("Base Sbt Configuration (readonly)") + ), + pre(cls := "configuration")( + SimpleEditor( + value = props.sbtConfig, + isDarkTheme = props.isDarkTheme, + readOnly = true, + onChange = Reusable.always(_ => Callback.empty) + ).render + ), + h2( + span("Base Sbt Plugins Configuration (readonly)") + ), + pre(cls := "configuration")( + SimpleEditor( + value = props.sbtPluginsConfig, + isDarkTheme = props.isDarkTheme, + readOnly = true, + onChange = Reusable.always(_ => Callback.empty) + ).render + ), + + h2( + span("Convert to Scala-CLI") + ), + div( + title := "Convert to Scala-CLI", + onClick --> props.convertToScalaCli, + role := "button", + cls := "btn" + )("Convert to Scala-CLI"), + + props.scalaCliConversionError.map(err => p()(s"Failed to convert to Scala-CLI: $err")) + ).when(!isScalaCli) ) } diff --git a/client/src/main/scala/com.olegych.scastie.client/components/EditorTopBar.scala b/client/src/main/scala/com.olegych.scastie.client/components/EditorTopBar.scala index 7fea28cba..605dd438d 100644 --- a/client/src/main/scala/com.olegych.scastie.client/components/EditorTopBar.scala +++ b/client/src/main/scala/com.olegych.scastie.client/components/EditorTopBar.scala @@ -27,6 +27,8 @@ final case class EditorTopBar(clear: Reusable[Callback], view: StateSnapshot[View], isWorksheetMode: Boolean, metalsStatus: MetalsStatus, + isMetalsStale: Boolean, + reloadStaleMetals: Reusable[Callback], toggleMetalsStatus: Reusable[Callback], scalaTarget: ScalaTarget) { @inline def render: VdomElement = EditorTopBar.component(this) @@ -73,6 +75,10 @@ object EditorTopBar { props.view.value ).render + val reloadMetalsButton = ReloadStaleMetals( + props.reloadStaleMetals + ).render.when(props.isMetalsStale) + val metalsButton = MetalsStatusIndicator( props.metalsStatus, props.toggleMetalsStatus, @@ -123,7 +129,8 @@ object EditorTopBar { worksheetButton, downloadButton, embeddedModalButton, - metalsButton, + reloadMetalsButton, + metalsButton ) ) } @@ -134,4 +141,4 @@ object EditorTopBar { .render_P(render) .configure(Reusability.shouldComponentUpdate) .build -} +} \ No newline at end of file diff --git a/client/src/main/scala/com.olegych.scastie.client/components/MainPanel.scala b/client/src/main/scala/com.olegych.scastie.client/components/MainPanel.scala index 8bccde4c6..27b706dcb 100644 --- a/client/src/main/scala/com.olegych.scastie.client/components/MainPanel.scala +++ b/client/src/main/scala/com.olegych.scastie.client/components/MainPanel.scala @@ -81,7 +81,8 @@ object MainPanel { target = state.inputs.target, metalsStatus = state.metalsStatus, setMetalsStatus = backend.setMetalsStatus, - dependencies = state.inputs.libraries + dependencies = state.inputs.libraries, + isMetalsStale = state.isMetalsStale ).render val console = @@ -114,7 +115,10 @@ object MainPanel { sbtConfigChange = backend.sbtConfigChange, removeScalaDependency = backend.removeScalaDependency, updateDependencyVersion = backend.updateDependencyVersion, - addScalaDependency = backend.addScalaDependency + addScalaDependency = backend.addScalaDependency, + + convertToScalaCli = backend.convertToScalaCli, + scalaCliConversionError = state.scalaCliConversionError ).render val mobileBar = @@ -162,6 +166,8 @@ object MainPanel { view = backend.viewSnapshot(state.view), isWorksheetMode = state.inputs.isWorksheetMode, metalsStatus = state.metalsStatus, + isMetalsStale = state.isMetalsStale, + reloadStaleMetals = backend.reloadStaleMetals, toggleMetalsStatus = backend.toggleMetalsStatus, scalaTarget = state.inputs.target ).render.unless(props.isEmbedded || state.isPresentationMode) @@ -214,4 +220,4 @@ object MainPanel { .render_P(render) .configure(Reusability.shouldComponentUpdate) .build -} +} \ No newline at end of file diff --git a/client/src/main/scala/com.olegych.scastie.client/components/TargetSelector.scala b/client/src/main/scala/com.olegych.scastie.client/components/TargetSelector.scala index bca1c389a..4958568bb 100644 --- a/client/src/main/scala/com.olegych.scastie.client/components/TargetSelector.scala +++ b/client/src/main/scala/com.olegych.scastie.client/components/TargetSelector.scala @@ -12,6 +12,7 @@ case class TargetSelector(scalaTarget: ScalaTarget, onChange: ScalaTarget ~=> Ca object TargetSelector { val targetTypes = List[ScalaTargetType]( + ScalaTargetType.ScalaCli, ScalaTargetType.Scala3, ScalaTargetType.Scala2, ScalaTargetType.JS @@ -20,6 +21,7 @@ object TargetSelector { def labelFor(targetType: ScalaTargetType) = { targetType match { + case ScalaTargetType.ScalaCli => "Scala-CLI" case ScalaTargetType.Scala2 => "Scala 2" case ScalaTargetType.JS => "Scala.js" case ScalaTargetType.Scala3 => "Scala 3" diff --git a/client/src/main/scala/com.olegych.scastie.client/components/editor/CodeEditor.scala b/client/src/main/scala/com.olegych.scastie.client/components/editor/CodeEditor.scala index 52d70e43b..2567a08ee 100644 --- a/client/src/main/scala/com.olegych.scastie.client/components/editor/CodeEditor.scala +++ b/client/src/main/scala/com.olegych.scastie.client/components/editor/CodeEditor.scala @@ -45,6 +45,7 @@ final case class CodeEditor(visible: Boolean, codeChange: String ~=> Callback, target: api.ScalaTarget, metalsStatus: MetalsStatus, + isMetalsStale: Boolean, setMetalsStatus: MetalsStatus ~=> Callback, dependencies: Set[api.ScalaDependency]) extends Editor { diff --git a/client/src/main/scala/com.olegych.scastie.client/components/editor/InteractiveProvider.scala b/client/src/main/scala/com.olegych.scastie.client/components/editor/InteractiveProvider.scala index ecfc1f4f6..ba637c703 100644 --- a/client/src/main/scala/com.olegych.scastie.client/components/editor/InteractiveProvider.scala +++ b/client/src/main/scala/com.olegych.scastie.client/components/editor/InteractiveProvider.scala @@ -16,6 +16,7 @@ import hooks.Hooks.UseStateF case class InteractiveProvider( dependencies: Set[api.ScalaDependency], target: api.ScalaTarget, + code: Some[String], metalsStatus: MetalsStatus, updateStatus: MetalsStatus ~=> Callback, isWorksheetMode: Boolean, @@ -34,6 +35,7 @@ object InteractiveProvider { InteractiveProvider( props.dependencies, props.target, + Some(props.value), props.metalsStatus, props.setMetalsStatus, props.isWorksheetMode, @@ -77,6 +79,7 @@ object InteractiveProvider { val extension = InteractiveProvider( props.dependencies, props.target, + Some(props.value), props.metalsStatus, props.setMetalsStatus, props.isWorksheetMode, @@ -91,4 +94,3 @@ object InteractiveProvider { } } - diff --git a/client/src/main/scala/com.olegych.scastie.client/components/editor/MetalsClient.scala b/client/src/main/scala/com.olegych.scastie.client/components/editor/MetalsClient.scala index 59f7e2acf..11d90c44b 100644 --- a/client/src/main/scala/com.olegych.scastie.client/components/editor/MetalsClient.scala +++ b/client/src/main/scala/com.olegych.scastie.client/components/editor/MetalsClient.scala @@ -25,21 +25,27 @@ trait MetalsClient { val target: api.ScalaTarget val isWorksheetMode: Boolean val isEmbedded: Boolean - val scastieMetalsOptions = api.ScastieMetalsOptions(dependencies, target) + val code: Option[String] + var scastieMetalsOptions = api.ScastieMetalsOptions(dependencies, target, code) + private val isConfigurationSupported: Future[Boolean] = { if (metalsStatus == MetalsDisabled || isEmbedded) Future.successful(false) else { updateStatus(MetalsLoading).runNow() val res = makeRequest(scastieMetalsOptions, "isConfigurationSupported").map(maybeText => - parseMetalsResponse[Boolean](maybeText).getOrElse(false) + parseMetalsResponse[api.ScastieMetalsOptions](maybeText).map(Right(_)).getOrElse(Left(maybeText)) ) res.onComplete { - case Success(true) => updateStatus(MetalsReady).runNow() + case Success(Right(opt)) => { + scastieMetalsOptions = opt + updateStatus(MetalsReady).runNow() + } + case Success(Left(details)) => updateStatus(NetworkError(s"Error sent from server: $details")).runNow() case Failure(exception) => updateStatus(NetworkError(exception.getMessage)).runNow() case _ => } - res + res.map { _.isRight } } } diff --git a/client/src/main/scala/com.olegych.scastie.client/components/editor/ReloadStaleMetals.scala b/client/src/main/scala/com.olegych.scastie.client/components/editor/ReloadStaleMetals.scala new file mode 100644 index 000000000..b44872fe9 --- /dev/null +++ b/client/src/main/scala/com.olegych.scastie.client/components/editor/ReloadStaleMetals.scala @@ -0,0 +1,31 @@ +package com.olegych.scastie +package client +package components + +import japgolly.scalajs.react._ +import vdom.all._ + +final case class ReloadStaleMetals( + reload: Reusable[Callback] +) { + @inline def render: VdomElement = ReloadStaleMetals.component(this) +} + +object ReloadStaleMetals { + private def render(props: ReloadStaleMetals): VdomElement = { + li( + role := "button", + title := "Reload metals", + cls := "btn editor reload-metals-btn", + onClick --> props.reload + )( + i(cls := "fa fa-refresh"), + span("Reload metals") + ) + } + + private val component = + ScalaFnComponent + .withHooks[ReloadStaleMetals] + .render(props => ReloadStaleMetals.render(props)) +} \ No newline at end of file diff --git a/client/src/main/scala/com.olegych.scastie.client/scli/ScalaCliUtils.scala b/client/src/main/scala/com.olegych.scastie.client/scli/ScalaCliUtils.scala new file mode 100644 index 000000000..5187da01f --- /dev/null +++ b/client/src/main/scala/com.olegych.scastie.client/scli/ScalaCliUtils.scala @@ -0,0 +1,40 @@ +package com.olegych.scastie.client.scli + +import com.olegych.scastie.api.Inputs +import com.olegych.scastie.api.ScalaTargetType +import com.olegych.scastie.api.ScalaDependency +import com.olegych.scastie.api.ScalaTarget + +object ScalaCliUtils { + def convertInputsToScalaCli(in: Inputs): Either[Inputs, String] = { + val Inputs(_isWorksheetMode, code, target, libraries, librariesFromList, sbtConfigExtra, sbtConfigSaved, sbtPluginsConfigExtra, sbtPluginsConfigSaved, isShowingInUserProfile, forked) = in + + if (sbtConfigExtra.size > 0 && sbtConfigExtra != Inputs.default.sbtConfigExtra) { + Right("Custom SBT config is not supported in Scala-CLI") + } else if (target.targetType == ScalaTargetType.ScalaCli) { + Right("Already a Scala-CLI snippet.") + } else if (target.targetType == ScalaTargetType.Scala2 || target.targetType == ScalaTargetType.Scala3) { + Left( + Inputs.default.copy(_isWorksheetMode = _isWorksheetMode, + code = prependWithDirectives(target.scalaVersion, libraries, code), + target = ScalaTarget.ScalaCli() + ) + ) + } else { + Right(s"Unsupported target ${target.targetType}") + } + } + + private def prependWithDirectives(scalaVersion: String, libraries: Set[ScalaDependency], code: String): String = { + val dependencies = { + if (libraries.size == 0) "" else { + "\n" + libraries.map(dep => s"""//> using dep "${dep.groupId}::${dep.artifact}::${dep.version}"""").mkString("\n") + } + } + + s"""//> using scala "$scalaVersion"$dependencies + |//> ============= + | + |$code""".stripMargin + } +} \ No newline at end of file diff --git a/deployment/production.conf b/deployment/production.conf index bb47b2b01..487d69a78 100644 --- a/deployment/production.conf +++ b/deployment/production.conf @@ -17,6 +17,9 @@ com.olegych.scastie { remote-sbt-ports-start = 5150 remote-sbt-ports-size = 6 + + remote-scli-ports-start = 5250 + remote-scli-ports-size = 4 } web { // this is where web server will be running diff --git a/metals-runner/src/main/scala/scastie/metals/DTOCodecs.scala b/metals-runner/src/main/scala/scastie/metals/DTOCodecs.scala index 17054e229..62a644b41 100644 --- a/metals-runner/src/main/scala/scastie/metals/DTOCodecs.scala +++ b/metals-runner/src/main/scala/scastie/metals/DTOCodecs.scala @@ -4,6 +4,7 @@ import com.olegych.scastie.api._ import io.circe._ import io.circe.generic.semiauto._ import io.circe.syntax._ +import com.olegych.scastie.api.ScalaTarget._ object DTOCodecs { import JavaConverters._ @@ -14,7 +15,7 @@ object DTOCodecs { final def apply(c: HCursor): Decoder.Result[ScalaTarget] = val res = for { tpe <- c.downField("tpe").as[String] - scalaVersion <- c.downField("scalaVersion").as[String] + scalaVersion <- if (tpe == "ScalaCli") then Right("") else c.downField("scalaVersion").as[String] } yield tpe -> scalaVersion res.flatMap((tpe, scalaVersion) => @@ -24,9 +25,27 @@ object DTOCodecs { case "Typelevel" => Right(ScalaTarget.Typelevel(scalaVersion)) case "Native" => c.downField("scalaNativeVersion").as[String].map(ScalaTarget.Native(scalaVersion, _)) case "Scala3" | "Dotty" => Right(ScalaTarget.Scala3(scalaVersion)) + case "ScalaCli" => Right(ScalaTarget.ScalaCli()) } ) + } + + implicit val scalaTargetEncoder: Encoder[ScalaTarget] = new Encoder[ScalaTarget] { + def apply(a: ScalaTarget): Json = { + val supplementaryFields = a match + case Jvm(scalaVersion) => List("tpe" -> "Jvm") + case Typelevel(scalaVersion) => List("tpe" -> "Typelevel") + case Js(scalaVersion, scalaJsVersion) => List("tpe" -> "Js", "scalaJsVersion" -> scalaJsVersion) + case Native(scalaVersion, scalaNativeVersion) => List("tpe" -> "Native", "scalaNativeVersion" -> scalaNativeVersion) + case Scala3(scalaVersion) => List("tpe" -> "Scala3") + case ScalaCli(scalaBinaryVersion0) => List("tpe" -> "ScalaCli") + Json.fromFields( + (List("scalaVersion" -> a.scalaVersion) ++ supplementaryFields).map({ + case (a1, a2) => (a1, Json.fromString(a2)) + }) + ) + } } implicit val scalaDependencyDecoder: Decoder[ScalaDependency] = deriveDecoder @@ -47,10 +66,13 @@ object DTOCodecs { implicit val noResultEncoder: Encoder.AsObject[NoResult] = deriveEncoder[NoResult] implicit val presentationCompilerFailureEncoder: Encoder.AsObject[PresentationCompilerFailure] = deriveEncoder[PresentationCompilerFailure] + implicit val invalidScalaVersionEncoder: Encoder.AsObject[InvalidScalaVersion] = + deriveEncoder[InvalidScalaVersion] def apply(a: FailureType): Json = (a match case noResult: NoResult => noResult.asJsonObject case pcFailure: PresentationCompilerFailure => pcFailure.asJsonObject + case invScalaVersion: InvalidScalaVersion => invScalaVersion.asJsonObject ).+:("_type" -> a.getClass.getCanonicalName.toString.asJson).asJson } diff --git a/metals-runner/src/main/scala/scastie/metals/DTOExtensions.scala b/metals-runner/src/main/scala/scastie/metals/DTOExtensions.scala index e539be574..55c8c0ed8 100644 --- a/metals-runner/src/main/scala/scastie/metals/DTOExtensions.scala +++ b/metals-runner/src/main/scala/scastie/metals/DTOExtensions.scala @@ -17,14 +17,22 @@ object DTOExtensions { val wrapperObject = s"""|object worksheet { |$wrapperIndent""".stripMargin - val contentToOffset = offsetParams.content.take(offsetParams.offset).linesWithSeparators - val line = contentToOffset.size - 1 - val (content, position) = if offsetParams.isWorksheetMode then - val adjustedContent = s"""$wrapperObject${offsetParams.content.replace("\n", "\n" + wrapperIndent)}}""" - val adjustedPosition = wrapperObject.length + line * 2 + offsetParams.offset - (adjustedContent, adjustedPosition) + val (userDirectives, userCode) = offsetParams.content.split("\n").span(_.startsWith("//>")) + + val userDirectivesEndingWithLR = if (userDirectives.size == 0) then "" else userDirectives.mkString("", "\n", "\n") + + val adjustedContent = s"""${userDirectives.mkString("\n")}\n$wrapperObject${userCode.mkString("\n" + wrapperIndent)}}""" + + val userDirectivesLength = userDirectives.map(_.length + 1).sum + if (offsetParams.offset < userDirectivesLength) then + // cursor is in directives + (adjustedContent, offsetParams.offset) + else + // cursor is in code + (adjustedContent, wrapperObject.length + offsetParams.offset + (userCode.length - 1) * wrapperIndent.length + 1) + else (offsetParams.content, offsetParams.offset) new CompilerOffsetParams(noSourceFilePath.toUri, content, position) diff --git a/metals-runner/src/main/scala/scastie/metals/MetalsDispatcher.scala b/metals-runner/src/main/scala/scastie/metals/MetalsDispatcher.scala index 007131da5..de37c1f23 100644 --- a/metals-runner/src/main/scala/scastie/metals/MetalsDispatcher.scala +++ b/metals-runner/src/main/scala/scastie/metals/MetalsDispatcher.scala @@ -21,6 +21,8 @@ import com.olegych.scastie.api.ScalaTarget._ import coursierapi.{Dependency, Fetch} import org.slf4j.LoggerFactory +import scastie.metals.ScalaCliParser + /* * MetalsDispatcher is responsible for managing the lifecycle of presentation compilers. * @@ -45,35 +47,55 @@ class MetalsDispatcher[F[_]: Async](cache: Cache[F, ScastieMetalsOptions, Scasti * @param configuration - scastie client configuration * @returns `EitherT[F, FailureType, ScastiePresentationCompiler]` */ - def getCompiler(configuration: ScastieMetalsOptions): EitherT[F, FailureType, ScastiePresentationCompiler] = EitherT { - if !isSupportedVersion(configuration) then - Async[F].pure( - Left( - PresentationCompilerFailure( - s"Interactive features are not supported for Scala ${configuration.scalaTarget.binaryScalaVersion}." - ) - ) - ) - else - Sync[F].blocking( - mtagsResolver - .resolve(configuration.scalaTarget.scalaVersion) - .toRight( - PresentationCompilerFailure( - s"Mtags couldn't be resolved for target: ${configuration.scalaTarget.scalaVersion}." + def getCompiler(configuration: ScastieMetalsOptions): EitherT[F, FailureType, ScastiePresentationCompiler] = + if (configuration.scalaTarget.targetType == ScalaTargetType.ScalaCli) { + convertConfigurationFromScalaCli(configuration).flatMap(getCompiler(_)) + } else { + EitherT { + if !isSupportedVersion(configuration) then + Async[F].delay( + Left( + PresentationCompilerFailure( + s"Interactive features are not supported for Scala ${configuration.scalaTarget.binaryScalaVersion}." + ) ) ) - ) >>= (_.traverse(mtags => - cache.getOrUpdateReleasable(configuration) { - initializeCompiler(configuration, mtags).map { newPC => - Releasable(newPC, Sync[F].delay(newPC.underlyingPC.shutdown())) - } - } - ).recoverWith { case NonFatal(e) => - logger.error(e.getMessage) - PresentationCompilerFailure(e.getMessage).asLeft.pure[F] - }) - } + else + Sync[F].delay( + mtagsResolver + .resolve(configuration.scalaTarget.scalaVersion) + .toRight( + PresentationCompilerFailure( + s"Mtags couldn't be resolved for target: ${configuration.scalaTarget.scalaVersion}." + ) + ) + ) >>= (_.traverse(mtags => + cache.getOrUpdateReleasable(configuration) { + initializeCompiler(configuration, mtags).map { newPC => + Releasable(newPC, Sync[F].delay(newPC.underlyingPC.shutdown())) + } + } + ).recoverWith { case NonFatal(e) => + logger.error(e.getMessage) + PresentationCompilerFailure(e.getMessage).asLeft.pure[F] + }) + } + + } + + /** + * This converts a configuration that targets Scala-CLI + * It extracts directives and expose a new configuration understandable + * by the runner. + * If it is not a Scala-CLI target, it will return the untouched configuration. + */ + def convertConfigurationFromScalaCli(configuration: ScastieMetalsOptions): EitherT[F, FailureType, ScastieMetalsOptions] = + if (configuration.scalaTarget.targetType == ScalaTargetType.ScalaCli && configuration.code.isDefined) then + val res = Async[F].delay ( ScalaCliParser.getScalaTarget(configuration.code.get) ) + EitherT(res) + else + EitherT.rightT(configuration) + /* * Checks if given configuration is supported. Currently it is based on scala binary version. @@ -88,10 +110,13 @@ class MetalsDispatcher[F[_]: Async](cache: Cache[F, ScastieMetalsOptions, Scasti * In sbt it is automatically resolved but here, we manually specify scala target. */ def areDependenciesSupported(configuration: ScastieMetalsOptions): EitherT[F, FailureType, Boolean] = - def scalaTargetString(scalaTarget: ScalaTarget): String = - s"${scalaTarget.scalaVersion}" ++ (if scalaTarget.targetType == ScalaTargetType.JS then - s" ${scalaTarget.targetType}" - else "") + if (configuration.scalaTarget.targetType == ScalaTargetType.ScalaCli) then + convertConfigurationFromScalaCli(configuration).flatMap(areDependenciesSupported(_)) + else { + def scalaTargetString(scalaTarget: ScalaTarget): String = + s"${scalaTarget.scalaVersion}" ++ (if scalaTarget.targetType == ScalaTargetType.JS then + s" ${scalaTarget.targetType}" + else "") def checkScalaVersionCompatibility(scalaTarget: ScalaTarget): Boolean = SemVer.isCompatibleVersion(scalaTarget.scalaVersion, configuration.scalaTarget.scalaVersion) @@ -102,20 +127,21 @@ class MetalsDispatcher[F[_]: Async](cache: Cache[F, ScastieMetalsOptions, Scasti else scalaTarget.isJVMTarget - val misconfiguredLibraries = configuration.dependencies - .filterNot(l => checkScalaVersionCompatibility(l.target) && checkScalaJsCompatibility(l.target)) + val misconfiguredLibraries = configuration.dependencies + .filterNot(l => checkScalaVersionCompatibility(l.target) && checkScalaJsCompatibility(l.target)) - Option - .when(misconfiguredLibraries.nonEmpty) { - val errorString = misconfiguredLibraries - .map(l => - s"${l.toString} dependency binary version is: ${scalaTargetString(l.target)} while scastie is set to: ${scalaTargetString(configuration.scalaTarget)}" - ) - .mkString("\n") - PresentationCompilerFailure(s"Misconfigured dependencies: $errorString") - } - .toLeft(true) - .toEitherT + Option + .when(misconfiguredLibraries.nonEmpty) { + val errorString = misconfiguredLibraries + .map(l => + s"${l.toString} dependency binary version is: ${scalaTargetString(l.target)} while scastie is set to: ${scalaTargetString(configuration.scalaTarget)}" + ) + .mkString("\n") + PresentationCompilerFailure(s"Misconfigured dependencies: $errorString") + } + .toLeft(true) + .toEitherT + } /* * Initializes the compiler with proper classpath and version diff --git a/metals-runner/src/main/scala/scastie/metals/ScalaCliParser.scala b/metals-runner/src/main/scala/scastie/metals/ScalaCliParser.scala new file mode 100644 index 000000000..5f6fbb3df --- /dev/null +++ b/metals-runner/src/main/scala/scastie/metals/ScalaCliParser.scala @@ -0,0 +1,88 @@ +package scastie.metals + +import scala.jdk.CollectionConverters._ +import com.virtuslab.using_directives.custom.utils.Source +import com.virtuslab.using_directives.config.Settings +import com.virtuslab.using_directives.Context +import com.virtuslab.using_directives.reporter.PersistentReporter +import com.virtuslab.using_directives.custom.Parser +import com.virtuslab.using_directives.custom.utils.ast.UsingDef +import com.virtuslab.using_directives.reporter.ConsoleReporter +import com.virtuslab.using_directives.custom.SimpleCommentExtractor +import com.olegych.scastie.api.ScalaTarget +import com.olegych.scastie.api.FailureType +import com.olegych.scastie.api.InvalidScalaVersion +import com.olegych.scastie.buildinfo.BuildInfo +import com.olegych.scastie.api.ScastieMetalsOptions +import com.virtuslab.using_directives.custom.utils.ast.SettingDefOrUsingValue +import com.virtuslab.using_directives.custom.utils.ast.NumericLiteral +import com.virtuslab.using_directives.custom.utils.ast.StringLiteral +import com.olegych.scastie.api.PresentationCompilerFailure +import com.olegych.scastie.api.ScalaDependency +import com.olegych.scastie.api.ScalaTarget.ScalaCli + +object ScalaCliParser { + + def getScliDirectives(string: String) = + SimpleCommentExtractor(string.toCharArray(), true).extractComments() + + def parse(string: String) = + val source = new Source(getScliDirectives(string)) + val reporter = new PersistentReporter() + val ctx = new Context(reporter) + val parser = new Parser(source, ctx) + + val defs = parser.parse().getUsingDefs().asScala.toList + val allDefs = defs.flatMap(_.getSettingDefs().getSettings().asScala) + + allDefs + + private def extractValue(k: SettingDefOrUsingValue): Option[String] = { + k match + case k: NumericLiteral => Some(k.getValue()) + case k: StringLiteral => Some(k.getValue()) + case _ => None + } + + def getScalaTarget(code: String): Either[FailureType, ScastieMetalsOptions] = { + val defs: Map[String, List[String]] = parse(code).groupMapReduce( + _.getKey() + )( + t => { + val option = extractValue(t.getValue()) + option.toList + } + )(_ ++ _) + + // get the scala version + var scalaVersion = defs.get("scala").getOrElse(List(BuildInfo.latest3)).headOption.getOrElse("3") + + // now we have the scala version + // get the target + val scalaTarget: Either[FailureType, ScalaTarget] = + ScalaTarget.fromScalaVersion(scalaVersion) match + case None => Left(InvalidScalaVersion(s"Invalid Scala version $scalaVersion")) + case Some(target) => Right(target) + + scalaTarget.map { scalaTarget => { + val dependencies = defs.get("dep").getOrElse(List()) ++ defs.get("lib").getOrElse(List()) + + val actualDependencies = dependencies.map(_.split(":").toList).flatMap { + // "groupId::artifact:version" + case List(groupId, "", artifactId, version) => List(ScalaDependency(groupId, artifactId, scalaTarget, version)) + // "groupId:artifact:version" + case List(groupId, artifactId, version) => { + val split = artifactId.split("_") + val scalaLibVersion = split.last + val libname = split.init.mkString("_") + List(ScalaDependency(groupId, libname, ScalaCli(scalaLibVersion), version)) + } + + case _ => List() + } + + ScastieMetalsOptions(actualDependencies.toSet, scalaTarget) + } } + } + +} diff --git a/metals-runner/src/main/scala/scastie/metals/ScastieMetals.scala b/metals-runner/src/main/scala/scastie/metals/ScastieMetals.scala index a89f823b5..e3b457d46 100644 --- a/metals-runner/src/main/scala/scastie/metals/ScastieMetals.scala +++ b/metals-runner/src/main/scala/scastie/metals/ScastieMetals.scala @@ -6,6 +6,7 @@ import cats.effect.Async import cats.syntax.all._ import com.evolutiongaming.scache.Cache import com.olegych.scastie.api._ +import scastie.metals.DTOCodecs._ import org.eclipse.lsp4j._ trait ScastieMetals[F[_]]: @@ -13,7 +14,7 @@ trait ScastieMetals[F[_]]: def completionInfo(request: CompletionInfoRequest): EitherT[F, FailureType, String] def hover(request: LSPRequestDTO): EitherT[F, FailureType, Hover] def signatureHelp(request: LSPRequestDTO): EitherT[F, FailureType, SignatureHelp] - def isConfigurationSupported(config: ScastieMetalsOptions): EitherT[F, FailureType, Boolean] + def isConfigurationSupported(config: ScastieMetalsOptions): EitherT[F, FailureType, ScastieMetalsOptions] object ScastieMetalsImpl: @@ -33,8 +34,12 @@ object ScastieMetalsImpl: def signatureHelp(request: LSPRequestDTO): EitherT[F, FailureType, SignatureHelp] = dispatcher.getCompiler(request.options) >>= (_.signatureHelp(request.offsetParams)) - def isConfigurationSupported(config: ScastieMetalsOptions): EitherT[F, FailureType, Boolean] = - dispatcher.areDependenciesSupported(config) >>= - (_ => dispatcher.getCompiler(config).map(_ => true)) + def isConfigurationSupported(config: ScastieMetalsOptions): EitherT[F, FailureType, ScastieMetalsOptions] = + dispatcher.convertConfigurationFromScalaCli(config) >>= + (config => + dispatcher.areDependenciesSupported(config) >>= + (_ => dispatcher.getCompiler(config).map(_ => config)) + ) + } diff --git a/metals-runner/src/test/scala/scastie/metals/MetalsServerTest.scala b/metals-runner/src/test/scala/scastie/metals/MetalsServerTest.scala index 7385da91d..15d042a4b 100644 --- a/metals-runner/src/test/scala/scastie/metals/MetalsServerTest.scala +++ b/metals-runner/src/test/scala/scastie/metals/MetalsServerTest.scala @@ -368,4 +368,42 @@ class MetalsServerTest extends CatsEffectSuite { expected = Set().asRight ) } + + test("Scala-CLI: Completion with dependency given with `import dep` directives") { + testCompletion( + testTargets = List(ScalaTarget.ScalaCli()), + code = """//> using dep "com.lihaoyi::os-lib:0.9.1" + |object M { + | os.pw@@ + |} + """.stripMargin, + expected = Set("pwd: Path").asRight + ) + } + + test("Scala-CLI: Completion with dependency given with `import lib` directives") { + testCompletion( + testTargets = List(ScalaTarget.ScalaCli()), + code = """//> using lib "com.lihaoyi::os-lib:0.9.1" + |object M { + | os.pw@@ + |} + """.stripMargin, + expected = Set("pwd: Path").asRight + ) + } + + test("Scala-CLI: Hover on a dependency function works") { + testCompletionInfo( + testTargets = List(ScalaTarget.ScalaCli()), + code = """//> using lib "com.lihaoyi::os-lib:0.9.1" + |object M { + | os.pw@@d + |} + """.stripMargin, + expected = List( + "The current working directory for this process.".asRight + ) + ) + } } diff --git a/metals-runner/src/test/scala/scastie/metals/TestUtils.scala b/metals-runner/src/test/scala/scastie/metals/TestUtils.scala index 878d7ae9d..78bc2240a 100644 --- a/metals-runner/src/test/scala/scastie/metals/TestUtils.scala +++ b/metals-runner/src/test/scala/scastie/metals/TestUtils.scala @@ -38,7 +38,7 @@ object TestUtils extends Assertions with CatsEffectAssertions { ): LSPRequestDTO = val offsetParamsComplete = testCode(code) val dependencies0 = dependencies.map(_.apply(scalaTarget)) - LSPRequestDTO(ScastieMetalsOptions(dependencies0, scalaTarget), offsetParamsComplete) + LSPRequestDTO(ScastieMetalsOptions(dependencies0, scalaTarget, code = Some(code)), offsetParamsComplete) def getCompat[A](scalaTarget: ScalaTarget, compat: Map[String, A], default: A): A = val binaryScalaVersion = scalaTarget.binaryScalaVersion diff --git a/sbt-runner/src/main/scala/com.olegych.scastie.sbt/SbtActor.scala b/sbt-runner/src/main/scala/com.olegych.scastie.sbt/SbtActor.scala index 64e875b06..76579bb8a 100644 --- a/sbt-runner/src/main/scala/com.olegych.scastie.sbt/SbtActor.scala +++ b/sbt-runner/src/main/scala/com.olegych.scastie.sbt/SbtActor.scala @@ -21,7 +21,7 @@ class SbtActor(system: ActorSystem, def balancer(context: ActorContext, info: ReconnectInfo): ActorSelection = { import info._ context.actorSelection( - s"akka://Web@$serverHostname:$serverAkkaPort/user/DispatchActor" + s"akka://Web@$serverHostname:$serverAkkaPort/user/DispatchActor/SbtDispatcher" ) } @@ -29,7 +29,7 @@ class SbtActor(system: ActorSystem, if (isProduction) { reconnectInfo.foreach { info => import info._ - balancer(context, info) ! SbtRunnerConnect(actorHostname, actorAkkaPort) + balancer(context, info) ! RunnerConnect(actorHostname, actorAkkaPort) } } } @@ -64,8 +64,8 @@ class SbtActor(system: ActorSystem, ) override def receive: Receive = reconnectBehavior orElse [Any, Unit] { - case SbtPing => { - sender() ! SbtPong + case RunnerPing => { + sender() ! RunnerPong } case format: FormatRequest => { diff --git a/scli-runner/src/main/resources/application.conf b/scli-runner/src/main/resources/application.conf new file mode 100644 index 000000000..65b143f11 --- /dev/null +++ b/scli-runner/src/main/resources/application.conf @@ -0,0 +1,31 @@ +com.olegych.scastie { + sbt { + hostname = "127.0.0.1" + hostname = ${?RUNNER_HOSTNAME} + akka-port = 5250 + akka-port = ${?RUNNER_PORT} + + reconnect = false + reconnect = ${?RUNNER_RECONNECT} + + production = false + production = ${?RUNNER_PRODUCTION} + } +} + +akka { + loggers = ["akka.event.slf4j.Slf4jLogger"] + loglevel = "INFO" + actor { + provider = cluster + warn-about-java-serializer-usage = false + allow-java-serialization = on + } + remote { + artery.canonical { + hostname = ${com.olegych.scastie.sbt.hostname} + port = ${com.olegych.scastie.sbt.akka-port} + } + } +} +akka.remote.artery.advanced.maximum-frame-size = 5 MiB diff --git a/scli-runner/src/main/resources/logback.xml b/scli-runner/src/main/resources/logback.xml new file mode 100644 index 000000000..922666d6e --- /dev/null +++ b/scli-runner/src/main/resources/logback.xml @@ -0,0 +1,15 @@ + + + + System.out + + %d{HH:mm:ss.SSS} %-5level %logger{5}:%line | %msg%n + + + + + + + + + diff --git a/scli-runner/src/main/scala/com.olegych.scastie.sclirunner/BspClient.scala b/scli-runner/src/main/scala/com.olegych.scastie.sclirunner/BspClient.scala new file mode 100644 index 000000000..151adf451 --- /dev/null +++ b/scli-runner/src/main/scala/com.olegych.scastie.sclirunner/BspClient.scala @@ -0,0 +1,298 @@ +package com.olegych.scastie.sclirunner + +import com.typesafe.scalalogging.Logger +import scala.collection.mutable.{Map, HashMap} +import org.eclipse.lsp4j.jsonrpc.Launcher +import ch.epfl.scala.bsp4j._ +import java.util.Collections +import java.util.concurrent.Executors +import java.util.concurrent.CompletableFuture +import java.io.{InputStream, OutputStream} +import java.nio.file.Path +import scala.concurrent.Future + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.jdk.FutureConverters._ +import scala.jdk.CollectionConverters._ +import scala.jdk.OptionConverters._ +import scala.sys.process.{ Process, ProcessBuilder } +import java.util.Optional +import com.olegych.scastie.api.Problem +import com.olegych.scastie.api.Severity +import com.olegych.scastie.api +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException +import org.eclipse.lsp4j.jsonrpc.messages.CancelParams +import java.net.URI +import java.nio.file.Paths + + +object BspClient { + case class BspRun(process: ProcessBuilder, warnings: List[Diagnostic], logMessages: List[String]) { + def toProblemList: List[Problem] = convertDiagListToProblemList(warnings) + } + + trait BspError extends Exception { + val logs: List[String] = List() + } + case class FailedRunError(err: String, override val logs: List[String] = List()) extends BspError + case class NoTargetsFoundException(err: String, override val logs: List[String] = List()) extends BspError + case class NoMainClassFound(err: String, override val logs: List[String] = List()) extends BspError + case class CompilationError(err: List[Diagnostic], override val logs: List[String] = List()) extends BspError { + def toProblemList: List[Problem] = convertDiagListToProblemList(err) + } + + private def diagSeverityToSeverity(severity: DiagnosticSeverity): Severity = { + if (severity == DiagnosticSeverity.ERROR) api.Error + else if (severity == DiagnosticSeverity.INFORMATION) api.Info + else if (severity == DiagnosticSeverity.HINT) api.Info + else if (severity == DiagnosticSeverity.WARNING) api.Warning + else api.Error + } + + + private def convertDiagListToProblemList(list: List[Diagnostic]) = + list.map(diagnostic => + Problem( + diagSeverityToSeverity(diagnostic.getSeverity()), + Some(diagnostic.getRange().getStart().getLine()), + diagnostic.getMessage() + )) + + + type Callback = String => Any +} + +trait ScalaCliServer extends BuildServer with ScalaBuildServer + with JvmBuildServer + +class BspClient(private val workingDir: Path, + private val inputStream: InputStream, + private val outputStream: OutputStream) { + import BspClient._ + + private val log = Logger("BspClient") + + private val localClient = new InnerClient() + private val es = Executors.newFixedThreadPool(1) + + private val bspLauncher = new Launcher.Builder[ScalaCliServer]() + .setOutput(outputStream) + .setInput(inputStream) + .setLocalService(localClient) + .setExecutorService(es) + .setRemoteInterface(classOf[ScalaCliServer]) + .create() + + private val bspServer = bspLauncher.getRemoteProxy() + + private val listeningThread = new Thread { + override def run() = { + try { + bspLauncher.startListening().get() + } catch { + case _: Throwable => log.info("Listening thread down.") + } + } + } + listeningThread.start() + + def resetInternalBuffers = { + logMessages = List() + } + + def reloadWorkspace = bspServer.workspaceReload().asScala + + def getBuildTargetId: Future[Either[BuildTargetIdentifier, NoTargetsFoundException]] = + bspServer.workspaceBuildTargets().asScala + .map(_.getTargets().stream().filter(t => { + val isTestBuild = t.getTags().stream().anyMatch(tag => tag.equals("test")) + !isTestBuild + }).findFirst()) + .map(_.toScala) + .map( + _.map(target => Left(target.getId())) + .getOrElse(Right(NoTargetsFoundException("No build target found."))) + ) + + def compile(id: String, buildTargetId: BuildTargetIdentifier): Future[Either[CompileResult, CompilationError]] = + bspServer.buildTargetCompile({ + val param = new CompileParams(Collections.singletonList(buildTargetId)) + param.setOriginId(s"$id-compile") + param + }) + .orTimeout(10, TimeUnit.SECONDS) + .asScala + .map(compileResult => { + if (compileResult.getStatusCode() == StatusCode.ERROR) { + log.info(s"Error while compiling $diagnostics.") + Right(CompilationError(diagnostics, logMessages.map(_.getMessage()))) + } else { + Left(compileResult) + } + }) + .recover({ + case _: TimeoutException => + log.warn(s"Compilation timeout on snippet $id") + sys.exit(-1) + case k => throw k + }) + + // Throws either NoMainClassFound or UnexpectedError on unexpected result. + def getMainClass(id: BuildTargetIdentifier): Future[Either[ScalaMainClass, BspError]] = + bspServer.buildTargetScalaMainClasses({ + val param = new ScalaMainClassesParams(Collections.singletonList(id)) + param.setOriginId(s"$id-main-classes") + param + }) + .asScala + .map(_.getItems().asScala.toList) + .map({ + case Nil => Right(NoMainClassFound("No main class found.", logMessages.map(_.getMessage()))) + case item :: _ => item.getClasses.asScala.toList match { + case class_ :: _ => Left(class_) + case _ => Right(NoMainClassFound("No main class found.", logMessages.map(_.getMessage()))) + } + }) + + def getJvmRunEnvironment(id: BuildTargetIdentifier): Future[Either[JvmEnvironmentItem, FailedRunError]] = + bspServer.jvmRunEnvironment(new JvmRunEnvironmentParams(java.util.List.of(id))) + .asScala + .map(_.getItems().asScala.toList) + .map({ + case head :: next => Left(head) + case Nil => Right(FailedRunError("No JvmRunEnvironmentResult available.", logMessages.map(_.getMessage()))) + }) + + def makeProcess(mainClass: String, runSettings: JvmEnvironmentItem) = { + val classpath = runSettings.getClasspath.asScala.map(uri => Paths.get(new URI(uri))).mkString(":") + val envVars = Map( + "CLASSPATH" -> classpath + ) ++ runSettings.getEnvironmentVariables.asScala + + val process = Process( + Seq("java", mainClass) ++ runSettings.getJvmOptions().asScala, + cwd = new java.io.File(runSettings.getWorkingDirectory()), + envVars.toSeq : _* + ) + + process + } + + // Forward Right if res is Right + // or returns the future if Left + def withShortCircuit[T, U](res: Either[T, BspError], f: T => Future[Either[U, BspError]]): Future[Either[U, BspError]] = { + res match { + case Left(value) => f(value) + case Right(k) => Future.successful(Right(k)) + } + } + + // Returns a (T, U) if the two are left + // if any of the two is Right, then returns Right of the first + def combineEither[T, U](a: Either[T, BspError], b: Either[U, BspError]) = { + a match { + case Left(value) => b.fold(bLeft => Left(value,bLeft), Right(_)) + case Right(value) => Right(value) + } + } + + def build(id: String): Future[Either[BspRun, BspError]]= { + resetInternalBuffers + + for ( + r <- reloadWorkspace; + buildTarget <- getBuildTargetId; + + // Compile + compilationResult <- withShortCircuit(buildTarget, target => compile(id, target)); + + // Get main class + // Note: it is combined to compilationResult so if compilationResult fails, + // then we do not continue + mainClass <- withShortCircuit[(BuildTargetIdentifier, CompileResult), ScalaMainClass]( + combineEither(buildTarget, compilationResult), + { + case ((tId: BuildTargetIdentifier, _)) => getMainClass(tId) + } + ); + + // Get JvmRunEnv + jvmRunEnv <- withShortCircuit[(BuildTargetIdentifier, ScalaMainClass), JvmEnvironmentItem]( + combineEither(buildTarget, mainClass), + { + case ((tId: BuildTargetIdentifier, _)) => getJvmRunEnvironment(tId) + } + ) + ) yield { + val ret = combineEither(mainClass, jvmRunEnv) match { + case Left((mainClass, jvmRunEnv)) => { + Left(BspRun( + makeProcess(mainClass.getClassName(), jvmRunEnv), + diagnostics, + logMessages.map(_.getMessage()) + )) + } + case Right(value) => Right(value) + } + ret + } + } + + // Kills the BSP connection and makes this object + // un-usable. + def end = { + bspServer.buildShutdown().get(2, TimeUnit.SECONDS) + bspServer.onBuildExit() + listeningThread.interrupt() // This will stop the thread + } + + def initWorkspace: Unit = { + val r = bspServer.buildInitialize(new InitializeBuildParams( + "BspClient", + "1.0.0", // TODO: maybe not hard-code the version? not really much important + "2.1.0-M4", // TODO: same + workingDir.toAbsolutePath().normalize().toUri().toString(), + new BuildClientCapabilities(Collections.singletonList("scala")) + )).get // Force to wait + log.info(s"Initialized workspace: $r") + bspServer.onBuildInitialized() + } + + initWorkspace + + var diagnostics: List[Diagnostic] = List() + + // Note, log messages is not really useful now but if we want to forward + // execution progress, could be a good idea + var logMessages: List[LogMessageParams] = List() + + class InnerClient extends BuildClient { + def onBuildLogMessage(params: LogMessageParams): Unit = { + logMessages = params :: logMessages + } + def onBuildPublishDiagnostics(params: PublishDiagnosticsParams): Unit = { + if (params.getReset()) + diagnostics = List() + + diagnostics = params.getDiagnostics().asScala.toList ++ diagnostics + } + def onBuildShowMessage(params: ShowMessageParams): Unit = () // log.info(s"ShowMessageParams: $params") + def onBuildTargetDidChange(params: DidChangeBuildTarget): Unit = () // log.info(s"DidChangeBuildTarget: $params") + def onBuildTaskFinish(params: TaskFinishParams): Unit = () // log.info(s"TaskFinishParams: $params") + def onBuildTaskProgress(params: TaskProgressParams): Unit = () // log.info(s"TaskProgressParams: $params") + def onBuildTaskStart(params: TaskStartParams): Unit = () // log.info(s"TaskStartParams: $params") + } + + private def wrapTimeout[T](id: String, cf: CompletableFuture[T]) = { + cf.orTimeout(30, TimeUnit.SECONDS).asScala.recover((throwable => { + throwable match { + case _: TimeoutException => { + // TODO: cancel + throw FailedRunError("Timeout exceeded.") + } + case _ => throw throwable + } + })) + } +} \ No newline at end of file diff --git a/scli-runner/src/main/scala/com.olegych.scastie.sclirunner/ScliActor.scala b/scli-runner/src/main/scala/com.olegych.scastie.sclirunner/ScliActor.scala new file mode 100644 index 000000000..b06a5bc4d --- /dev/null +++ b/scli-runner/src/main/scala/com.olegych.scastie.sclirunner/ScliActor.scala @@ -0,0 +1,180 @@ +package `com.olegych.scastie.sclirunner` + +import akka.actor.ActorSystem +import akka.actor.ActorRef +import akka.actor.Actor +import akka.actor.ActorLogging +import akka.actor.ActorContext +import com.olegych.scastie.util.ActorReconnecting +import com.olegych.scastie.util.ReconnectInfo +import com.olegych.scastie.util.SbtTask +import com.olegych.scastie.api.SnippetId +import com.olegych.scastie.api.Inputs +import com.olegych.scastie.api.RunnerPing +import com.olegych.scastie.api.RunnerPong +import com.olegych.scastie.api.SnippetProgress +import com.olegych.scastie.api.ProcessOutput +import com.olegych.scastie.api.ProcessOutputType +import com.olegych.scastie.sclirunner.ScliRunner +import com.olegych.scastie.sclirunner.BspClient +import scala.concurrent.ExecutionContext.Implicits.global +import scala.util.Failure +import scala.util.Success + +import java.time.Instant + +import scala.sys.process._ +import java.nio.charset.StandardCharsets +import scala.io.{Source => IOSource} +import com.olegych.scastie.api.Problem +import com.olegych.scastie.api +import play.api.libs.json.Reads +import play.api.libs.json.Json +import scala.util.control.NonFatal +import akka.actor.ActorSelection +import scala.concurrent.duration.FiniteDuration +import com.olegych.scastie.util.ScliActorTask +import akka.util.Timeout +import scala.concurrent.duration._ +import akka.pattern.ask + +object ScliActor { + // States + sealed trait ScliState + + case object Available extends ScliState + case object Running extends ScliState + + def sbtTaskToScliActorTask(sbtTask: SbtTask): ScliActorTask = { + sbtTask match { + case SbtTask(snippetId, inputs, ip, _, progressActor) => + ScliActorTask(snippetId, inputs, ip, progressActor) + } + } +} + + +class ScliActor(system: ActorSystem, + isProduction: Boolean, + runTimeout: FiniteDuration, + readyRef: Option[ActorRef], + override val reconnectInfo: Option[ReconnectInfo]) + extends Actor + with ActorLogging + with ActorReconnecting { + + import ScliActor._ + + // Runner + private val runner: ScliRunner = new ScliRunner + + // Initial state + var currentState: ScliState = Available + override def receive: Receive = whenAvailable + + + // FSM + // Available state (no running scala cli instance) + val whenAvailable: Receive = reconnectBehavior orElse { message => message match { + case task: SbtTask => { + // log.warning("Should not receive an SbtTask, converting to ScliActorTask") + runTask(sbtTaskToScliActorTask(task), sender()) + } + case task: ScliActorTask => runTask(task, sender()) + + case RunnerPing => sender() ! RunnerPong + } } + + + def makeOutput(str: String) = Some(ProcessOutput(str, tpe = ProcessOutputType.StdOut, id = None)) + def makeOutput(str: List[String]): Option[ProcessOutput] = makeOutput(str.mkString("\n")) + + // Run task + def runTask(task: ScliActorTask, author: ActorRef): Unit = { + val ScliActorTask(snipId, inp, ip, progressActor) = task + + val r = runner.runTask(ScliRunner.ScliTask(snipId, inp, ip), runTimeout, output => { + sendProgress(progressActor, author, SnippetProgress.default.copy( + ts = Some(Instant.now.toEpochMilli), + snippetId = Some(snipId), + userOutput = makeOutput(output), + isDone = false + )) + }) + + r.onComplete({ + case Failure(exception) => { + // Unexpected exception + log.error(exception, s"Could not run $snipId") + sendProgress(progressActor, author, buildErrorProgress(snipId, s"Unexpected exception while running: $exception")) + } + case Success(Right(error)) => error match { + // TODO: handle every possible exception + case ScliRunner.InstrumentationException(report) => sendProgress(progressActor, author, report.toProgress(snippetId = snipId)) + case ScliRunner.ErrorFromBsp(x: BspClient.NoTargetsFoundException, logs) => sendProgress(progressActor, author, buildErrorProgress(snipId, x.err, logs)) + case ScliRunner.ErrorFromBsp(x: BspClient.NoMainClassFound, logs) => sendProgress(progressActor, author, buildErrorProgress(snipId, x.err, logs)) + case ScliRunner.ErrorFromBsp(x: BspClient.FailedRunError, logs) => sendProgress(progressActor, author, buildErrorProgress(snipId, x.err, logs)) + case ScliRunner.CompilationError(problems, logs) => { + sendProgress(progressActor, author, SnippetProgress.default.copy( + ts = Some(Instant.now.toEpochMilli), + snippetId = Some(snipId), + compilationInfos = problems, + userOutput = makeOutput(logs), + isDone = true + )) + } + } + case Success(Left(value)) => sendProgress(progressActor, author, SnippetProgress.default.copy( + ts = Some(Instant.now.toEpochMilli), + snippetId = Some(snipId), + isDone = true, + instrumentations = value.instrumentation.getOrElse(List()), + compilationInfos = value.diagnostics + )) + }) + } + + // Progress + private var progressId = 0L + + private def sendProgress(progressActor: ActorRef, author: ActorRef, _p: SnippetProgress): Unit = { + progressId = progressId + 1 + val p: SnippetProgress = _p.copy(id = Some(progressId)) + progressActor ! p + implicit val tm = Timeout(10.seconds) + (author ? p) + .recover { + case e => + log.error(e, s"error while saving progress $p") + } + } + + private def buildErrorProgress(snipId: SnippetId, err: String, logs: List[String] = List()) = { + SnippetProgress.default.copy( + ts = Some(Instant.now.toEpochMilli), + snippetId = Some(snipId), + isDone = true, + compilationInfos = List(Problem(api.Error, line = None, message = err)), + userOutput = Some(ProcessOutput(logs.mkString("\n") + "\n" + err, ProcessOutputType.StdErr, None)) + ) + } + + + // Reconnection + def balancer(context: ActorContext, info: ReconnectInfo): ActorSelection = { + import info._ + context.actorSelection( + s"akka://Web@$serverHostname:$serverAkkaPort/user/DispatchActor/ScliDispatcher" + ) + } + + override def tryConnect(context: ActorContext): Unit = { + if (isProduction) { + reconnectInfo.foreach { info => + import info._ + balancer(context, info) ! api.RunnerConnect(actorHostname, actorAkkaPort) + } + } + } + +} \ No newline at end of file diff --git a/scli-runner/src/main/scala/com.olegych.scastie.sclirunner/ScliMain.scala b/scli-runner/src/main/scala/com.olegych.scastie.sclirunner/ScliMain.scala new file mode 100644 index 000000000..43858009e --- /dev/null +++ b/scli-runner/src/main/scala/com.olegych.scastie.sclirunner/ScliMain.scala @@ -0,0 +1,79 @@ +package `com.olegych.scastie.sclirunner` + +import com.olegych.scastie.util.ScastieFileUtil.writeRunningPid +import com.olegych.scastie.util.ReconnectInfo + +import akka.actor.{ActorSystem, Props} +import com.typesafe.config.ConfigFactory + +import scala.concurrent.Await +import scala.concurrent.duration._ +import java.util.concurrent.TimeUnit + +import org.slf4j.LoggerFactory + +/** + * This object provides the main endpoint for the Scala-CLI runner. + * Its role is to create and setup the ActorSystem and create the ScalaCli Actor + */ +object SbtMain { + def main(args: Array[String]): Unit = { + val logger = LoggerFactory.getLogger(getClass) + + val system = ActorSystem("ScliRunner") + + val config2 = ConfigFactory.load().getConfig("akka.remote.artery.canonical") + logger.info("akka tcp config") + logger.info(" '" + config2.getString("hostname") + "'") + logger.info(" " + config2.getInt("port")) + + val config = ConfigFactory.load().getConfig("com.olegych.scastie") + + val serverConfig = config.getConfig("web") + val sbtConfig = config.getConfig("sbt") + + val isProduction = true + + // TODO: check if production + // Create appropriate config files + if (isProduction) { + val pid = writeRunningPid("RUNNING_PID") + logger.info(s"Starting scliRunner pid: $pid") + } + + val runTimeout = { + val timeunit = TimeUnit.SECONDS + FiniteDuration( + sbtConfig.getDuration("runTimeout", timeunit), + timeunit + ) + } + + // Reconnect info + val reconnectInfo = ReconnectInfo( + serverHostname = serverConfig.getString("hostname"), + serverAkkaPort = serverConfig.getInt("akka-port"), + actorHostname = sbtConfig.getString("hostname"), + actorAkkaPort = sbtConfig.getInt("akka-port") + ) + + system.actorOf( + Props( + new ScliActor( + system = system, + isProduction = isProduction, + readyRef = None, + runTimeout = runTimeout, + reconnectInfo = Some(reconnectInfo) + ) + ), + name = "ScliActor" + ) + + logger.info("ScliActor started") + + Await.result(system.whenTerminated, Duration.Inf) + + () + } +} diff --git a/scli-runner/src/main/scala/com.olegych.scastie.sclirunner/ScliRunner.scala b/scli-runner/src/main/scala/com.olegych.scastie.sclirunner/ScliRunner.scala new file mode 100644 index 000000000..9ff060bab --- /dev/null +++ b/scli-runner/src/main/scala/com.olegych.scastie.sclirunner/ScliRunner.scala @@ -0,0 +1,250 @@ +package com.olegych.scastie.sclirunner + +import com.olegych.scastie.api.{SnippetId, Inputs, ScalaDependency} +import com.olegych.scastie.instrumentation.{InstrumentedInputs, InstrumentationFailureReport} +import com.typesafe.scalalogging.Logger +import java.nio.file.{Files, Path, StandardOpenOption} +import java.util.concurrent.CompletableFuture +import java.io.{InputStream, OutputStream} +import scala.concurrent.Future +import scala.concurrent.ExecutionContext.Implicits.global +import scala.sys.process._ +import com.olegych.scastie.instrumentation.InstrumentationFailure +import com.olegych.scastie.api.Problem +import com.olegych.scastie.instrumentation.Instrument +import play.api.libs.json.Reads +import play.api.libs.json.Json +import scala.util.control.NonFatal +import com.olegych.scastie.api.Instrumentation +import scala.concurrent.duration.Duration +import java.util.concurrent.TimeUnit +import scala.concurrent.Await +import scala.util.Try +import scala.util.Success +import scala.util.Failure +import java.util.concurrent.TimeoutException +import scala.collection.concurrent.TrieMap +import com.olegych.scastie.api.ScalaTarget +import com.olegych.scastie.buildinfo.BuildInfo +import scala.concurrent.duration.FiniteDuration + + +object ScliRunner { + case class ScliRun( + output: List[String], + instrumentation: Option[List[Instrumentation]] = None, + diagnostics: List[Problem] = List() + ) + + case class ScliTask(snippetId: SnippetId, inputs: Inputs, ip: String) + + // Errors + abstract class ScliRunnerError extends Exception + case class InvalidScalaVersion(version: String) extends ScliRunnerError + case class InstrumentationException(failure: InstrumentationFailureReport) extends ScliRunnerError + case class CompilationError(problems: List[Problem], logs: List[String] = List()) extends ScliRunnerError + // From Bsp + case class ErrorFromBsp(err: BspClient.BspError, logs: List[String] = List()) extends ScliRunnerError +} + +class ScliRunner { + import ScliRunner._ + + private val log = Logger("ScliRunner") + + // Files + private val workingDir = Files.createTempDirectory("scastie") + private val scalaMain = workingDir.resolve("src/main/scala/main.scala") + + private def initFiles : Unit = { + Files.createDirectories(scalaMain.getParent()) + writeFile(scalaMain, "@main def main = { println(\"hello world!\") } ") + } + + private def writeFile(path: Path, content: String): Unit = { + if (Files.exists(path)) { + Files.delete(path) + } + + Files.write(path, content.getBytes, StandardOpenOption.CREATE_NEW) + + () + } + + def runTask(task: ScliTask, timeout: FiniteDuration, onOutput: String => Any): Future[Either[ScliRun, ScliRunnerError]] = { + log.info(s"Running task with snippetId=${task.snippetId}") + + // Extract directives from user code + val (userDirectives, userCode) = task.inputs.code.split("\n") + .span(line => line.startsWith("//>")) + + var userTarget = userDirectives.map(_.split(" ")).find(_.contains("scala")).map(_.last).getOrElse("3") + // removing quotes + userTarget = userTarget.replaceAll("\"", "") + + val scalaTarget = ScalaTarget.fromScalaVersion(userTarget) + + scalaTarget match { + case None => Future.successful(Right(InvalidScalaVersion(userTarget))) + case Some(scalaTarget) => + // Instrument + InstrumentedInputs(task.inputs.copy(code = userCode.mkString("\n"), target = scalaTarget)) match { + case Left(failure) => Future.failed(InstrumentationException(failure)) + case Right(InstrumentedInputs(inputs, isForcedProgramMode, _)) => + buildAndRun(task.snippetId, inputs, isForcedProgramMode, userDirectives, userCode, timeout, onOutput) + } + } + } + + def buildAndRun( + snippetId: SnippetId, + inputs: Inputs, + isForcedProgramMode: Boolean, + userDirectives: Array[String], + userCode: Array[String], + timeout: FiniteDuration, + onOutput: String => Any + ) + : Future[Either[ScliRun, ScliRunnerError]] = { + val runtimeDependency = inputs.target.runtimeDependency.map(Set(_)).getOrElse(Set()) ++ inputs.libraries + val allDirectives = (runtimeDependency.map(scalaDepToFullName).map(libraryDirective) ++ userDirectives) + val totalLineOffset = -runtimeDependency.size + Instrument.getExceptionLineOffset(inputs) + + val charOffsetInstrumentation = userDirectives.map(_.length() + 1).sum + + val code = allDirectives.mkString("\n") + "\n" + inputs.code + writeFile(scalaMain, code) + + var instrumentationMem: Option[List[Instrumentation]] = None + var outputBuffer: List[String] = List() + + def forwardPrint(str: String) = { + outputBuffer = str :: outputBuffer + onOutput(str) + } + + def mapProblems(list: List[Problem]): List[Problem] = { + list.map(pb => + pb.copy(line = pb.line.map(line => { + if (line <= allDirectives.size) // if issue is on directive, then do not map. + line + else + (line + totalLineOffset + 1) + }) + // removing invalid lines + // NOTE: somehow, BSP can report lines ≤ 0 and are usually + // duplicates of previous reports. + .filter(_ > 0) + ) + ) + } + + def handleError(bspError: BspClient.BspError): ScliRunnerError = bspError match { + case x: BspClient.CompilationError => CompilationError(mapProblems(x.toProblemList), x.logs) + case _ => ErrorFromBsp(bspError, bspError.logs) + } + + // Should be executed asynchronously due to the timeout (executed synchronously) + def runProcess(bspRun: BspClient.BspRun) = { + // print log messages + bspRun.logMessages.foreach(forwardPrint) + + val runProcess = bspRun.process.run( + ProcessLogger({ line: String => { + // extract instrumentation + extract[List[Instrumentation]](line) match { + case None => forwardPrint(line) + case Some(value) => { + instrumentationMem = Some(value.map(inst => inst.copy( + position = inst.position.copy(inst.position.start + charOffsetInstrumentation, inst.position.end + charOffsetInstrumentation) + ))) + } + } + + }}) + ) + javaProcesses.put(snippetId, runProcess) + + // Wait + val f = Future { runProcess.exitValue() } + val didSucceed = + Try(Await.result(f, timeout)) match { + case Success(value) => { + forwardPrint(s"Process exited with error code $value") + true + } + case Failure(_: TimeoutException) => { + forwardPrint("Timeout exceeded.") + false + } + case Failure(e) => { + forwardPrint(s"Unknown exception $e") + false + } + } + + if (!didSucceed) { + runProcess.destroy() + } + javaProcesses.remove(snippetId) + + ScliRun(outputBuffer, instrumentationMem, mapProblems(bspRun.toProblemList)) + } + + val build = bspClient.build(snippetId.base64UUID) + build.map { result => + result match { + case Right(bspError) => Right(handleError(bspError)) + case Left(x: BspClient.BspRun) => Left(runProcess(x)) + } + } + } + + def end: Unit = { + bspClient.end + javaProcesses.values.foreach(_.destroy()) + process.map(_.destroy()) + } + + // Java processes + private val javaProcesses = TrieMap[SnippetId, Process]() // mutable and concurrent HashMap + + // Process streams + private var pStdin: Option[OutputStream] = None + private var pStdout: Option[InputStream] = None + private var pStderr: Option[InputStream] = None + private var process: Option[Process] = None + + // Bsp + private val bspClient = { + log.info(s"Starting Scala-CLI BSP in folder ${workingDir.toAbsolutePath().normalize().toString()}") + val processBuilder: ProcessBuilder = Process(Seq("scala-cli", "bsp", ".", "-deprecation"), workingDir.toFile()) + val io = BasicIO.standard(true) + .withInput(i => pStdin = Some(i)) + .withError(e => pStderr = Some(e)) + .withOutput(o => pStdout = Some(o)) + + process = Some(processBuilder.run(io)) + + // TODO: really bad + while (pStdin.isEmpty || pStdout.isEmpty || pStderr.isEmpty) Thread.sleep(100) + + // Create BSP connection + new BspClient(workingDir, pStdout.get, pStdin.get) + } + + private val runTimeScala = "//> using lib \"org.scastie::runtime-scala\"" + + private def scalaDepToFullName = (dep: ScalaDependency) => s"${dep.groupId}::${dep.artifact}:${dep.version}" + private def libraryDirective = (lib: String) => s"//> using lib \"$lib\"".mkString + + initFiles + + private def extract[T: Reads](line: String) = { + try { + Json.fromJson[T](Json.parse(line)).asOpt + } catch { + case NonFatal(e) => None + } + } +} \ No newline at end of file diff --git a/scli-runner/src/test/resources/application.conf b/scli-runner/src/test/resources/application.conf new file mode 100644 index 000000000..f83d1b46d --- /dev/null +++ b/scli-runner/src/test/resources/application.conf @@ -0,0 +1,6 @@ +com.olegych.scastie { + sbt { + akka-port = 15150 + } +} +akka.actor.provider = akka.actor.LocalActorRefProvider \ No newline at end of file diff --git a/scli-runner/src/test/resources/directive-1.scala b/scli-runner/src/test/resources/directive-1.scala new file mode 100644 index 000000000..4ba926418 --- /dev/null +++ b/scli-runner/src/test/resources/directive-1.scala @@ -0,0 +1,3 @@ +//> using lib "com.lihaoyi::os-lib:0.9.1" + +import os.Source \ No newline at end of file diff --git a/scli-runner/src/test/resources/directive-2.scala b/scli-runner/src/test/resources/directive-2.scala new file mode 100644 index 000000000..018a33d23 --- /dev/null +++ b/scli-runner/src/test/resources/directive-2.scala @@ -0,0 +1,8 @@ +//> using scala "2" +//> using lib "ch.epfl.scala::bsp4s:2.0.0" + +import ch.epfl.scala.bsp.BuildTarget + +BuildTarget + +println("hello") \ No newline at end of file diff --git a/scli-runner/src/test/resources/instrumentation-test.scala b/scli-runner/src/test/resources/instrumentation-test.scala new file mode 100644 index 000000000..572cc0295 --- /dev/null +++ b/scli-runner/src/test/resources/instrumentation-test.scala @@ -0,0 +1,10 @@ +//> using scala "2" +//> blabla +// some annoying comments to be sure… + +123+3 + +// hello, if you read this, you are probably wondering why do i write this? +// because we need tests in our lives. + +42+2 \ No newline at end of file diff --git a/scli-runner/src/test/resources/non-compilable.scala b/scli-runner/src/test/resources/non-compilable.scala new file mode 100644 index 000000000..b10ef8854 --- /dev/null +++ b/scli-runner/src/test/resources/non-compilable.scala @@ -0,0 +1 @@ +@main def main = nonExistingFunction("scala") \ No newline at end of file diff --git a/scli-runner/src/test/resources/normal.scala b/scli-runner/src/test/resources/normal.scala new file mode 100644 index 000000000..c25ffa258 --- /dev/null +++ b/scli-runner/src/test/resources/normal.scala @@ -0,0 +1 @@ +@main def main = println("hello!") \ No newline at end of file diff --git a/scli-runner/src/test/resources/too-long.scala b/scli-runner/src/test/resources/too-long.scala new file mode 100644 index 000000000..a71e711f8 --- /dev/null +++ b/scli-runner/src/test/resources/too-long.scala @@ -0,0 +1 @@ +Thread.sleep(40000) \ No newline at end of file diff --git a/scli-runner/src/test/scala/com.olegych.scastie.sclirunner/ScliRunnerTest.scala b/scli-runner/src/test/scala/com.olegych.scastie.sclirunner/ScliRunnerTest.scala new file mode 100644 index 000000000..622afc0f4 --- /dev/null +++ b/scli-runner/src/test/scala/com.olegych.scastie.sclirunner/ScliRunnerTest.scala @@ -0,0 +1,90 @@ +package com.olegych.scastie.sclirunner + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.BeforeAndAfterAll +import com.olegych.scastie.util.ScastieFileUtil +import java.nio.file.Paths +import com.olegych.scastie.api.SnippetId +import com.olegych.scastie.api.Inputs +import scala.concurrent.Future +import scala.concurrent.duration._ +import com.olegych.scastie.api.Value + +class ScliRunnerTest extends AnyFunSuite with BeforeAndAfterAll { + + var scliRunner: Option[ScliRunner] = None + + override protected def beforeAll(): Unit = { + scliRunner = Some(new ScliRunner) + } + test("forward compile errors") { + TestUtils.shouldNotCompile( + run("non-compilable") + ) + } + + test("should timeout on >30 seconds scripts") { + TestUtils.shouldTimeout( + run("too-long") + ) + } + + test("directives are updated") { + TestUtils.shouldRun( + run("directive-1") + ) + TestUtils.shouldRun( + run("directive-2") + ) + } + + test("instrumentation is correct") { + val r = TestUtils.shouldRun( + run("instrumentation-test") + ) + + assert(r.instrumentation.isDefined) + val content = r.instrumentation.get + assert(content.exists( + p => p.position.start == 193 && p.position.end == 197 + && { p.render match { + case Value("44", "Int") => true + case _ => false + } } + )) + assert(content.exists( + p => p.position.start == 70 && p.position.end == 75 + && { p.render match { + case Value("126", "Int") => true + case _ => false + } } + )) + } + + test("do not instrument if not need") { + val r = TestUtils.shouldRun( + run("normal") + ) + assert(r.output.mkString.contains("hello!")) + } + + override protected def afterAll(): Unit = { + scliRunner.map(_.end) + scliRunner = None + } + + def run(file: String, isWorksheet: Boolean = true, onOutput: String => Any = str => ()): Future[Either[ScliRunner.ScliRun, ScliRunner.ScliRunnerError]] = { + val f = ScastieFileUtil.slurp(Paths.get("scli-runner", "src", "test", "resources", s"$file.scala")) + + if (scliRunner.isEmpty) throw new IllegalStateException("scli-runner is not defined") + if (f.isEmpty) throw new IllegalArgumentException(s"Test file $file does not exist.") + + scliRunner.get.runTask( + ScliRunner.ScliTask( + SnippetId("1", None), + Inputs.default.copy(_isWorksheetMode = isWorksheet, code = f.get), + "1.1.1.1" + ) + , 30.seconds, onOutput) + } +} diff --git a/scli-runner/src/test/scala/com.olegych.scastie.sclirunner/TestUtils.scala b/scli-runner/src/test/scala/com.olegych.scastie.sclirunner/TestUtils.scala new file mode 100644 index 000000000..08dc0764c --- /dev/null +++ b/scli-runner/src/test/scala/com.olegych.scastie.sclirunner/TestUtils.scala @@ -0,0 +1,44 @@ +package com.olegych.scastie.sclirunner + +import scala.jdk.FutureConverters._ +import scala.concurrent.Future +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import java.util.concurrent.TimeUnit +import scala.util.Failure +import scala.util.Success +import scala.util.Try +import com.olegych.scastie.api.Problem + +object TestUtils { + def getResultWithTimeout[T](run: Future[T]) = { + Try(Await.result(run, Duration(35, TimeUnit.SECONDS))) + } + + def shouldNotCompile(run: Future[Either[ScliRunner.ScliRun, ScliRunner.ScliRunnerError]]): List[Problem] = { + val result = getResultWithTimeout(run) + result match { + case Success(Right(ScliRunner.CompilationError(problems, _))) => problems + case _ => throw new AssertionError(s"Expected the code to not compile. Instead, got $result") + } + } + + def shouldOutputString(run: Future[Either[ScliRunner.ScliRun, ScliRunner.ScliRunnerError]], str: String): ScliRunner.ScliRun = { + val result = getResultWithTimeout(run) + result match { + case Success(Left(x @ ScliRunner.ScliRun(output, instrumentations, _))) => { + if (output.exists(_.contains(str))) x + else throw new AssertionError(s"Expected the output to contain at least $str. Contained only $output") + } + case _ => throw new AssertionError(s"Expected the run to have been run. Got $result") + } + } + + def shouldRun(run: Future[Either[ScliRunner.ScliRun, ScliRunner.ScliRunnerError]]) = { + shouldOutputString(run, "") + } + + def shouldTimeout(run: Future[Either[ScliRunner.ScliRun, ScliRunner.ScliRunnerError]]): Unit = { + shouldOutputString(run, "Timeout exceeded.") + } +} diff --git a/utils/src/main/scala/com.olegych.scastie/util/ScliTask.scala b/utils/src/main/scala/com.olegych.scastie/util/ScliTask.scala new file mode 100644 index 000000000..d6fd5067b --- /dev/null +++ b/utils/src/main/scala/com.olegych.scastie/util/ScliTask.scala @@ -0,0 +1,7 @@ +package com.olegych.scastie.util + +import com.olegych.scastie.api._ + +import akka.actor.ActorRef + +case class ScliActorTask(snippetId: SnippetId, inputs: Inputs, ip: String, progressActor: ActorRef) \ No newline at end of file