diff --git a/build.sbt b/build.sbt index 67e1ef6d0d..9da0f0a524 100644 --- a/build.sbt +++ b/build.sbt @@ -1539,7 +1539,10 @@ lazy val zioHttpServer: ProjectMatrix = (projectMatrix in file("server/zio-http- .settings(commonJvmSettings) .settings( name := "tapir-zio-http-server", - libraryDependencies ++= Seq("dev.zio" %% "zio-interop-cats" % Versions.zioInteropCats % Test, "dev.zio" %% "zio-http" % "3.0.0-RC4") + libraryDependencies ++= Seq( + "dev.zio" %% "zio-interop-cats" % Versions.zioInteropCats % Test, + "dev.zio" %% "zio-http" % Versions.zioHttp + ) ) .jvmPlatform(scalaVersions = scala2And3Versions) .dependsOn(serverCore, zio, serverTests % Test) diff --git a/project/Versions.scala b/project/Versions.scala index ee4dc3d037..38594bf6d0 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -34,6 +34,7 @@ object Versions { val iron = "2.5.0" val enumeratum = "1.7.3" val zio = "2.0.21" + val zioHttp = "3.0.0-RC6" val zioInteropCats = "23.0.0.8" val zioInteropReactiveStreams = "2.0.2" val zioJson = "0.6.2" diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpBodyListener.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpBodyListener.scala index 3289e5cc23..52d3abd4c7 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpBodyListener.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpBodyListener.scala @@ -1,7 +1,7 @@ package sttp.tapir.server.ziohttp import sttp.tapir.server.interpreter.BodyListener -import zio.{RIO, ZIO} +import zio.{Cause, RIO, ZIO} import zio.stream.ZStream import scala.util.{Failure, Success, Try} @@ -11,19 +11,18 @@ private[ziohttp] class ZioHttpBodyListener[R] extends BodyListener[RIO[R, *], Zi ZIO .environmentWithZIO[R] .apply { r => + def succeed = cb(Success(())).provideEnvironment(r) + def failed(cause: Cause[Throwable]) = cb(Failure(cause.squash)).orDie.provideEnvironment(r) + body match { case Right(ZioStreamHttpResponseBody(stream, contentLength)) => ZIO.right( ZioStreamHttpResponseBody( - stream.onError(cause => cb(Failure(cause.squash)).orDie.provideEnvironment(r)) ++ ZStream - .fromZIO(cb(Success(()))) - .provideEnvironment(r) - .drain, + stream.onError(failed) ++ ZStream.fromZIO(succeed).drain, contentLength ) ) - case raw @ Right(_: ZioRawHttpResponseBody) => cb(Success(())).provideEnvironment(r).map(_ => raw) - case ws @ Left(_) => cb(Success(())).provideEnvironment(r).map(_ => ws) + case rawOrWs => succeed.as(rawOrWs) } } } diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala index e427b932f7..d2aac7cbde 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioHttpInterpreter.scala @@ -88,7 +88,7 @@ trait ZioHttpInterpreter[R] { resp: ServerResponse[ZioResponseBody], body: Option[ZioHttpResponseBody] ): UIO[Response] = { - val baseHeaders = resp.headers.groupBy(_.name).flatMap(sttpToZioHttpHeader).toList + val baseHeaders = resp.headers.groupBy(_.name).map(sttpToZioHttpHeader).toList val allHeaders = body.flatMap(_.contentLength) match { case Some(contentLength) if resp.contentLength.isEmpty => ZioHttpHeader.ContentLength(contentLength) :: baseHeaders case _ => baseHeaders @@ -97,20 +97,21 @@ trait ZioHttpInterpreter[R] { ZIO.succeed( Response( - status = Status.fromInt(statusCode).getOrElse(Status.Custom(statusCode)), + status = Status.fromInt(statusCode), headers = ZioHttpHeaders(allHeaders), body = body .map { - case ZioStreamHttpResponseBody(stream, _) => Body.fromStream(stream) - case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk) + case ZioStreamHttpResponseBody(stream, Some(contentLength)) => Body.fromStream(stream, contentLength) + case ZioStreamHttpResponseBody(stream, None) => Body.fromStreamChunked(stream) + case ZioRawHttpResponseBody(chunk, _) => Body.fromChunk(chunk) } .getOrElse(Body.empty) ) ) } - private def sttpToZioHttpHeader(hl: (String, Seq[SttpHeader])): List[ZioHttpHeader] = - List(ZioHttpHeader.Custom(hl._1, hl._2.map(_.value).mkString(", "))) + private def sttpToZioHttpHeader(hl: (String, Seq[SttpHeader])): ZioHttpHeader = + ZioHttpHeader.Custom(hl._1, hl._2.map(_.value).mkString(", ")) } object ZioHttpInterpreter {