Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flatten to Source[T] #198

Merged
merged 7 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions core/src/main/scala/ox/channels/SourceOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,39 @@ trait SourceOps[+T] { outer: Source[T] =>
}
c

/** Pipes the elements of child sources into the output source. If the parent source or any of the child sources emit an error, the
* pulling stops and the output source emits the error.
*/
def flatten[U](using Ox, StageCapacity, T <:< Source[U]): Source[U] = {
val c2 = StageCapacity.newChannel[U]

forkPropagate(c2) {
var pool = List[Source[T] | Source[U]](this)
repeatWhile {
selectOrClosed(pool) match {
Copy link
Contributor Author

@nimatrueway nimatrueway Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that we can not know the result of select stems from which select statement is limiting. In Golang there's a hacky way to do it.
https://stackoverflow.com/a/19992525/1556045

I think both (v: V) or (c: ChannelClosed) results if they hint that which select/source they stem from would make the API more robust.

Update: created a ticket for it as Adam suggested #201

case ChannelClosed.Done =>
// TODO: best to remove the specific channel that signalled to be Done
pool = pool.filterNot(_.isClosedForReceiveDetail.contains(ChannelClosed.Done))
if pool.isEmpty then
c2.doneOrClosed()
false
else true
case ChannelClosed.Error(e) =>
c2.errorOrClosed(e)
false
// TODO: we might go too deep and pull from non immediate children of the parent source
case t: Source[U] @unchecked =>
pool = t :: pool
true
case r: U @unchecked =>
c2.sendOrClosed(r).isValue
}
}
}

c2
}

/** Concatenates this source with the `other` source. The resulting source will emit elements from this source first, and then from the
* `other` source.
*
Expand Down
165 changes: 165 additions & 0 deletions core/src/test/scala/ox/channels/SourceOpsFlattenTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package ox.channels

import org.scalatest.OptionValues
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*

import java.util.concurrent.CountDownLatch
import scala.collection.mutable.ListBuffer

class SourceOpsFlattenTest extends AnyFlatSpec with Matchers with OptionValues {

"flatten" should "pipe all elements of the child sources into the output source" in {
supervised {
val source = Source.fromValues(
Source.fromValues(10),
Source.fromValues(20, 30),
Source.fromValues(40, 50, 60)
)
source.flatten.toList should contain theSameElementsAs List(10, 20, 30, 40, 50, 60)
}
}

it should "handle empty source" in {
supervised {
val source = Source.empty[Source[Int]]
source.flatten.toList should contain theSameElementsAs Nil
}
}

it should "handle singleton source" in {
supervised {
val source = Source.fromValues(Source.fromValues(10))
source.flatten.toList should contain theSameElementsAs List(10)
}
}

it should "pipe elements realtime" in {
supervised {
val source = Channel.bufferedDefault[Source[Int]]
val lockA = CountDownLatch(1)
val lockB = CountDownLatch(1)
source.send(Source.fromValues(10))
source.send {
val subSource = Channel.bufferedDefault[Int]
subSource.send(20)
forkUnsupervised {
lockA.await() // 30 won't be added until, lockA is released after 20 consumption
subSource.send(30)
subSource.done()
}
subSource
}
forkUnsupervised {
lockB.await() // 40 won't be added until, lockB is released after 30 consumption
source.send(Source.fromValues(40))
source.done()
}

val collected = ListBuffer[Int]()
source.flatten.foreachOrError { e =>
collected += e
if e == 20 then lockA.countDown()
else if e == 30 then lockB.countDown()
}
collected should contain theSameElementsAs List(10, 20, 30, 40)
}
}

it should "propagate error of any of the child sources and stop piping" in {
supervised {
val child1 = Channel.rendezvous[Int]
val lock = CountDownLatch(1)
fork {
child1.send(10)
// wait for child2 to emit an error
lock.await()
// `flatten` will not receive this, as it will be short-circuited by the error
child1.sendOrClosed(30)
}
val child2 = Channel.rendezvous[Int]
fork {
child2.send(20)
child2.error(new Exception("intentional failure"))
lock.countDown()
}
val source = Source.fromValues(child1, child2)

val (collectedElems, collectedError) = source.flatten.toPartialList()
collectedError.value.getMessage shouldBe "intentional failure"
collectedElems should contain theSameElementsAs List(10, 20)
child1.receive() shouldBe 30
}
}

it should "propagate error of the parent source and stop piping" in {
supervised {
val child1 = Channel.rendezvous[Int]
val lock = CountDownLatch(1)
fork {
child1.send(10)
lock.countDown()
// depending on how quick it picks up the error from the parent
// `flatten` may or may not receive this
child1.send(20)
child1.done()
}
val source = Channel.rendezvous[Source[Int]]
fork {
source.send(child1)
// make sure the first element of child1 is consumed before emitting error
lock.await()
source.error(new Exception("intentional failure"))
}

val (collectedElems, collectedError) = source.flatten.toPartialList()
collectedError.value.getMessage shouldBe "intentional failure"
collectedElems should contain atLeastOneElementOf List(10, 20)
}
}

it should "stop pulling from the sources when the receiver is closed" in {
val child1 = Channel.rendezvous[Int]

Thread.startVirtualThread(() => {
child1.send(10)
// at this point `flatten` channel is closed
// so although `flatten` thread receives "20" element
// it can not push it to its output channel and it will be lost
child1.send(20)
child1.send(30)
child1.done()
})

supervised {
val source = Source.fromValues(child1)
val flattenSource = {
implicit val capacity: StageCapacity = StageCapacity(0)
source.flatten
}
flattenSource.receive() shouldBe 10
}

child1.receiveOrClosed() shouldBe 30
child1.receiveOrClosed() shouldBe ChannelClosed.Done
}

extension [T](source: Source[T]) {
def toPartialList(cb: T | Throwable => Unit = (_: Any) => ()): (List[T], Option[Throwable]) = {
val elementCapture = ListBuffer[T]()
var errorCapture = Option.empty[Throwable]
try {
for (t <- source) {
cb(t)
elementCapture += t
}
} catch {
case ChannelClosedException.Error(e) =>
cb(e)
errorCapture = Some(e)
}
(elementCapture.toList, errorCapture)
}
}
}