Skip to content

Commit

Permalink
Merge pull request #4 from softwaremill/mapView
Browse files Browse the repository at this point in the history
Map/filter/collect view on sources
  • Loading branch information
adamw authored Jul 22, 2023
2 parents 07d8b38 + 3afdfdf commit d7e914d
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 8 deletions.
34 changes: 30 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ Source.tick(1.second, "x")
Source.iterate(0)(_ + 1) // natural numbers
```

## Transforming sources
## Transforming sources (eagerly)

Sources can be transformed by receiving values, manipulating them and sending to other channels - this provides the
highest flexibility and allows creating arbitrary channel topologies.
Expand All @@ -355,9 +355,13 @@ scoped {
}
```

The `.map` needs to be run within a scope, as it starts a new virtual thread (using `fork`), which received values from
the given source, applies the given function and sends the result to the new channel, which is then returned to the
user.
The `.map` needs to be run within a scope, as it starts a new virtual thread (using `fork`), which:

* immediately starts receiving values from the given source
* applies the given function
* sends the result to the new channel

The new channel is returned to the user as the return value of `.map`.

Some other available combinators include `.filter`, `.take`, `.zip(otherSource)`, `.merge(otherSource)` etc.

Expand Down Expand Up @@ -386,6 +390,28 @@ these channels by default is 0 (unbuffered). This can be overridden by providing
(v: Source[Int]).map(_ + 1)(using StageCapacity(10))
```

## Transforming sources (lazily)

A limited number of transformations can be applied to a source without creating a new channel and a new fork, which
computes the transformation. These include: `.mapAsView`, `.filterAsView` and `.collectAsView`.

For example:

```scala
import ox.scoped
import ox.channels.{Channel, Source}

val c = Channel[String]()
val c2: Source[Int] = c.mapAsView(s => s.length())
```

The mapping function (`s => s.length()`) will only be invoked when the source is consumed (using `.receive()`
or `select`), on the calling thread. This is in contrast to `.map`, where the mapping function is invoked on a separate
fork.

Hence, creating views doesn't need to be run within a scope, and creating the view itself doesn't consume any elements
from the source on which it is run.

## Discharging channels

Values of a source can be terminated using methods such as `.foreach`, `.toList`, `.pipeTo` or `.drain`. These methods
Expand Down
38 changes: 35 additions & 3 deletions core/src/main/scala/ox/channels/Cell.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@ package ox.channels

import java.util.concurrent.ArrayBlockingQueue
import java.util.concurrent.atomic.AtomicBoolean
import scala.util.control.NonFatal

// a lazily-created, optional result - exceptions might be throw when the function is called, hence it should be called
// only on the thread where the value should be received
private[ox] type MaybeCreateResult[T] = () => Option[SelectResult[T]]

private[ox] trait CellCompleter[-T]:
/** Complete the cell with a value. Should only be called if this cell is owned by the calling thread. */
/** Complete the cell with a result. Should only be called if this cell is owned by the calling thread. */
def complete(t: SelectResult[T]): Unit

/** Complete the cell with a lazily-created, optional result. Should only be called if this cell is owned by the calling thread. */
def complete(t: MaybeCreateResult[T]): Unit

/** Complete the cell with a new completer. Should only be called if this cell is owned by the calling thread. */
def completeWithNewCell(): Unit

Expand All @@ -20,14 +28,38 @@ private[ox] trait CellCompleter[-T]:

private[ox] class Cell[T] extends CellCompleter[T]:
private val isOwned = new AtomicBoolean(false)
private val cell = new ArrayBlockingQueue[SelectResult[T] | Cell[T] | ChannelState.Closed](1)
private val cell = new ArrayBlockingQueue[SelectResult[T] | MaybeCreateResult[T] | Cell[T] | ChannelState.Closed](1)

// each cell should be completed exactly once, so we are not using the blocking capabilities of `cell`;
// using `cell.put` might throw an interrupted exception, which might cause a deadlock (as there's a thread awaiting a
// cell's completion on its own interrupt - see cellTakeInterrupted); hence, using `.add`.
override def complete(t: SelectResult[T]): Unit = cell.add(t)
override def complete(t: MaybeCreateResult[T]): Unit = cell.add(t)
override def completeWithNewCell(): Unit = cell.add(Cell[T])
override def completeWithClosed(s: ChannelState.Closed): Unit = cell.add(s)
override def tryOwn(): Boolean = isOwned.compareAndSet(false, true)
def take(): SelectResult[T] | Cell[T] | ChannelState.Closed = cell.take()
def take(): SelectResult[T] | MaybeCreateResult[T] | Cell[T] | ChannelState.Closed = cell.take()
def isAlreadyOwned: Boolean = isOwned.get()

/** Linked cells are created when creating CollectSources. */
private[ox] class LinkedCell[T, U](linkedTo: CellCompleter[U], f: T => Option[U], createReceived: U => Source[U]#Received)
extends CellCompleter[T] {
override def complete(t: SelectResult[T]): Unit =
t match
case r: Source[T]#Received => linkedTo.complete(() => f(r.value).map(createReceived)) // f might throw exceptions, making lazy
case _ => throw new IllegalStateException() // linked cells can only be created from sources
override def complete(t: MaybeCreateResult[T]): Unit =
linkedTo.complete { () =>
t() match
case Some(r: Source[T]#Received) => f(r.value).map(createReceived)
case Some(_) => throw new IllegalStateException() // linked cells can only be created from sources
case _ => None
}
override def completeWithNewCell(): Unit = linkedTo.completeWithNewCell()
override def completeWithClosed(s: ChannelState.Closed): Unit = linkedTo.completeWithClosed(s)
override def tryOwn(): Boolean = linkedTo.tryOwn()

// for Source/Sink cell cleanup
override def equals(obj: Any): Boolean = linkedTo.equals(obj)
override def hashCode(): Int = linkedTo.hashCode()
}
10 changes: 10 additions & 0 deletions core/src/main/scala/ox/channels/Channel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,16 @@ class BufferedChannel[T](capacity: Int = 1) extends Channel[T]:
else c.complete(Received(w)) // sending the element
drainWaitingReceivesWhenDone()

class CollectSource[T, U](s: Source[T], f: T => Option[U]) extends Source[U]:
@tailrec final override def receive(): U | ChannelClosed = select(List(s.receiveClause)).map(_.value).map(f) match
case Some(u) => u
case None => receive()
case c: ChannelClosed => c
override private[ox] def receiveCellOffer(c: CellCompleter[U]): Unit = s.receiveCellOffer(createLinkedCell(c))
override private[ox] def receiveCellCleanup(c: CellCompleter[U]): Unit = s.receiveCellCleanup(createLinkedCell(c))
override private[ox] def trySatisfyWaiting(): Unit | ChannelClosed = s.trySatisfyWaiting()
private def createLinkedCell(c: CellCompleter[U]): CellCompleter[T] = LinkedCell(c, f, u => Received(u))

object Channel:
def apply[T](capacity: Int = 0): Channel[T] = if capacity == 0 then DirectChannel() else BufferedChannel(capacity)

Expand Down
8 changes: 8 additions & 0 deletions core/src/main/scala/ox/channels/SourceOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ import ox.*
import scala.concurrent.duration.FiniteDuration

trait SourceOps[+T] { this: Source[T] =>
// view ops (lazy)

def mapAsView[U](f: T => U): Source[U] = CollectSource(this, t => Some(f(t)))
def filterAsView(f: T => Boolean): Source[T] = CollectSource(this, t => if f(t) then Some(t) else None)
def collectAsView[U](f: PartialFunction[T, U]): Source[U] = CollectSource(this, f.lift)

// run ops (eager)

def map[U](f: T => U)(using Ox, StageCapacity): Source[U] =
val c2 = Channel[U](summon[StageCapacity].toInt)
fork {
Expand Down
13 changes: 12 additions & 1 deletion core/src/main/scala/ox/channels/select.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,25 @@ private def doSelect[T](clauses: List[SelectClause[T]]): SelectResult[T] | Chann
// completed with a value; interrupting self and returning it
try t
finally Thread.currentThread().interrupt()
case t: MaybeCreateResult[T] @unchecked =>
try
t() match
case Some(r) => r
case None => throw e
finally Thread.currentThread().interrupt()

def takeFromCellInterruptSafe(c: Cell[T]): SelectResult[T] | ChannelClosed =
try
c.take() match
case c2: Cell[T] @unchecked => offerCellAndTake(c2) // we got a new cell on which we should be waiting, add it to the channels
case s: ChannelState.Error => ChannelClosed.Error(s.reason)
case ChannelState.Done => doSelect(clauses)
case t: SelectResult[T] @unchecked => t
case t: SelectResult[T] @unchecked => t
case t: MaybeCreateResult[T] @unchecked =>
// this might throw exceptions, but this is fine - we're on the thread that called select
t() match
case Some(r) => r
case None => doSelect(clauses)
catch case e: InterruptedException => cellTakeInterrupted(c, e)
// now that the cell has been filled, it is owned, and should be removed from the waiting lists of the other channels
finally cleanupCell(c, alsoWhenSingleClause = false)
Expand Down
137 changes: 137 additions & 0 deletions core/src/test/scala/ox/channels/SourceOpsAsViewTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package ox.channels

import org.scalatest.concurrent.Eventually
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*

import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters.*
import scala.util.{Failure, Try}

class SourceOpsAsViewTest extends AnyFlatSpec with Matchers with Eventually {
it should "map over a source as a view" in {
val c: Channel[Int] = Channel()

scoped {
fork {
c.send(10)
c.send(20)
c.send(30)
c.done()
}

val s2 = c.mapAsView(_ + 1)
s2.receive() shouldBe 11
s2.receive() shouldBe 21
s2.receive() shouldBe 31
s2.receive() shouldBe ChannelClosed.Done
}
}

it should "select from sources mapped as view" in {
val c1: Channel[Int] = Channel()
val c2: Channel[Int] = Channel()

scoped {
fork {
c1.send(10)
c1.send(20)
c1.send(30)
c1.done()
}

fork {
c2.send(100)
c2.send(200)
c2.send(300)
c2.done()
}

val s1 = c1.mapAsView(_ + 1)
val s2 = c2.mapAsView(_ + 1)

(for (_ <- 1 to 7) yield select(s1.receiveClause, s2.receiveClause).map(_.value)).toSet shouldBe Set(
101,
201,
301,
11,
21,
31,
ChannelClosed.Done
)
}
}

it should "filter over a source as a view" in {
val c: Channel[Int] = Channel()

scoped {
fork {
c.send(1)
c.send(2)
c.send(3)
c.send(4)
c.done()
}

val s2 = c.filterAsView(_ % 2 == 0)
s2.receive() shouldBe 2
s2.receive() shouldBe 4
s2.receive() shouldBe ChannelClosed.Done
}
}

it should "select from sources filtered as a view" in {
val c1: Channel[Int] = Channel()
val c2: Channel[Int] = Channel()

scoped {
fork {
c1.send(1)
c1.send(2)
c1.send(3)
c1.send(4)
c1.done()
}

fork {
c2.send(11)
c2.send(12)
c2.send(13)
c2.send(14)
c2.done()
}

val s1 = c1.filterAsView(_ % 2 == 0)
val s2 = c2.filterAsView(_ % 2 == 0)

(for (_ <- 1 to 5) yield select(s1.receiveClause, s2.receiveClause).map(_.value)).toSet shouldBe Set(2, 4, 12, 14, ChannelClosed.Done)
}
}

it should "propagate exceptions to the calling select" in {
val c: Channel[Int] = Channel()

scoped {
fork {
c.send(1)
c.send(2)
c.send(3)
c.send(4)
c.done()
}

val c1 = Channel(); c1.done()
val s2 = c.filterAsView(v => if v % 2 == 0 then true else throw new RuntimeException("test"))

Try(select(c1.receiveClause, s2.receiveClause)) should matchPattern { case Failure(e) if e.getMessage == "test" => }
select(c1.receiveClause, s2.receiveClause).map(_.value) shouldBe 2
Try(select(c1.receiveClause, s2.receiveClause)) should matchPattern { case Failure(e) if e.getMessage == "test" => }
select(c1.receiveClause, s2.receiveClause).map(_.value) shouldBe 4
select(c1.receiveClause, s2.receiveClause) shouldBe ChannelClosed.Done
}
}
}

0 comments on commit d7e914d

Please sign in to comment.