diff --git a/modules/core/src/main/scala/doobie/syntax/stream.scala b/modules/core/src/main/scala/doobie/syntax/stream.scala index da0866c17..b204e47b1 100644 --- a/modules/core/src/main/scala/doobie/syntax/stream.scala +++ b/modules/core/src/main/scala/doobie/syntax/stream.scala @@ -8,6 +8,7 @@ import doobie.util.compat.=:= import doobie.util.transactor.Transactor import doobie.free.connection.ConnectionIO import cats.data.Kleisli +import cats.effect.Concurrent import cats.effect.kernel.{Async, MonadCancelThrow} import fs2.{Pipe, Stream} @@ -15,6 +16,10 @@ class StreamOps[F[_], A](fa: Stream[F, A]) { def transact[M[_]: MonadCancelThrow](xa: Transactor[M])(implicit ev: Stream[F, A] =:= Stream[ConnectionIO, A] ): Stream[M, A] = xa.transP.apply(fa) + + def transactBuffer[M[_]: MonadCancelThrow: Concurrent](xa: Transactor[M], bufferSize: Int)(implicit + ev: Stream[F, A] =:= Stream[ConnectionIO, A], + ): Stream[M, A] = xa.transBuffer(bufferSize).apply(fa) } class KleisliStreamOps[A, B](fa: Stream[Kleisli[ConnectionIO, A, *], B]) { def transact[M[_]: MonadCancelThrow](xa: Transactor[M]): Stream[Kleisli[M, A, *], B] = xa.transPK[A].apply(fa) diff --git a/modules/core/src/main/scala/doobie/util/transactor.scala b/modules/core/src/main/scala/doobie/util/transactor.scala index 809540101..a0b5a133f 100644 --- a/modules/core/src/main/scala/doobie/util/transactor.scala +++ b/modules/core/src/main/scala/doobie/util/transactor.scala @@ -11,11 +11,13 @@ import doobie.implicits.* import doobie.util.lens.* import doobie.util.log.LogHandler import doobie.util.yolo.Yolo -import doobie.free.{connection as IFC} +import doobie.free.connection as IFC import cats.{Monad, ~>} import cats.data.Kleisli +import cats.effect.Concurrent import cats.effect.kernel.{Async, MonadCancelThrow, Resource} import cats.effect.kernel.Resource.ExitCase +import cats.effect.std.Queue import fs2.{Pipe, Stream} import java.sql.{Connection, DriverManager} @@ -186,6 +188,19 @@ object transactor { }.scope } + def transBuffer(bufferSize: Int)(implicit ev: Concurrent[M]): Stream[ConnectionIO, *] ~> Stream[M, *] = + new (Stream[ConnectionIO, *] ~> Stream[M, *]) { + def apply[T](s: Stream[ConnectionIO, T]) = { + fs2.Stream.eval(Queue.bounded[M, Option[T]](bufferSize)).flatMap { buffer => + val res = Stream.resource(connect(kernel)).flatMap { c => + Stream.resource(strategy.resource).flatMap(_ => s).translate(run(c)) + .evalMap(x => buffer.offer(Some(x))) ++ Stream.emit(None).evalMap(buffer.offer) + } + Stream.fromQueueNoneTerminated(buffer).concurrently(res) + } + } + } + def rawTransPK[I](implicit ev: MonadCancelThrow[M] ): Stream[Kleisli[ConnectionIO, I, *], *] ~> Stream[Kleisli[M, I, *], *] = diff --git a/modules/hikari/src/test/scala/doobie/postgres/PGConcurrentSuite.scala b/modules/hikari/src/test/scala/doobie/postgres/PGConcurrentSuite.scala index 2ce4f462f..263d5cfe3 100644 --- a/modules/hikari/src/test/scala/doobie/postgres/PGConcurrentSuite.scala +++ b/modules/hikari/src/test/scala/doobie/postgres/PGConcurrentSuite.scala @@ -4,13 +4,13 @@ package doobie.postgres -import java.util.concurrent.Executors - import cats.effect.IO import com.zaxxer.hikari.HikariDataSource import doobie.* import doobie.implicits.* +import fs2.Stream +import java.util.concurrent.Executors import scala.concurrent.ExecutionContext import scala.concurrent.duration.* @@ -43,7 +43,6 @@ class PGConcurrentSuite extends munit.FunSuite { super.afterAll() } - /* test("Not leak connections with recursive query streams") { val xa = Transactor.fromDataSource[IO]( @@ -62,20 +61,35 @@ class PGConcurrentSuite extends munit.FunSuite { assertEquals(pollingStream.unsafeRunSync(), ()) } - */ test("Connection returned before stream is drained") { - val xa = Transactor.fromDataSource[IO]( dataSource, ExecutionContext.fromExecutor(Executors.newFixedThreadPool(32)) ) - val stream = fr"select 1".query[Int].stream.transact(xa) + val count = 100 + val insert = for { + _ <- Stream.eval(sql"CREATE TABLE if not exists stream_cancel_test (i text)".update.run.transact(xa)) + _ <- Stream.eval(sql"truncate table stream_cancel_test".update.run.transact(xa)) + _ <- Stream.eval(sql"INSERT INTO stream_cancel_test values ('1')".update.run.transact(xa)).repeatN(count) + } yield () + + insert.compile.drain.unsafeRunSync() + + val streamLargerBuffer = fr"select * from stream_cancel_test".query[Int].stream.transactBuffer(xa, 1024) .evalMap(_ => fr"select 1".query[Int].unique.transact(xa)) - .compile.drain + .compile.count + + assertEquals(streamLargerBuffer.unsafeRunSync(), count.toLong) + + // if buffer is less than result set, it will be still block new connection since the result set is not drained + // use sleep to test the result set can be drained + val streamSmallerBuffer = fr"select * from stream_cancel_test".query[Int].stream.transactBuffer(xa, 50) + .evalMap(_ => IO.sleep(10.milliseconds)) + .compile.count - assertEquals(stream.unsafeRunSync(), ()) + assertEquals(streamSmallerBuffer.unsafeRunSync(), count.toLong) } }