diff --git a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sWebSockets.scala b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sWebSockets.scala index 3a53cbc292..ae81873df2 100644 --- a/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sWebSockets.scala +++ b/server/http4s-server/src/main/scala/sttp/tapir/server/http4s/Http4sWebSockets.scala @@ -101,6 +101,8 @@ private[http4s] object Http4sWebSockets { case (None, f: WebSocketFrame.Pong) => (None, Some(f)) case (None, f: WebSocketFrame.Close) => (None, Some(f)) case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f)) + case (None, f: WebSocketFrame.Text) => (Some(Right(f.payload)), None) + case (None, f: WebSocketFrame.Binary) => (Some(Left(f.payload)), None) case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload))) case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None) case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload))) diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala index 26e50651de..3612925382 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala @@ -100,19 +100,7 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( private def optionallyConcatenateFrames(s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] = if (doConcatenate) { - type Accumulator = Option[Either[Array[Byte], String]] - - s.mapAccumulate(None: Accumulator) { - case (None, f: WebSocketFrame.Ping) => (None, Some(f)) - case (None, f: WebSocketFrame.Pong) => (None, Some(f)) - case (None, f: WebSocketFrame.Close) => (None, Some(f)) - case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f)) - case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload))) - case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None) - case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload))) - case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None) - case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.") - }.collect { case (_, Some(f)) => f } + s.mapAccumulate(None: Accumulator)(accumulateFrameState).collect { case (_, Some(f)) => f } } else s } diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketFrameConverters.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketFrameConverters.scala index d9e7c75cd1..748dba94ac 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketFrameConverters.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketFrameConverters.scala @@ -35,4 +35,23 @@ object WebSocketFrameConverters { case WebSocketFrame.Binary(payload, finalFragment, rsvOpt) => new BinaryWebSocketFrame(finalFragment, rsvOpt.getOrElse(0), Unpooled.wrappedBuffer(payload)) } + + type Accumulator = Option[Either[Array[Byte], String]] + val accumulateFrameState: (Accumulator, WebSocketFrame) => (Accumulator, Option[WebSocketFrame]) = { + case (None, f: WebSocketFrame.Ping) => (None, Some(f)) + case (None, f: WebSocketFrame.Pong) => (None, Some(f)) + case (None, f: WebSocketFrame.Close) => (None, Some(f)) + case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f)) + case (None, f: WebSocketFrame.Text) => (Some(Right(f.payload)), None) + case (None, f: WebSocketFrame.Binary) => (Some(Left(f.payload)), None) + case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload))) + case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None) + case (Some(Right(acc)), f: WebSocketFrame.Binary) if f.finalFragment => + // Netty's ContinuationFrame is translated to Binary, so we need to handle a Binary frame received after accumulating Text + (None, Some(WebSocketFrame.Text(payload = acc + new String(f.payload), finalFragment = true, rsv = f.rsv))) + case (Some(Right(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Right(acc + new String(f.payload))), None) + case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload))) + case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None) + case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.") + } } diff --git a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/ws/OxSourceWebSocketProcessor.scala b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/ws/OxSourceWebSocketProcessor.scala index 17140d1fb8..fb7ecc5d18 100644 --- a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/ws/OxSourceWebSocketProcessor.scala +++ b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/ws/OxSourceWebSocketProcessor.scala @@ -66,16 +66,5 @@ private[sync] object OxSourceWebSocketProcessor: private def optionallyConcatenateFrames(s: Source[WebSocketFrame], doConcatenate: Boolean)(using Ox): Source[WebSocketFrame] = if doConcatenate then - type Accumulator = Option[Either[Array[Byte], String]] - s.mapStateful(() => None: Accumulator) { - case (None, f: WebSocketFrame.Ping) => (None, Some(f)) - case (None, f: WebSocketFrame.Pong) => (None, Some(f)) - case (None, f: WebSocketFrame.Close) => (None, Some(f)) - case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f)) - case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload))) - case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None) - case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload))) - case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None) - case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.") - }.collectAsView { case Some(f: WebSocketFrame) => f } + s.mapStateful(() => None: Accumulator)(accumulateFrameState).collectAsView { case Some(f: WebSocketFrame) => f } else s diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index c33c8253bb..b56ca42041 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -29,7 +29,8 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( failingPipe: Boolean, handlePong: Boolean, // Disabled for eaxmple for vert.x, which sometimes drops connection without returning Close - expectCloseResponse: Boolean = true + expectCloseResponse: Boolean = true, + frameConcatenation: Boolean = true )(implicit m: MonadError[F] ) extends EitherValues { @@ -244,7 +245,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( response2.body shouldBe Right("echo: testOk") } } - ) ++ autoPingTests ++ failingPipeTests ++ handlePongTests + ) ++ autoPingTests ++ failingPipeTests ++ handlePongTests ++ frameConcatenationTests val autoPingTests = if (autoPing) @@ -314,6 +315,53 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( ) else List.empty + val frameConcatenationTests = if (frameConcatenation) List( + testServer( + endpoint.out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) + .autoPing(None) + .autoPongOnPing(false) + .concatenateFragmentedFrames(true) + ), + "concatenate fragmented text frames" + )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.send(WebSocketFrame.Text("f1", finalFragment = false, None)) + _ <- ws.sendText("f2") + r <- ws.receiveText() + _ <- ws.close() + } yield r + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map { _.body shouldBe (Right("echo: f1f2")) } + }, + testServer( + endpoint.out( + webSocketBody[Array[Byte], CodecFormat.OctetStream, String, CodecFormat.TextPlain](streams) + .autoPing(None) + .autoPongOnPing(false) + .concatenateFragmentedFrames(true) + ), + "concatenate fragmented binary frames" + )((_: Unit) => pureResult(functionToPipe((bs: Array[Byte]) => s"echo: ${new String(bs)}").asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.send(WebSocketFrame.Binary("frame1-bytes;".getBytes(), finalFragment = false, None)) + _ <- ws.sendBinary("frame2-bytes".getBytes()) + r <- ws.receiveText() + _ <- ws.close() + } yield r + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map { _.body shouldBe (Right("echo: frame1-bytes;frame2-bytes")) } + } + ) else Nil + val handlePongTests = if (handlePong) List( diff --git a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala index 6aa70b3a2c..4044072ac4 100644 --- a/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala +++ b/server/vertx-server/cats/src/test/scala/sttp/tapir/server/vertx/cats/CatsVertxServerTest.scala @@ -43,7 +43,8 @@ class CatsVertxServerTest extends TestSuite { autoPing = false, failingPipe = false, handlePong = true, - expectCloseResponse = false + expectCloseResponse = false, + frameConcatenation = false ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => Stream.empty diff --git a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/websocket.scala b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/websocket.scala index d87ef689a9..9bcf8cd7a3 100644 --- a/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/websocket.scala +++ b/server/vertx-server/src/main/scala/sttp/tapir/server/vertx/streams/websocket.scala @@ -15,6 +15,10 @@ object websocket { (None, Some(f)) case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f)) + case (None, f: WebSocketFrame.Text) => + (Some(Right(f.payload)), None) + case (None, f: WebSocketFrame.Binary) => + (Some(Left(f.payload)), None) case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload))) case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => diff --git a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala index 862a7ac524..fca79709f5 100644 --- a/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala +++ b/server/vertx-server/src/test/scala/sttp/tapir/server/vertx/VertxServerTest.scala @@ -59,7 +59,8 @@ class VertxServerTest extends TestSuite { autoPing = false, failingPipe = false, handlePong = false, - expectCloseResponse = false + expectCloseResponse = false, + frameConcatenation = false ) { override def functionToPipe[A, B](f: A => B): VertxStreams.Pipe[A, B] = in => new ReadStreamMapping(in, f) override def emptyPipe[A, B]: VertxStreams.Pipe[A, B] = _ => new EmptyReadStream() diff --git a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala index e2a53b57d1..66c06f9c29 100644 --- a/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala +++ b/server/vertx-server/zio/src/test/scala/sttp/tapir/server/vertx/zio/ZioVertxServerTest.scala @@ -49,7 +49,8 @@ class ZioVertxServerTest extends TestSuite with OptionValues { autoPing = true, failingPipe = false, handlePong = false, - expectCloseResponse = false + expectCloseResponse = false, + frameConcatenation = false ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty diff --git a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala index 1814d859b7..ee46017620 100644 --- a/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala +++ b/server/zio-http-server/src/main/scala/sttp/tapir/server/ziohttp/ZioWebSockets.scala @@ -124,15 +124,13 @@ object ZioWebSockets { case (None, f: SttpWebSocketFrame.Pong) => (None, Some(f)) case (None, f: SttpWebSocketFrame.Close) => (None, Some(f)) case (None, f: SttpWebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f)) + case (None, f: SttpWebSocketFrame.Text) => (Some(Right(f.payload)), None) + case (None, f: SttpWebSocketFrame.Binary) => (Some(Left(f.payload)), None) case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload))) case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None) case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if f.finalFragment => - println(s"final fragment: $f") - println(s"acc: $acc") (None, Some(f.copy(payload = acc + f.payload))) case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if !f.finalFragment => - println(s"final fragment: $f") - println(s"acc: $acc") (Some(Right(acc + f.payload)), None) case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.") diff --git a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala index 417e16a5af..e462a81630 100644 --- a/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala +++ b/server/zio-http-server/src/test/scala/sttp/tapir/server/ziohttp/ZioHttpServerTest.scala @@ -271,7 +271,8 @@ class ZioHttpServerTest extends TestSuite { ZioStreams, autoPing = true, failingPipe = false, - handlePong = false + handlePong = false, + frameConcatenation = false ) { override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty