diff --git a/shared/src/main/scala/async/AsyncOperations.scala b/shared/src/main/scala/async/AsyncOperations.scala index 0eb5c23..ab7228d 100644 --- a/shared/src/main/scala/async/AsyncOperations.scala +++ b/shared/src/main/scala/async/AsyncOperations.scala @@ -35,7 +35,7 @@ object AsyncOperations: * [[java.util.concurrent.TimeoutException]] is thrown. */ def withTimeout[T](timeout: FiniteDuration)(op: Async ?=> T)(using AsyncOperations, Async): T = - Async.group: + Async.group: spawn ?=> Async.select( Future(op).handle(_.get), Future(sleep(timeout)).handle: _ => diff --git a/shared/src/main/scala/async/futures.scala b/shared/src/main/scala/async/futures.scala index e8a4d85..ea9b8e7 100644 --- a/shared/src/main/scala/async/futures.scala +++ b/shared/src/main/scala/async/futures.scala @@ -1,17 +1,17 @@ package gears.async +import language.experimental.captureChecking + import java.util.concurrent.CancellationException import java.util.concurrent.atomic.AtomicBoolean import scala.annotation.tailrec -import scala.annotation.unchecked.uncheckedCaptures import scala.annotation.unchecked.uncheckedVariance import scala.collection.mutable import scala.compiletime.uninitialized import scala.util import scala.util.control.NonFatal import scala.util.{Failure, Success, Try} - -import language.experimental.captureChecking +import gears.async.Async.SourceSymbol /** Futures are [[Async.Source Source]]s that has the following properties: * - They represent a single value: Once resolved, [[Async.await await]]-ing on a [[Future]] should always return the @@ -51,10 +51,11 @@ object Future: * - withResolver: Completion is done by external request set up from a block of code. */ private class CoreFuture[+T] extends Future[T]: + @volatile protected var hasCompleted: Boolean = false protected var cancelRequest = AtomicBoolean(false) private var result: Try[T] = uninitialized // guaranteed to be set if hasCompleted = true - private val waiting = mutable.Set[(Listener[Try[T]]^) @uncheckedCaptures]() + private val waiting: mutable.Set[Listener[Try[T]]^] = mutable.Set() // Async.Source method implementations @@ -107,49 +108,11 @@ object Future: end CoreFuture - private class CancelSuspension[U](val src: Async.Source[U]^)(val ac: Async, val suspension: ac.support.Suspension[Try[U], Unit]) extends Cancellable: - self: CancelSuspension[U]^{src, ac} => - val listener: Listener[U]^{ac} = Listener.acceptingListener[U]: (x, _) => - val completedBefore = complete() - if !completedBefore then - ac.support.resumeAsync(suspension)(Success(x)) - unlink() - var completed = false - - def complete() = synchronized: - val completedBefore = completed - completed = true - completedBefore - - override def cancel() = - val completedBefore = complete() - if !completedBefore then - src.dropListener(listener) - ac.support.resumeAsync(suspension)(Failure(new CancellationException())) - - private class FutureAsync(val group: CompletionGroup)(using ac: Async, label: ac.support.Label[Unit]) extends Async(using ac.support): - /** Await a source first by polling it, and, if that fails, by suspending in a onComplete call. - */ - override def await[U](src: Async.Source[U]^): U = - if group.isCancelled then throw new CancellationException() - src - .poll() - .getOrElse: - val res = ac.support.suspend[Try[U], Unit](k => - val cancellable: CancelSuspension[U]^{src, ac} = CancelSuspension(src)(ac, k) - // val listener: Listener[U] = Listener.acceptingListener[U]: (x, _) => ??? - // val completedBefore = cancellable.complete() - // if !completedBefore then ac.support.resumeAsync(k)(Success(x)) - cancellable.link(group) // may resume + remove listener immediately - src.onComplete(cancellable.listener) - )(using label) - res.get - - override def withGroup(group: CompletionGroup): Async = FutureAsync(group) - /** A future that is completed by evaluating `body` as a separate asynchronous operation in the given `scheduler` */ private class RunnableFuture[+T](body: Async.Spawn ?=> T)(using ac: Async) extends CoreFuture[T]: + private given acSupport: ac.support.type = ac.support + private given acScheduler: ac.support.Scheduler = ac.scheduler /** RunnableFuture maintains its own inner [[CompletionGroup]], that is separated from the provided Async * instance's. When the future is cancelled, we only cancel this CompletionGroup. This effectively means any * `.await` operations within the future is cancelled *only if they link into this group*. The future body run with @@ -160,6 +123,47 @@ object Future: private def checkCancellation(): Unit = if cancelRequest.get() then throw new CancellationException() + private class FutureAsync(val group: CompletionGroup)(using label: acSupport.Label[Unit]) + extends Async(using acSupport, acScheduler): + /** Await a source first by polling it, and, if that fails, by suspending in a onComplete call. + */ + override def await[U](src: Async.Source[U]^): U = + class CancelSuspension extends Cancellable: + var suspension: acSupport.Suspension[Try[U], Unit] = uninitialized + var listener: Listener[U]^{this} = uninitialized + var completed = false + + def complete() = synchronized: + val completedBefore = completed + completed = true + completedBefore + + override def cancel() = + val completedBefore = complete() + if !completedBefore then + src.dropListener(listener) + acSupport.resumeAsync(suspension)(Failure(new CancellationException())) + + if group.isCancelled then throw new CancellationException() + + src + .poll() + .getOrElse: + val cancellable = CancelSuspension() + val res = acSupport.suspend[Try[U], Unit](k => + val listener = Listener.acceptingListener[U]: (x, _) => + val completedBefore = cancellable.complete() + if !completedBefore then acSupport.resumeAsync(k)(Success(x)) + cancellable.suspension = k + cancellable.listener = listener + cancellable.link(group) // may resume + remove listener immediately + src.onComplete(listener) + ) + cancellable.unlink() + res.get + + override def withGroup(group: CompletionGroup): Async = FutureAsync(group) + override def cancel(): Unit = if setCancelled() then this.innerGroup.cancel() link() @@ -178,8 +182,9 @@ object Future: /** Create a future that asynchronously executes `body` that wraps its execution in a [[scala.util.Try]]. The returned * future is linked to the given [[Async.Spawn]] scope by default, i.e. it is cancelled when this scope ends. */ - def apply[T](body: Async.Spawn ?=> T)(using async: Async, spawnable: Async.Spawn) - (using async.type =:= spawnable.type): Future[T]^{body, spawnable} = + def apply[T](body: Async.Spawn ?=> T)(using async: Async, spawnable: Async.Spawn)( + using async.type =:= spawnable.type + ): Future[T]^{body, spawnable} = RunnableFuture(body)(using spawnable) /** A future that is immediately completed with the given result. */ @@ -197,11 +202,11 @@ object Future: /** A future that immediately rejects with the given exception. Similar to `Future.now(Failure(exception))`. */ inline def rejected(exception: Throwable): Future[Nothing] = now(Failure(exception)) - extension [T](f1: Future[T]^) + extension [T](f1: Future[T]) /** Parallel composition of two futures. If both futures succeed, succeed with their values in a pair. Otherwise, * fail with the failure that was returned first. */ - def zip[U](f2: Future[U]^): Future[(T, U)]^{f1, f2} = + def zip[U](f2: Future[U]): Future[(T, U)] = Future.withResolver: r => Async .either(f1, f2) @@ -234,20 +239,20 @@ object Future: * @see * [[orWithCancel]] for an alternative version where the slower future is cancelled. */ - def or(f2: Future[T]^): Future[T]^{f1, f2} = orImpl(false)(f2) + def or(f2: Future[T]): Future[T] = orImpl(false)(f2) /** Like `or` but the slower future is cancelled. If either task succeeds, succeed with the success that was * returned first and the other is cancelled. Otherwise, fail with the failure that was returned last. */ - def orWithCancel(f2: Future[T]^): Future[T]^{f1, f2} = orImpl(true)(f2) + def orWithCancel(f2: Future[T]): Future[T] = orImpl(true)(f2) - inline def orImpl(inline withCancel: Boolean)(f2: Future[T]^): Future[T]^{f1, f2} = Future.withResolver: r => + inline def orImpl(inline withCancel: Boolean)(f2: Future[T]): Future[T] = Future.withResolver: r => Async .raceWithOrigin(f1, f2) .onComplete(Listener { case ((v, which), _) => v match case Success(value) => - inline if withCancel then (if which == f1.symbol then f2 else f1).cancel() + inline if withCancel then (if which == f1 then f2 else f1).cancel() r.resolve(value) case Failure(_) => (if which == f1.symbol then f2 else f1).onComplete(Listener((v, _) => r.complete(v))) @@ -300,7 +305,7 @@ object Future: * may be used. The handler should eventually complete the Future using one of complete/resolve/reject*. The * default handler is set up to [[rejectAsCancelled]] immediately. */ - def onCancel(handler: () -> Unit): Unit + def onCancel(handler: () => Unit): Unit end Resolver /** Create a promise that may be completed asynchronously using external means. @@ -310,16 +315,16 @@ object Future: * * If the external operation supports cancellation, the body can register one handler using [[Resolver.onCancel]]. */ - def withResolver[T](body: Resolver[T]^ => Unit): Future[T] = - val future = new CoreFuture[T] with Resolver[T] with Promise[T] { - @volatile var cancelHandle: (() -> Unit) = () => rejectAsCancelled() - override def onCancel(handler: () -> Unit): Unit = cancelHandle = handler + def withResolver[T](body: Resolver[T] => Unit): Future[T] = + val future = new CoreFuture[T] with Resolver[T] with Promise[T]: + @volatile var cancelHandle: () -> Unit = () => rejectAsCancelled() + override def onCancel(handler: () => Unit): Unit = cancelHandle = caps.unsafe.unsafeAssumePure(handler) override def complete(result: Try[T]): Unit = super.complete(result) override def cancel(): Unit = if setCancelled() then cancelHandle() - } - body(future) + end future + body(future: Resolver[T]) future end withResolver @@ -338,51 +343,46 @@ object Future: * [[Future.awaitAll]] and [[Future.awaitFirst]] for simple usage of the collectors to get all results or the first * succeeding one. */ - class Collector[T](val futures: (Future[T]^)*): + class Collector[T](futures: (Future[T]^)*): private val ch = UnboundedChannel[Future[T]^{futures*}]() - private val futureRefs = mutable.Map[Async.SourceSymbol[Try[T]], Future[T]^{futures*}]() + private val futMap = mutable.Map[SourceSymbol[Try[T]], Future[T]^{futures*}]() /** Output channels of all finished futures. */ final def results: ReadableChannel[Future[T]^{futures*}] = ch.asReadable - private val listener = Listener((_, futRef) => + private val listener = Listener((_, fut) => // safe, as we only attach this listener to Future[T] - val ref = futRef.asInstanceOf[Async.SourceSymbol[Try[T]]] - val fut = futureRefs.synchronized: - // futureRefs.remove(ref).get - futureRefs(ref) - ch.sendImmediately(futureRefs(fut.symbol)) + val future = futMap.synchronized: + futMap.remove(fut.asInstanceOf[SourceSymbol[Try[T]]]).get + ch.sendImmediately(future) ) protected final def addFuture(future: Future[T]^{futures*}) = - futureRefs.synchronized: - futureRefs += (future.symbol -> future) + futMap.synchronized { futMap += (future.symbol -> future) } future.onComplete(listener) futures.foreach(addFuture) end Collector /** Like [[Collector]], but exposes the ability to add futures after creation. */ - class MutableCollector[T](futures: (Future[T]^)*) extends Collector[T](futures*): + class MutableCollector[T](futures: Future[T]*) extends Collector[T](futures*): /** Add a new [[Future]] into the collector. */ - def add(future: Future[T]^{futures*}): Unit = addFuture(future) - def +=(future: Future[T]^{futures*}) = add(future) + inline def add(future: Future[T]^) = addFuture(future) + inline def +=(future: Future[T]^) = add(future) - extension [T](@caps.unbox fs: Seq[Future[T]^]) + extension [T](fs: Seq[Future[T]]) /** `.await` for all futures in the sequence, returns the results in a sequence, or throws if any futures fail. */ def awaitAll(using Async) = val collector = Collector(fs*) - for _ <- fs do - val fut: Future[T]^{fs*} = collector.results.read().right.get - fut.await + for _ <- fs do collector.results.read().right.get.await fs.map(_.await) /** Like [[awaitAll]], but cancels all futures as soon as one of them fails. */ def awaitAllOrCancel(using Async) = - val collector = Collector[T](fs*) + val collector = Collector(fs*) try - for _ <- fs do ??? // collector.results.read().right.get.await + for _ <- fs do collector.results.read().right.get.await fs.map(_.await) catch case NonFatal(e) => @@ -391,22 +391,20 @@ object Future: /** Race all futures, returning the first successful value. Throws the last exception received, if everything fails. */ - def awaitFirst(using Async): T = impl.awaitFirstImpl[T](fs, false) + def awaitFirst(using Async): T = awaitFirstImpl(false) /** Like [[awaitFirst]], but cancels all other futures as soon as the first future succeeds. */ - def awaitFirstWithCancel(using Async): T = impl.awaitFirstImpl[T](fs, true) + def awaitFirstWithCancel(using Async): T = awaitFirstImpl(true) - private object impl: - def awaitFirstImpl[T](@caps.unbox fs: Seq[Future[T]^], withCancel: Boolean)(using Async): T = - val collector = Collector[T](fs*) + private inline def awaitFirstImpl(withCancel: Boolean)(using Async): T = + val collector = Collector(fs*) @scala.annotation.tailrec def loop(attempt: Int): T = - val fut: Future[T]^{fs*} = collector.results.read().right.get - fut.awaitResult match + collector.results.read().right.get.awaitResult match case Failure(exception) => if attempt == fs.length then /* everything failed */ throw exception else loop(attempt + 1) case Success(value) => - if withCancel then fs.foreach(_.cancel()) + inline if withCancel then fs.foreach(_.cancel()) value loop(1) end Future @@ -432,11 +430,10 @@ class Task[+T](val body: (Async, AsyncOperations) ?=> T): def run()(using Async, AsyncOperations): T = body /** Start a future computed from the `body` of this task */ - def start()(using async: Async, spawn: Async.Spawn, asyncOps: AsyncOperations) - (using async.type =:= spawn.type): Future[T]^{this, spawn} = + def start()(using async: Async, spawn: Async.Spawn)(using asyncOps: AsyncOperations)(using async.type =:= spawn.type): Future[T]^{body, spawn} = Future(body)(using async, spawn) - def schedule(s: TaskSchedule): Task[T]^{this} = + def schedule(s: TaskSchedule): Task[T]^{body} = s match { case TaskSchedule.Every(millis, maxRepetitions) => assert(millis >= 1)