Skip to content

Commit

Permalink
Fix WebSocket frame concatenation for Netty (#3801)
Browse files Browse the repository at this point in the history
  • Loading branch information
kciesielski authored May 29, 2024
1 parent 6298a5f commit cf82f2f
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ private[http4s] object Http4sWebSockets {
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (None, f: WebSocketFrame.Text) => (Some(Right(f.payload)), None)
case (None, f: WebSocketFrame.Binary) => (Some(Left(f.payload)), None)
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,7 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP](

private def optionallyConcatenateFrames(s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] =
if (doConcatenate) {
type Accumulator = Option[Either[Array[Byte], String]]

s.mapAccumulate(None: Accumulator) {
case (None, f: WebSocketFrame.Ping) => (None, Some(f))
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None)
case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
}.collect { case (_, Some(f)) => f }
s.mapAccumulate(None: Accumulator)(accumulateFrameState).collect { case (_, Some(f)) => f }
} else s
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,23 @@ object WebSocketFrameConverters {
case WebSocketFrame.Binary(payload, finalFragment, rsvOpt) =>
new BinaryWebSocketFrame(finalFragment, rsvOpt.getOrElse(0), Unpooled.wrappedBuffer(payload))
}

type Accumulator = Option[Either[Array[Byte], String]]
val accumulateFrameState: (Accumulator, WebSocketFrame) => (Accumulator, Option[WebSocketFrame]) = {
case (None, f: WebSocketFrame.Ping) => (None, Some(f))
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (None, f: WebSocketFrame.Text) => (Some(Right(f.payload)), None)
case (None, f: WebSocketFrame.Binary) => (Some(Left(f.payload)), None)
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Binary) if f.finalFragment =>
// Netty's ContinuationFrame is translated to Binary, so we need to handle a Binary frame received after accumulating Text
(None, Some(WebSocketFrame.Text(payload = acc + new String(f.payload), finalFragment = true, rsv = f.rsv)))
case (Some(Right(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Right(acc + new String(f.payload))), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None)
case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,5 @@ private[sync] object OxSourceWebSocketProcessor:

private def optionallyConcatenateFrames(s: Source[WebSocketFrame], doConcatenate: Boolean)(using Ox): Source[WebSocketFrame] =
if doConcatenate then
type Accumulator = Option[Either[Array[Byte], String]]
s.mapStateful(() => None: Accumulator) {
case (None, f: WebSocketFrame.Ping) => (None, Some(f))
case (None, f: WebSocketFrame.Pong) => (None, Some(f))
case (None, f: WebSocketFrame.Close) => (None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: WebSocketFrame.Text) if f.finalFragment => (None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: WebSocketFrame.Text) if !f.finalFragment => (Some(Right(acc + f.payload)), None)
case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
}.collectAsView { case Some(f: WebSocketFrame) => f }
s.mapStateful(() => None: Accumulator)(accumulateFrameState).collectAsView { case Some(f: WebSocketFrame) => f }
else s
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE](
failingPipe: Boolean,
handlePong: Boolean,
// Disabled for eaxmple for vert.x, which sometimes drops connection without returning Close
expectCloseResponse: Boolean = true
expectCloseResponse: Boolean = true,
frameConcatenation: Boolean = true
)(implicit
m: MonadError[F]
) extends EitherValues {
Expand Down Expand Up @@ -244,7 +245,7 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE](
response2.body shouldBe Right("echo: testOk")
}
}
) ++ autoPingTests ++ failingPipeTests ++ handlePongTests
) ++ autoPingTests ++ failingPipeTests ++ handlePongTests ++ frameConcatenationTests

val autoPingTests =
if (autoPing)
Expand Down Expand Up @@ -314,6 +315,53 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE](
)
else List.empty

val frameConcatenationTests = if (frameConcatenation) List(
testServer(
endpoint.out(
webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)
.autoPing(None)
.autoPongOnPing(false)
.concatenateFragmentedFrames(true)
),
"concatenate fragmented text frames"
)((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.send(WebSocketFrame.Text("f1", finalFragment = false, None))
_ <- ws.sendText("f2")
r <- ws.receiveText()
_ <- ws.close()
} yield r
})
.get(baseUri.scheme("ws"))
.send(backend)
.map { _.body shouldBe (Right("echo: f1f2")) }
},
testServer(
endpoint.out(
webSocketBody[Array[Byte], CodecFormat.OctetStream, String, CodecFormat.TextPlain](streams)
.autoPing(None)
.autoPongOnPing(false)
.concatenateFragmentedFrames(true)
),
"concatenate fragmented binary frames"
)((_: Unit) => pureResult(functionToPipe((bs: Array[Byte]) => s"echo: ${new String(bs)}").asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.send(WebSocketFrame.Binary("frame1-bytes;".getBytes(), finalFragment = false, None))
_ <- ws.sendBinary("frame2-bytes".getBytes())
r <- ws.receiveText()
_ <- ws.close()
} yield r
})
.get(baseUri.scheme("ws"))
.send(backend)
.map { _.body shouldBe (Right("echo: frame1-bytes;frame2-bytes")) }
}
) else Nil

val handlePongTests =
if (handlePong)
List(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class CatsVertxServerTest extends TestSuite {
autoPing = false,
failingPipe = false,
handlePong = true,
expectCloseResponse = false
expectCloseResponse = false,
frameConcatenation = false
) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => Stream.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ object websocket {
(None, Some(f))
case (None, f: WebSocketFrame.Data[_]) if f.finalFragment =>
(None, Some(f))
case (None, f: WebSocketFrame.Text) =>
(Some(Right(f.payload)), None)
case (None, f: WebSocketFrame.Binary) =>
(Some(Left(f.payload)), None)
case (Some(Left(acc)), f: WebSocketFrame.Binary) if f.finalFragment =>
(None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: WebSocketFrame.Binary) if !f.finalFragment =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class VertxServerTest extends TestSuite {
autoPing = false,
failingPipe = false,
handlePong = false,
expectCloseResponse = false
expectCloseResponse = false,
frameConcatenation = false
) {
override def functionToPipe[A, B](f: A => B): VertxStreams.Pipe[A, B] = in => new ReadStreamMapping(in, f)
override def emptyPipe[A, B]: VertxStreams.Pipe[A, B] = _ => new EmptyReadStream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class ZioVertxServerTest extends TestSuite with OptionValues {
autoPing = true,
failingPipe = false,
handlePong = false,
expectCloseResponse = false
expectCloseResponse = false,
frameConcatenation = false
) {
override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: streams.Pipe[A, B] = _ => ZStream.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,13 @@ object ZioWebSockets {
case (None, f: SttpWebSocketFrame.Pong) => (None, Some(f))
case (None, f: SttpWebSocketFrame.Close) => (None, Some(f))
case (None, f: SttpWebSocketFrame.Data[_]) if f.finalFragment => (None, Some(f))
case (None, f: SttpWebSocketFrame.Text) => (Some(Right(f.payload)), None)
case (None, f: SttpWebSocketFrame.Binary) => (Some(Left(f.payload)), None)
case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if f.finalFragment => (None, Some(f.copy(payload = acc ++ f.payload)))
case (Some(Left(acc)), f: SttpWebSocketFrame.Binary) if !f.finalFragment => (Some(Left(acc ++ f.payload)), None)
case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if f.finalFragment =>
println(s"final fragment: $f")
println(s"acc: $acc")
(None, Some(f.copy(payload = acc + f.payload)))
case (Some(Right(acc)), f: SttpWebSocketFrame.Text) if !f.finalFragment =>
println(s"final fragment: $f")
println(s"acc: $acc")
(Some(Right(acc + f.payload)), None)

case (acc, f) => throw new IllegalStateException(s"Cannot accumulate web socket frames. Accumulator: $acc, frame: $f.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ class ZioHttpServerTest extends TestSuite {
ZioStreams,
autoPing = true,
failingPipe = false,
handlePong = false
handlePong = false,
frameConcatenation = false
) {
override def functionToPipe[A, B](f: A => B): ZioStreams.Pipe[A, B] = in => in.map(f)
override def emptyPipe[A, B]: ZioStreams.Pipe[A, B] = _ => ZStream.empty
Expand Down

0 comments on commit cf82f2f

Please sign in to comment.