Skip to content

Commit

Permalink
Merge pull request #2137 from wb14123/stream-leak
Browse files Browse the repository at this point in the history
Prefetch a chunk of result for stream operation
  • Loading branch information
jatcwang authored Nov 21, 2024
2 parents 876c29b + 8d39231 commit 941f66d
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 4 deletions.
12 changes: 10 additions & 2 deletions modules/core/src/main/scala/doobie/syntax/stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,24 @@ import doobie.util.compat.=:=
import doobie.util.transactor.Transactor
import doobie.free.connection.ConnectionIO
import cats.data.Kleisli
import cats.effect.Concurrent
import cats.effect.kernel.{Async, MonadCancelThrow}
import fs2.{Pipe, Stream}

class StreamOps[F[_], A](fa: Stream[F, A]) {
def transact[M[_]: MonadCancelThrow](xa: Transactor[M])(implicit
def transactNoPrefetch[M[_]: MonadCancelThrow](xa: Transactor[M])(implicit
ev: Stream[F, A] =:= Stream[ConnectionIO, A]
): Stream[M, A] = xa.transP.apply(fa)

def transact[M[_]: Concurrent](xa: Transactor[M])(implicit
ev: Stream[F, A] =:= Stream[ConnectionIO, A]
): Stream[M, A] = transactNoPrefetch(xa).prefetchN(1)

}
class KleisliStreamOps[A, B](fa: Stream[Kleisli[ConnectionIO, A, *], B]) {
def transact[M[_]: MonadCancelThrow](xa: Transactor[M]): Stream[Kleisli[M, A, *], B] = xa.transPK[A].apply(fa)
def transactNoPrefetch[M[_]: MonadCancelThrow](xa: Transactor[M]): Stream[Kleisli[M, A, *], B] =
xa.transPK[A].apply(fa)
def transact[M[_]: Concurrent](xa: Transactor[M]): Stream[Kleisli[M, A, *], B] = transactNoPrefetch(xa).prefetchN(1)
}
class PipeOps[F[_], A, B](inner: Pipe[F, A, B]) {
def transact[M[_]: Async](xa: Transactor[M])(implicit ev: Pipe[F, A, B] =:= Pipe[ConnectionIO, A, B]): Pipe[M, A, B] =
Expand Down
4 changes: 2 additions & 2 deletions modules/example/src/main/scala/example/StreamingCopy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ object StreamingCopy extends IOApp.Simple {
sourceXA: Transactor[F],
sinkXA: Transactor[F]
)(
implicit ev: MonadCancelThrow[F]
implicit ev: Concurrent[F]
): Stream[F, B] =
fuseMapGeneric(source, identity[A], sink)(sourceXA, sinkXA)

Expand All @@ -44,7 +44,7 @@ object StreamingCopy extends IOApp.Simple {
sourceXA: Transactor[F],
sinkXA: Transactor[F]
)(
implicit ev: MonadCancelThrow[F]
implicit ev: Concurrent[F]
): Stream[F, C] = {

// Interpret a ConnectionIO into a Kleisli arrow for F via the sink interpreter.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright (c) 2013-2020 Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package doobie.postgres

import cats.effect.IO
import com.zaxxer.hikari.HikariDataSource
import doobie.*
import doobie.implicits.*
import doobie.util.transactor.Strategy
import fs2.Stream

import java.sql.SQLTransientConnectionException
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicBoolean
import scala.concurrent.ExecutionContext
import scala.concurrent.duration.*

class StreamPrefetchSuite extends munit.FunSuite {

import cats.effect.unsafe.implicits.global

private var dataSource: HikariDataSource = null
private var xa: Transactor[IO] = null
private val count = 100

private def createDataSource() = {

Class.forName("org.postgresql.Driver")
val dataSource = new HikariDataSource

dataSource `setJdbcUrl` "jdbc:postgresql://localhost:5432/postgres"
dataSource `setUsername` "postgres"
dataSource `setPassword` "password"
dataSource `setMaximumPoolSize` 1
dataSource `setConnectionTimeout` 2000
dataSource
}

override def beforeAll(): Unit = {
super.beforeAll()
dataSource = createDataSource()
xa = Transactor.fromDataSource[IO](
dataSource,
ExecutionContext.fromExecutor(Executors.newFixedThreadPool(32))
)
val insert = for {
_ <- Stream.eval(sql"CREATE TABLE if not exists stream_cancel_test (i text)".update.run.transact(xa))
_ <- Stream.eval(sql"truncate table stream_cancel_test".update.run.transact(xa))
_ <- Stream.eval(
sql"INSERT INTO stream_cancel_test select 1 from generate_series(1, $count)".update.run.transact(xa))
} yield ()

insert.compile.drain.unsafeRunSync()
}

override def afterAll(): Unit = {
dataSource.close()
super.afterAll()
}

test("Connection returned before stream is drained, if chunk size is larger than result count") {
val xa = Transactor.fromDataSource[IO](
dataSource,
ExecutionContext.fromExecutor(Executors.newFixedThreadPool(32))
)

val streamLargerBuffer = fr"select * from stream_cancel_test".query[Int].streamWithChunkSize(200).transact(xa)
.evalMap(_ => fr"select 1".query[Int].unique.transact(xa))
.compile.count

assertEquals(streamLargerBuffer.unsafeRunSync(), count.toLong)
}

test("Connection is not returned after consuming only 1 chunk, if chunk size is smaller than result count") {
val streamSmallerBuffer = fr"select * from stream_cancel_test".query[Int].streamWithChunkSize(10).transact(xa)
.evalMap(_ => fr"select 1".query[Int].unique.transact(xa))
.compile.count

intercept[SQLTransientConnectionException](streamSmallerBuffer.unsafeRunSync())
}

test("Connection returned before stream is drained, if chunk size is smaller than result count") {
val hasClosed = new AtomicBoolean(false)
val xaCopy = xa.copy(
strategy0 = Strategy.default.copy(
always = Strategy.default.always.flatMap(_ => FC.delay(hasClosed.set(true)))
)
)

val earlyClose = new AtomicBoolean(false)

val streamSmallerBufferValid =
fr"select * from stream_cancel_test".query[Int].streamWithChunkSize(10).transact(xaCopy)
.evalMap { _ => IO { if (hasClosed.get()) earlyClose.set(true) } >> IO.sleep(10.milliseconds) }
.compile.count

assertEquals(streamSmallerBufferValid.unsafeRunSync(), count.toLong)
assertEquals(earlyClose.get(), true)
}

}

0 comments on commit 941f66d

Please sign in to comment.