Skip to content

Commit

Permalink
Simplify forks
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Jul 24, 2023
1 parent 71e8615 commit 233fd7e
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 104 deletions.
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,8 @@ Scopes can be arbitrarily nested.

### Error handling

Any unhandled exceptions that are thrown in a `fork` block are propagated to the scope's main thread, by interrupting
it and re-throwing the exception there. Hence, any failed fork will cause the entire scope's computation to be
interrupted, and unless interruptions are intercepted, all other forks will get interrupted as well.

On the other hand, `forkHold` doesn't propagate any exceptions but retains them. The result of the fork must be
explicitly inspected to discover, if the computation failed or succeeded, e.g. using the `Fork.join` method.
If a fork fails with an exception, the `Fork.join` method will throw that exception. If there's no join and the fork
fails, the exception might go unnoticed.

## Scoped values

Expand Down
1 change: 0 additions & 1 deletion core/src/main/scala/ox/Ox.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import scala.util.control.NoStackTrace
case class Ox(
scope: StructuredTaskScope[Any],
scopeThread: Thread,
forkFailureToPropagate: AtomicReference[Throwable],
finalizers: AtomicReference[List[() => Unit]]
):
private[ox] def addFinalizer(f: () => Unit): Unit = finalizers.updateAndGet(f :: _)
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/ox/control.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def retry[T](times: Int, sleep: FiniteDuration)(f: => T): T =

def uninterruptible[T](f: => T): T =
scoped {
val t = forkHold(f)
val t = fork(f)

def joinDespiteInterrupted: T =
try t.join()
Expand Down
27 changes: 5 additions & 22 deletions core/src/main/scala/ox/fork.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,9 @@ import scala.util.Try

/** Starts a thread, which is guaranteed to complete before the enclosing [[scoped]] block exits.
*
* Exceptions are propagated. In case an exception is thrown while evaluating `t`, the enclosing scope's main thread is interrupted and the
* exception is re-thrown there.
* In case an exception is thrown while evaluating `t`, it will be thrown when calling the returned [[Fork]]'s `.join()` method.
*/
def fork[T](f: => T)(using Ox): Fork[T] = forkHold {
try f
catch
// not propagating interrupts, as these are not failures coming from evaluating `f` itself
case e: InterruptedException => throw e
case e: Throwable =>
val old = summon[Ox].forkFailureToPropagate.getAndSet(e) // TODO: only the last failure is propagated
if (old == null) summon[Ox].scopeThread.interrupt()
throw e
}

/** Starts a thread, which is guaranteed to complete before the enclosing [[scoped]] block exits.
*
* Exceptions are held. In case an exception is thrown while evaluating `t`, it will be thrown when calling the returned [[Fork]]'s
* `.join()` method. The exception is **not** propagated to the enclosing scope's main thread, like in the case of [[fork]].
*/
def forkHold[T](f: => T)(using Ox): Fork[T] =
def fork[T](f: => T)(using Ox): Fork[T] =
val result = new CompletableFuture[T]()
val forkFuture = summon[Ox].scope.fork { () =>
try result.complete(f)
Expand All @@ -43,8 +26,8 @@ def forkHold[T](f: => T)(using Ox): Fork[T] =
case e: ExecutionException => Left(e.getCause)
case e: Throwable => Left(e)

def forkAllHold[T](fs: Seq[() => T])(using Ox): Fork[Seq[T]] =
val forks = fs.map(f => forkHold(f()))
def forkAll[T](fs: Seq[() => T])(using Ox): Fork[Seq[T]] =
val forks = fs.map(f => fork(f()))
new Fork[Seq[T]]:
override def join(): Seq[T] = forks.map(_.join())
override def cancel(): Either[Throwable, Seq[T]] =
Expand All @@ -53,7 +36,7 @@ def forkAllHold[T](fs: Seq[() => T])(using Ox): Fork[Seq[T]] =
then Left(results.collectFirst { case Left(e) => e }.get)
else Right(results.collect { case Right(t) => t })

/** A running fork, started using [[fork]] or [[forkHold]], backend by a thread. */
/** A running fork, started using [[fork]] or [[fork]], backend by a thread. */
trait Fork[T]:
/** Blocks until the fork completes with a result. Throws an exception, if the fork completed with an exception. */
def join(): T
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/ox/race.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def timeout[T](duration: FiniteDuration)(t: => T): T =
def raceSuccess[T](fs: Seq[() => T]): T =
scoped {
val result = new ArrayBlockingQueue[Try[T]](fs.size)
fs.foreach(f => forkHold(result.put(Try(f()))))
fs.foreach(f => fork(result.put(Try(f()))))

@tailrec
def takeUntilSuccess(firstException: Option[Throwable], left: Int): T =
Expand Down
17 changes: 3 additions & 14 deletions core/src/main/scala/ox/scoped.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,8 @@ import java.util.concurrent.atomic.AtomicReference

private class DoNothingScope[T] extends StructuredTaskScope[T](null, Thread.ofVirtual().factory()) {}

/** Any child forks are interrupted after `f` completes. */
/** Any child forks are interrupted after `f` completes. The method only completes when all child forks have completed. */
def scoped[T](f: Ox ?=> T): T =
val forkFailure = new AtomicReference[Throwable]()

// only propagating if the main scope thread was interrupted (presumably because of a supervised child fork failing)
def handleInterrupted(e: InterruptedException) = forkFailure.get() match
case null => throw e
case t =>
t.addSuppressed(e)
throw t

def throwWithSuppressed(es: List[Throwable]): Nothing =
val e = es.head
es.tail.foreach(e.addSuppressed)
Expand All @@ -43,13 +34,11 @@ def scoped[T](f: Ox ?=> T): T =
try
val t =
try
try f(using Ox(scope, Thread.currentThread(), forkFailure, finalizers))
catch case e: InterruptedException => handleInterrupted(e)
try f(using Ox(scope, Thread.currentThread(), finalizers))
finally
scope.shutdown()
scope.join()
// .join might have been interrupted, because of a fork failing after f completes, including shutdown
catch case e: InterruptedException => handleInterrupted(e)
// join might have been interrupted
finally scope.close()

// running the finalizers only once we are sure that all child threads have been terminated, so that no new
Expand Down
1 change: 0 additions & 1 deletion core/src/main/scala/ox/syntax.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ object syntax:
def retry(times: Int, sleep: FiniteDuration): T = ox.retry(times, sleep)(f)

extension [T](f: => T)(using Ox)
def forkHold: Fork[T] = ox.forkHold(f)
def fork: Fork[T] = ox.fork(f)
def timeout(duration: FiniteDuration): T = ox.timeout(duration)(f)
def scopedWhere[U](fl: ForkLocal[U], u: U): T = fl.scopedWhere(u)(f)
Expand Down
61 changes: 10 additions & 51 deletions core/src/test/scala/ox/ForkTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ import scala.concurrent.duration.DurationInt
class ForkTest extends AnyFlatSpec with Matchers {
class CustomException extends RuntimeException

"forkHold" should "run two forks concurrently" in {
"fork" should "run two forks concurrently" in {
val trail = Trail()
scoped {
val f1 = forkHold {
val f1 = fork {
Thread.sleep(500)
trail.add("f1 complete")
5
}
val f2 = forkHold {
val f2 = fork {
Thread.sleep(1000)
trail.add("f2 complete")
6
Expand All @@ -37,8 +37,8 @@ class ForkTest extends AnyFlatSpec with Matchers {
it should "allow nested forks" in {
val trail = Trail()
scoped {
val f1 = forkHold {
val f2 = forkHold {
val f1 = fork {
val f2 = fork {
Thread.sleep(1000)
trail.add("f2 complete")
6
Expand All @@ -65,12 +65,12 @@ class ForkTest extends AnyFlatSpec with Matchers {
Thread.sleep(1000)
trail.add("f2 complete")
6
}.forkHold
}.fork

Thread.sleep(500)
trail.add("f1 complete")
5 + f2.join()
}.forkHold
}.fork

trail.add("main mid")
trail.add(s"result = ${f1.join()}")
Expand All @@ -82,8 +82,8 @@ class ForkTest extends AnyFlatSpec with Matchers {
it should "interrupt child forks when parents complete" in {
val trail = Trail()
scoped {
val f1 = forkHold {
forkHold {
val f1 = fork {
fork {
try
Thread.sleep(1000)
trail.add("f2 complete")
Expand All @@ -108,50 +108,9 @@ class ForkTest extends AnyFlatSpec with Matchers {

it should "throw the exception thrown by a joined fork" in {
val trail = Trail()
try scoped(forkHold(throw new CustomException()).join())
catch case e: Exception => trail.add(e.getClass.getSimpleName)

trail.trail shouldBe Vector("CustomException")
}

"fork" should "propagate failures to the scope thread" in {
val trail = Trail()
try
scoped {
val f1 = fork {
Thread.sleep(2000)
trail.add("f1 done")
}

val f2 = fork {
Thread.sleep(1000)
throw new CustomException
}

f1.join()
f2.join()
}
try scoped(fork(throw new CustomException()).join())
catch case e: Exception => trail.add(e.getClass.getSimpleName)

trail.trail shouldBe Vector("CustomException")
}

it should "not propagate interrupt exceptions" in {
val trail = Trail()
try
scoped {
fork {
Thread.sleep(2000)
trail.add("f1 done")
}

trail.add("main done")
}
catch case e: Exception => trail.add(e.getClass.getSimpleName)

// child should be interrupted, but the error shouldn't propagate
trail.trail shouldBe Vector("main done")
}
}


10 changes: 5 additions & 5 deletions core/src/test/scala/ox/LocalTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ class LocalTest extends AnyFlatSpec with Matchers {
val trail = Trail()
val v = ForkLocal("a")
scoped {
val f1 = forkHold {
val f1 = fork {
v.scopedWhere("x") {
Thread.sleep(100L)
trail.add(s"In f1 = ${v.get()}")
}
v.get()
}

val f3 = forkHold {
val f3 = fork {
v.scopedWhere("z") {
Thread.sleep(100L)
forkHold {
fork {
Thread.sleep(100L)
trail.add(s"In f3 = ${v.get()}")
}.join()
Expand All @@ -47,12 +47,12 @@ class LocalTest extends AnyFlatSpec with Matchers {
val trail = Trail()
val v = ForkLocal("a")
scoped {
forkHold {
fork {
v.scopedWhere("x") {
trail.add(s"nested1 = ${v.get()}")

scoped {
forkHold {
fork {
trail.add(s"nested2 = ${v.get()}")
}.join()
}
Expand Down
34 changes: 34 additions & 0 deletions core/src/test/scala/ox/RaceTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,38 @@ class RaceTest extends AnyFlatSpec with Matchers {

trail.trail shouldBe Vector("no timeout", "done")
}

it should "race a slower and faster computation" in {
val trail = Trail()
val start = System.currentTimeMillis()
raceSuccess {
Thread.sleep(1000L)
trail.add("slow")
} {
Thread.sleep(500L)
trail.add("fast")
}
val end = System.currentTimeMillis()

Thread.sleep(1000L)
trail.trail shouldBe Vector("fast")
end - start should be < 1000L
}

it should "race a faster and slower computation" in {
val trail = Trail()
val start = System.currentTimeMillis()
raceSuccess {
Thread.sleep(500L)
trail.add("fast")
} {
Thread.sleep(1000L)
trail.add("slow")
}
val end = System.currentTimeMillis()

Thread.sleep(1000L)
trail.trail shouldBe Vector("fast")
end - start should be < 1000L
}
}
4 changes: 2 additions & 2 deletions examples/src/test/scala/ox/main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import org.slf4j.LoggerFactory
@main def test1 =
val log = LoggerFactory.getLogger("test1")
val r = scoped {
val f1 = forkHold {
val f1 = fork {
Thread.sleep(1000L)
log.info("f1 done")
5
}
val f2 = forkHold {
val f2 = fork {
Thread.sleep(2000L)
log.info("f2 done")
6
Expand Down

0 comments on commit 233fd7e

Please sign in to comment.