Skip to content

Commit

Permalink
prototype stream with buffer to avoid block on transaction commit
Browse files Browse the repository at this point in the history
  • Loading branch information
wb14123 committed Nov 13, 2024
1 parent 410e034 commit 326269a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
5 changes: 5 additions & 0 deletions modules/core/src/main/scala/doobie/syntax/stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@ import doobie.util.compat.=:=
import doobie.util.transactor.Transactor
import doobie.free.connection.ConnectionIO
import cats.data.Kleisli
import cats.effect.Concurrent
import cats.effect.kernel.{Async, MonadCancelThrow}
import fs2.{Pipe, Stream}

class StreamOps[F[_], A](fa: Stream[F, A]) {
def transact[M[_]: MonadCancelThrow](xa: Transactor[M])(implicit
ev: Stream[F, A] =:= Stream[ConnectionIO, A]
): Stream[M, A] = xa.transP.apply(fa)

def transactBuffer[M[_]: MonadCancelThrow: Concurrent](xa: Transactor[M], bufferSize: Int)(implicit
ev: Stream[F, A] =:= Stream[ConnectionIO, A],
): Stream[M, A] = xa.transBuffer(bufferSize).apply(fa)
}
class KleisliStreamOps[A, B](fa: Stream[Kleisli[ConnectionIO, A, *], B]) {
def transact[M[_]: MonadCancelThrow](xa: Transactor[M]): Stream[Kleisli[M, A, *], B] = xa.transPK[A].apply(fa)
Expand Down
17 changes: 16 additions & 1 deletion modules/core/src/main/scala/doobie/util/transactor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ import doobie.implicits.*
import doobie.util.lens.*
import doobie.util.log.LogHandler
import doobie.util.yolo.Yolo
import doobie.free.{connection as IFC}
import doobie.free.connection as IFC
import cats.{Monad, ~>}
import cats.data.Kleisli
import cats.effect.Concurrent
import cats.effect.kernel.{Async, MonadCancelThrow, Resource}
import cats.effect.kernel.Resource.ExitCase
import cats.effect.std.Queue
import fs2.{Pipe, Stream}

import java.sql.{Connection, DriverManager}
Expand Down Expand Up @@ -186,6 +188,19 @@ object transactor {
}.scope
}

def transBuffer(bufferSize: Int)(implicit ev: Concurrent[M]): Stream[ConnectionIO, *] ~> Stream[M, *] =
new (Stream[ConnectionIO, *] ~> Stream[M, *]) {
def apply[T](s: Stream[ConnectionIO, T]) = {
fs2.Stream.eval(Queue.bounded[M, Option[T]](bufferSize)).flatMap { buffer =>
val res = Stream.resource(connect(kernel)).flatMap { c =>
Stream.resource(strategy.resource).flatMap(_ => s).translate(run(c))
.evalMap(x => buffer.offer(Some(x))) ++ Stream.emit(None).evalMap(buffer.offer)
}
Stream.fromQueueNoneTerminated(buffer).concurrently(res)
}
}
}

def rawTransPK[I](implicit
ev: MonadCancelThrow[M]
): Stream[Kleisli[ConnectionIO, I, *], *] ~> Stream[Kleisli[M, I, *], *] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

package doobie.postgres

import java.util.concurrent.Executors

import cats.effect.IO
import com.zaxxer.hikari.HikariDataSource
import doobie.*
import doobie.implicits.*
import fs2.Stream

import java.util.concurrent.Executors
import scala.concurrent.ExecutionContext
import scala.concurrent.duration.*

Expand Down Expand Up @@ -43,7 +43,6 @@ class PGConcurrentSuite extends munit.FunSuite {
super.afterAll()
}

/*
test("Not leak connections with recursive query streams") {

val xa = Transactor.fromDataSource[IO](
Expand All @@ -62,20 +61,35 @@ class PGConcurrentSuite extends munit.FunSuite {

assertEquals(pollingStream.unsafeRunSync(), ())
}
*/

test("Connection returned before stream is drained") {

val xa = Transactor.fromDataSource[IO](
dataSource,
ExecutionContext.fromExecutor(Executors.newFixedThreadPool(32))
)

val stream = fr"select 1".query[Int].stream.transact(xa)
val count = 100
val insert = for {
_ <- Stream.eval(sql"CREATE TABLE if not exists stream_cancel_test (i text)".update.run.transact(xa))
_ <- Stream.eval(sql"truncate table stream_cancel_test".update.run.transact(xa))
_ <- Stream.eval(sql"INSERT INTO stream_cancel_test values ('1')".update.run.transact(xa)).repeatN(count)
} yield ()

insert.compile.drain.unsafeRunSync()

val streamLargerBuffer = fr"select * from stream_cancel_test".query[Int].stream.transactBuffer(xa, 1024)
.evalMap(_ => fr"select 1".query[Int].unique.transact(xa))
.compile.drain
.compile.count

assertEquals(streamLargerBuffer.unsafeRunSync(), count.toLong)

// if buffer is less than result set, it will be still block new connection since the result set is not drained
// use sleep to test the result set can be drained
val streamSmallerBuffer = fr"select * from stream_cancel_test".query[Int].stream.transactBuffer(xa, 50)
.evalMap(_ => IO.sleep(10.milliseconds))
.compile.count

assertEquals(stream.unsafeRunSync(), ())
assertEquals(streamSmallerBuffer.unsafeRunSync(), count.toLong)
}

}

0 comments on commit 326269a

Please sign in to comment.