Skip to content

Commit

Permalink
Introduce an UnsupervisedFork trait, make .joinEither available only …
Browse files Browse the repository at this point in the history
…for unsupervised forks (#178)
  • Loading branch information
adamw authored Jul 11, 2024
1 parent c61beae commit 6328a87
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 41 deletions.
42 changes: 22 additions & 20 deletions core/src/main/scala/ox/fork.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def forkError[E, F[_], T](using OxError[E, F])(f: => F[T]): Fork[T] =
// completing the result; any joins will end up being interrupted
if !supervisor.forkException(e) then result.completeExceptionally(e).discard
}
newForkUsingResult(result)
new ForkUsingResult(result) {}

/** Starts a fork (logical thread of execution), which is guaranteed to complete before the enclosing [[supervised]] or [[supervisedError]]
* block completes.
Expand Down Expand Up @@ -94,7 +94,7 @@ def forkUserError[E, F[_], T](using OxError[E, F])(f: => F[T]): Fork[T] =
case e: Throwable =>
if !supervisor.forkException(e) then result.completeExceptionally(e).discard
}
newForkUsingResult(result)
new ForkUsingResult(result) {}

/** Starts a fork (logical thread of execution), which is guaranteed to complete before the enclosing [[supervised]], [[supervisedError]] or
* [[unsupervised]] block completes.
Expand All @@ -105,13 +105,13 @@ def forkUserError[E, F[_], T](using OxError[E, F])(f: => F[T]): Fork[T] =
*
* For alternate behaviors, see [[fork]], [[forkUser]] and [[forkCancellable]].
*/
def forkUnsupervised[T](f: => T)(using OxUnsupervised): Fork[T] =
def forkUnsupervised[T](f: => T)(using OxUnsupervised): UnsupervisedFork[T] =
val result = new CompletableFuture[T]()
summon[OxUnsupervised].scope.fork { () =>
try result.complete(f)
catch case e: Throwable => result.completeExceptionally(e)
}
newForkUsingResult(result)
new ForkUsingResult(result) with UnsupervisedFork[T] {}

/** For each thunk in the given sequence, starts a fork using [[fork]]. All forks are guaranteed to complete before the enclosing
* [[supervised]] or [[unsupervised]] block completes.
Expand Down Expand Up @@ -159,9 +159,7 @@ def forkCancellable[T](f: => T)(using OxUnsupervised): CancellableFork[T] =
done.acquire()
}
}
new CancellableFork[T]:
override def join(): T = unwrapExecutionException(result.get())

new ForkUsingResult(result) with CancellableFork[T]:
override def cancel(): Either[Throwable, T] =
cancelNow()
try Right(result.get())
Expand All @@ -178,10 +176,7 @@ def forkCancellable[T](f: => T)(using OxUnsupervised): CancellableFork[T] =
if !started.getAndSet(true)
then result.completeExceptionally(new InterruptedException("fork was cancelled before it started")).discard

override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean =
result.isCompletedExceptionally && (result.exceptionNow() eq ie)

private def newForkUsingResult[T](result: CompletableFuture[T]): Fork[T] = new Fork[T]:
private trait ForkUsingResult[T](result: CompletableFuture[T]) extends Fork[T]:
override def join(): T = unwrapExecutionException(result.get())
override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean =
result.isCompletedExceptionally && (result.exceptionNow() eq ie)
Expand All @@ -207,9 +202,18 @@ trait Fork[T]:
*/
def join(): T

/** Blocks until the fork completes with a result. Only makes sense in for unsupervised forks, that is when the fork is started using
* [[forkUnsupervised]] or [[forkCancellable]]; otherwise a thrown exception causes the scope to end, and is re-thrown by the
* [[supervised]] block.
private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean

object Fork:
/** A dummy pretending to represent a fork which successfully completed with the given value. */
def successful[T](value: T): Fork[T] = UnsupervisedFork.successful(value)

/** A dummy pretending to represent a fork which failed with the given exception. */
def failed[T](e: Throwable): Fork[T] = UnsupervisedFork.failed(e)

trait UnsupervisedFork[T] extends Fork[T]:
/** Blocks until the fork completes with a result. If the fork failed with an exception, this exception is not thrown, but returned as a
* `Left`.
*
* @throws InterruptedException
* If the join is interrupted.
Expand All @@ -222,21 +226,19 @@ trait Fork[T]:
case e: InterruptedException => if wasInterruptedWith(e) then Left(e) else throw e
case NonFatal(e) => Left(e)

private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean

object Fork:
object UnsupervisedFork:
/** A dummy pretending to represent a fork which successfully completed with the given value. */
def successful[T](value: T): Fork[T] = new Fork[T]:
def successful[T](value: T): UnsupervisedFork[T] = new UnsupervisedFork[T]:
override def join(): T = value
override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = false

/** A dummy pretending to represent a fork which failed with the given exception. */
def failed[T](e: Throwable): Fork[T] = new Fork[T]:
def failed[T](e: Throwable): UnsupervisedFork[T] = new UnsupervisedFork[T]:
override def join(): T = throw e
override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = e eq ie

/** A fork started using [[forkCancellable]], backed by a (virtual) thread. */
trait CancellableFork[T] extends Fork[T]:
trait CancellableFork[T] extends UnsupervisedFork[T]:
/** Interrupts the fork, and blocks until it completes with a result. */
def cancel(): Either[Throwable, T]

Expand Down
38 changes: 20 additions & 18 deletions core/src/test/scala/ox/SupervisedTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,27 +127,29 @@ class SupervisedTest extends AnyFlatSpec with Matchers {
it should "handle interruption of multiple forks with `joinEither` correctly" in {
val e = intercept[Exception] {
supervised {
def computation(withException: Option[String]): Int = {
withException match
case None => 1
case Some(value) =>
throw new Exception(value)
}

val fork1 = fork:
computation(withException = None)
val fork2 = fork:
computation(withException = Some("Oh no!"))
val fork3 = fork:
computation(withException = Some("Oh well.."))

fork1.joinEither() // 1
fork2.joinEither() // 2
fork3.joinEither() // 3
// first, starting a fork which will sleep in the background, and which is unsupervised, so that we can .joinEither()
val f1 = forkUnsupervised:
sleep(1.second)
10

// forking a supervised fork, which throws an exception and causes the scope to end
val f2 = fork:
throw new Exception("oh no!")

// this joinEither() might be interrupted because the scope ends, or it might obtain the interrupted exception,
// because `f` is interrupted as well
f1.joinEither()

// either the previous operation should throw an IE, or the thread should become interrupted while joining
f2.join()

// getting here means that we managed to catch the IE for the main body
fail("scope body should be interrupted")
}
}

e.getMessage should startWith("Oh")
// the exception that caused the scope to end should be re-thrown
e.getMessage shouldBe "oh no!"
}

}
2 changes: 1 addition & 1 deletion core/src/test/scala/ox/resilience/BackoffRetryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.scalatest.matchers.should.Matchers
import org.scalatest.{EitherValues, TryValues}
import ox.ElapsedTime
import ox.resilience.*
import ox.scheduling.{Jitter, Schedule}
import ox.scheduling.Jitter

import scala.concurrent.duration.*

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package ox.scheduling
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.{EitherValues, TryValues}
import ox.{ElapsedTime, sleep}
import ox.ElapsedTime

import scala.concurrent.duration.*

Expand Down
1 change: 0 additions & 1 deletion core/src/test/scala/ox/scheduling/JitterTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package ox.scheduling
import org.scalatest.Inspectors
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.scheduling.{Jitter, Schedule}

import scala.concurrent.duration.*

Expand Down

0 comments on commit 6328a87

Please sign in to comment.