Skip to content

Commit

Permalink
Support for filter/collect
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Jul 22, 2023
1 parent e992c64 commit 25533a2
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
14 changes: 8 additions & 6 deletions core/src/main/scala/ox/channels/Cell.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ private[ox] class Cell[T] extends CellCompleter[T]:
def take(): SelectResult[T] | Cell[T] | ChannelState.Closed = cell.take()
def isAlreadyOwned: Boolean = isOwned.get()

/** Linked cells are created when creating MappedSources. */
private[ox] class LinkedCell[T, U](linkedTo: CellCompleter[U], f: T => U, createReceived: U => Source[U]#Received)
/** Linked cells are created when creating CollectSources. */
private[ox] class LinkedCell[T, U](linkedTo: CellCompleter[U], f: T => Option[U], createReceived: U => Source[U]#Received)
extends CellCompleter[T] {
override def complete(t: SelectResult[T]): Unit =
val u: SelectResult[U] = t match
case r: Source[T]#Received => createReceived(f(r.value)) // TODO exceptions
case _ => throw new IllegalStateException() // linked cells can only be created from sources
linkedTo.complete(u)
t match
case r: Source[T]#Received =>
f(r.value) match // TODO exceptions
case Some(u) => linkedTo.complete(createReceived(u))
case None => linkedTo.completeWithNewCell() // ignoring the received value
case _ => throw new IllegalStateException() // linked cells can only be created from sources
override def completeWithNewCell(): Unit = linkedTo.completeWithNewCell()
override def completeWithClosed(s: ChannelState.Closed): Unit = linkedTo.completeWithClosed(s)
override def tryOwn(): Boolean = linkedTo.tryOwn()
Expand Down
7 changes: 5 additions & 2 deletions core/src/main/scala/ox/channels/Channel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,11 @@ class BufferedChannel[T](capacity: Int = 1) extends Channel[T]:
else c.complete(Received(w)) // sending the element
drainWaitingReceivesWhenDone()

class MappedSource[T, U](s: Source[T], f: T => U) extends Source[U]:
override def receive(): U | ChannelClosed = select(List(s.receiveClause)).map(_.value).map(f)
class CollectSource[T, U](s: Source[T], f: T => Option[U]) extends Source[U]:
@tailrec final override def receive(): U | ChannelClosed = select(List(s.receiveClause)).map(_.value).map(f) match
case Some(u) => u
case None => receive()
case c: ChannelClosed => c
override private[ox] def receiveCellOffer(c: CellCompleter[U]): Unit = s.receiveCellOffer(createLinkedCell(c))
override private[ox] def receiveCellCleanup(c: CellCompleter[U]): Unit = s.receiveCellCleanup(createLinkedCell(c))
override private[ox] def trySatisfyWaiting(): Unit | ChannelClosed = s.trySatisfyWaiting()
Expand Down
10 changes: 8 additions & 2 deletions core/src/main/scala/ox/channels/SourceOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ import ox.*
import scala.concurrent.duration.FiniteDuration

trait SourceOps[+T] { this: Source[T] =>
// view ops (lazy)

def mapAsView[U](f: T => U): Source[U] = CollectSource(this, t => Some(f(t)))
def filterAsView(f: T => Boolean): Source[T] = CollectSource(this, t => if f(t) then Some(t) else None)
def collectAsView[U](f: PartialFunction[T, U]): Source[U] = CollectSource(this, f.lift)

// run ops (eager)

def map[U](f: T => U)(using Ox, StageCapacity): Source[U] =
val c2 = Channel[U](summon[StageCapacity].toInt)
fork {
Expand All @@ -24,8 +32,6 @@ trait SourceOps[+T] { this: Source[T] =>
}
c2

def mapView[U](f: T => U): Source[U] = MappedSource(this, f)

def take(n: Int)(using Ox, StageCapacity): Source[T] = transform(_.take(n))

def filter(f: T => Boolean)(using Ox, StageCapacity): Source[T] = transform(_.filter(f))
Expand Down

0 comments on commit 25533a2

Please sign in to comment.