Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for mergePreferred #3229

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 218 additions & 0 deletions core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2015,6 +2015,224 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
): Stream[F2, O2] =
that.mergeHaltL(this)

/** Merges two streams with priority given to the first stream.
*
* Internally, this uses two bounded queues (of one element).
* Queue has tryTake() which allows to try a non-blocking read.
* This is used to check for elements on the prioritized queue,
* before a blocking read through racePair() is tried on both
* queues, if no data is available on the prioritized queue.
*/
def mergePreferred[F2[x] >: F[x], O2 >: O](
that: Stream[F2, O2]
)(implicit F: Concurrent[F2]): Stream[F2, O2] = {
val fstream: F2[Stream[F2, O2]] =
for {
interrupt <- F.deferred[Unit]
resultL <- F.deferred[Either[Throwable, Unit]]
resultR <- F.deferred[Either[Throwable, Unit]]
resultQL <- Queue.bounded[F2, Option[Stream[F2, O2]]](1)
resultQR <- Queue.bounded[F2, Option[Stream[F2, O2]]](1)
} yield {

def watchInterrupted(str: Stream[F2, O2]): Stream[F2, O2] =
str.interruptWhen(interrupt.get.attempt)

// action to signal that one stream is finished (by putting a None in it)
def doneAndClose(q: Queue[F2, Option[Stream[F2, O2]]]): F2[Unit] = q.offer(None).void

// action to interrupt the processing of both streams by completing interrupt
val signalInterruption: F2[Unit] = interrupt.complete(()).void

// Read from a stream and (possibly blocking) write to the bounded queue for that stream
def go(s: Stream[F2, O2], q: Queue[F2, Option[Stream[F2, O2]]]): Pull[F2, Nothing, Unit] =
s.pull.uncons
.flatMap {
case Some((hd, tl)) =>
val send = q.offer(Some(Stream.chunk(hd)))
Pull.eval(send) >> go(tl, q)
case None =>
Pull.done
}

def runStream(
s: Stream[F2, O2],
whenDone: Deferred[F2, Either[Throwable, Unit]],
q: Queue[F2, Option[Stream[F2, O2]]]
): F2[Unit] = {
val str = watchInterrupted(go(s, q).stream)
str.compile.drain.attempt
.flatMap {
// signal completion of our side before we will signal interruption,
// to make sure our result is always available to others
case r @ Left(_) =>
whenDone.complete(r) >> signalInterruption
case r @ Right(_) =>
whenDone.complete(r) >> doneAndClose(q)
}
}

// Typedef for the fibres that read from the queues.
// That's contained in the Either returned by racePair()
type FBR = Fiber[F2, Throwable, Option[Stream[F2, O2]]]

// An ADT for tracking state of the two queues.
// The types describe the state, starting with BothActive.
// Next state is either LeftDone or RightDone.
// Final state is BothDone.
// The members of those states store the loosing fibre
// of a racePair()-call, which will be reused during the
// next read.
sealed trait QueuesState
final case class BothActive(v: Option[Either[FBR, FBR]]) extends QueuesState
final case class LeftDone(rFbr: Option[FBR]) extends QueuesState
final case class RightDone(lFbr: Option[FBR]) extends QueuesState
case object BothDone extends QueuesState

// Race the given effects, returning the result of the winner
// plus the still active fibre of the looser
def raceQueues(
lq: F2[Option[Stream[F2, O2]]],
rq: F2[Option[Stream[F2, O2]]]
): F2[(Option[Stream[F2, O2]], Either[FBR, FBR])] =
F.racePair(lq, rq)
.flatMap {
case Left((result, fiber)) =>
result.embedError.map(_ -> fiber.asRight[FBR])
case Right((fiber, result)) =>
result.embedError.map(_ -> fiber.asLeft[FBR])
}

// stream that is generated from pumping out the elements of the queue.
val pumpFromQueue: Stream[F2, O2] =
Stream
.unfoldEval[F2, QueuesState, Stream[F2, O2]](BothActive(None)) { s =>
// Returning None from unfoldEval will stop the stream. If we read a None
// from any queue, we cannot return that but must continue reading on the
// other queue. Thus, we need a method which can be called recursively to
// continue reading in case of None.
def readNext(s: QueuesState): F2[(Option[Stream[F2, O2]], QueuesState)] =
s match {
// The initial state, both queues are active and there are no fibres left over
case BothActive(None) =>
// check available data on left, which would be prioritized
resultQL.tryTake
.flatMap {
_.fold(
// no data available on prioritized queue, race both queues
raceQueues(resultQL.take, resultQR.take)
.flatMap[(Option[Stream[F2, O2]], QueuesState)] {
case (None, Left(fbr)) =>
readNext(RightDone(fbr.some))
case (None, Right(fbr)) =>
readNext(LeftDone(fbr.some))
case (Some(s), fbr) =>
F.pure(s.some -> BothActive(fbr.some))
}
)(os =>
// we read data from the prioritized queue, however, this sill could be a None,
// signalling that queue is done. Handle that:
os.fold(readNext(LeftDone(None)))(ls =>
F.pure(ls.some -> BothActive(None))
)
)
}

// right was looser during the last run
case BothActive(Some(Right(fbr))) =>
// anyway, check for available data on left first, ignoring the incoming fibre for right
resultQL.tryTake
.flatMap(
_.fold(
// use the incoming fibre to read from right queue
raceQueues(resultQL.take, fbr.joinWithNever)
.flatMap[(Option[Stream[F2, O2]], QueuesState)] {
case (None, Left(fbr)) =>
readNext(RightDone(fbr.some))
case (None, Right(fbr)) =>
readNext(LeftDone(fbr.some))
case (Some(s), fbr) =>
F.pure(s.some -> BothActive(fbr.some))
}
)(os =>
// important to reuse the incoming fibre here!
os.fold(readNext(LeftDone(fbr.some)))(ls =>
F.pure(ls.some -> BothActive(fbr.asRight[FBR].some))
)
)
)

// left was looser during the last run
case BothActive(Some(Left(fbr))) =>
// Can't check for available data on left this time,
// because there's an active fibre reading from the left queue.
// Start a race and reuse that fibre for left.
raceQueues(fbr.joinWithNever, resultQR.take)
.flatMap[(Option[Stream[F2, O2]], QueuesState)] {
case (None, Left(fbr)) =>
readNext(RightDone(fbr.some))
case (None, Right(fbr)) =>
readNext(LeftDone(fbr.some))
case (Some(s), fbr) =>
F.pure(s.some -> BothActive(fbr.some))
}

// Left queue is done, but, it's possible we retrieve an active fibre for right.
case LeftDone(fbr) =>
fbr
.map(_.joinWithNever) // join the incoming fibre if given
.getOrElse(resultQR.take) // ordinary take() if no fibre has been given
.map {
case None =>
None -> BothDone
case os =>
os -> LeftDone(None)
}

// mirror case of above
case RightDone(fbr) =>
fbr
.map(_.joinWithNever)
.getOrElse(resultQL.take)
.map {
case None =>
None -> BothDone
case os =>
os -> RightDone(None)
}

// this should never happen, but we need to make the compiler happy
case BothDone =>
F.pure(None -> BothDone)
}

// readNext() returns None in _1 if and only if both queues are done
readNext(s).map {
case (None, _) =>
None // finish the stream (unfoldEval)
case (Some(s), st) =>
(s -> st).some // emit element and new state (unfoldEval)
}
}
.flatten // we have Stream[F2, Stream[F2, O2]] and flatten that to Stream[F2, O2]

val atRunEnd: F2[Unit] =
for {
_ <- signalInterruption // interrupt so the upstreams have chance to complete
left <- resultL.get
right <- resultR.get
r <- F.fromEither(CompositeFailure.fromResults(left, right))
} yield r

val runStreams =
runStream(this, resultL, resultQL).start >> runStream(that, resultR, resultQR).start

Stream.bracket(runStreams)(_ => atRunEnd) >> watchInterrupted(pumpFromQueue)
}
Stream.eval(fstream).flatten

}

/** Given two sorted streams emits a single sorted stream, like in merge-sort.
* For entries that are considered equal by the Order, left stream element is emitted first.
* Note: both this and another streams MUST BE ORDERED already
Expand Down
48 changes: 48 additions & 0 deletions core/shared/src/test/scala/fs2/StreamMergeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -242,4 +242,52 @@ class StreamMergeSuite extends Fs2Suite {
}
}
}

test("mergePreferred prefers this over that") {

val units = Stream.unit.covary[IO].repeat
val left = units.map(Left(_))
val right = units.map(Right(_))

val stream = left.mergePreferred(right)

stream
.take(10000)
.fold((0L, 0L)) {
case ((left, right), Left(_)) => (left + 1, right)
case ((left, right), Right(_)) => (left, right + 1)
}
.compile
.lastOrError
.map { case (left, right) =>
val relLeft = left.toDouble / (left + right).toDouble
// Tolerate up to 2% elements of the non preferred stream.
// Increase the value, if the test (ocassionally) reports false positives.
val delta = 0.02d
assertEqualsDouble(relLeft, 1.0d, delta)
}
}

test("mergePreferred fully consumes this") {
forAllF { (stream: Stream[Pure, Int]) =>
stream.covary[IO].mergePreferred(Stream.empty.covary[IO]).assertEmitsSameAs(stream)
}
}

test("mergePreferred fully consumes that") {
forAllF { (stream: Stream[Pure, Int]) =>
Stream.empty.covary[IO].mergePreferred(stream.covary[IO]).assertEmitsSameAs(stream)
}
}

test("mergePreferred fully consumes both") {
forAllF { (leftStream: Stream[Pure, Int], rightStream: Stream[Pure, Int]) =>
val leftTagged = leftStream.covary[IO]
val rightTagged = rightStream.covary[IO]
leftTagged
.mergePreferred(rightTagged)
.assertEmitsUnorderedSameAs(leftStream ++ rightStream)
}
}

}
6 changes: 6 additions & 0 deletions core/shared/src/test/scala/fs2/StreamSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,12 @@ class StreamSuite extends Fs2Suite {
}
}

test("mergePreferred") {
testCancelation {
constantStream.mergePreferred(constantStream)
}
}

test("parJoin") {
testCancelation {
Stream(constantStream, constantStream).parJoin(2)
Expand Down