Skip to content

Commit

Permalink
Use flow instead fo source in OxPipe
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Oct 8, 2024
1 parent 854eb3e commit fc45215
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@

//> using dep com.softwaremill.sttp.tapir::tapir-core:1.11.5
//> using dep com.softwaremill.sttp.tapir::tapir-netty-server-sync:1.11.5
//> using dep com.softwaremill.ox::core:0.4.0
//> using dep com.softwaremill.ox::core:0.5.0
// the explicit ox dependency is only needed until tapir is updated

package sttp.tapir.examples.websocket

import ox.channels.{Actor, ActorRef, Channel, ChannelClosed, Default, DefaultResult, selectOrClosed}
import ox.{ExitCode, Ox, OxApp, fork, never, releaseAfterScope}
import ox.{ExitCode, Ox, OxApp, fork, never, releaseAfterScope, supervised}
import sttp.tapir.*
import sttp.tapir.CodecFormat.*
import sttp.tapir.server.netty.sync.{NettySyncServer, OxStreams}

import java.util.UUID
import ox.flow.Flow
import ox.flow.FlowEmit

type ChatMemberId = UUID

Expand All @@ -33,7 +36,7 @@ class ChatRoom:

def incoming(message: Message): Unit =
println(s"Broadcasting: ${message.v}")
members = members.flatMap { (id, member) =>
members = members.flatMap: (id, member) =>
selectOrClosed(member.channel.sendClause(message), Default(())) match
case member.channel.Sent() => Some((id, member))
case _: ChannelClosed =>
Expand All @@ -42,7 +45,6 @@ class ChatRoom:
case DefaultResult(_) =>
println(s"Buffer for member $id full, not sending message")
Some((id, member))
}

//

Expand All @@ -53,27 +55,25 @@ val chatEndpoint = endpoint.get
.in("chat")
.out(webSocketBody[Message, TextPlain, Message, TextPlain](OxStreams))

def chatProcessor(a: ActorRef[ChatRoom]): OxStreams.Pipe[Message, Message] =
incoming => {
val member = ChatMember.create
def chatProcessor(a: ActorRef[ChatRoom]): OxStreams.Pipe[Message, Message] = incoming =>
// returning a flow which, when run, creates a scope to handle the incoming & outgoing messages
Flow.usingEmit: emit =>
supervised:
val member = ChatMember.create

a.tell(_.connected(member))
a.tell(_.connected(member))

fork {
incoming.foreach { msg =>
a.tell(_.incoming(msg))
}
// all incoming messages are processed (= client closed), completing the outgoing channel as well
member.channel.done()
}
fork:
incoming.runForeach: msg =>
a.tell(_.incoming(msg))
// all incoming messages are processed (= client closed), completing the outgoing channel as well
member.channel.done()

// however the scope ends (client close or error), we need to notify the chat room
releaseAfterScope {
a.tell(_.disconnected(member))
}
// however the scope ends (client close or error), we need to notify the chat room
releaseAfterScope:
a.tell(_.disconnected(member))

member.channel
}
FlowEmit.channelToEmit(member.channel, emit)

object WebSocketChatNettySyncServer extends OxApp:
override def run(args: Vector[String])(using Ox): ExitCode =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ object WebSocketNettySyncServer:
// Alternative logic (not used here): requests and responses can be treated separately, for example to emit frames
// to the client from another source.
val wsPipe2: Pipe[String, String] = { in =>
val flowLeft: Flow[Either[String, String]] = Flow.fromSource(in).map(Left(_))
val flowLeft: Flow[Either[String, String]] = in.map(Left(_))
// emit periodic responses
val flowRight: Flow[Either[String, String]] = Flow.tick(1.second).map(_ => System.currentTimeMillis()).map(_.toString).map(Right(_))

// ignore whatever is sent by the client (represented as `Left`)
flowLeft.merge(flowRight, propagateDoneLeft = true).collect { case Right(s) => s }.runToChannel()
flowLeft.merge(flowRight, propagateDoneLeft = true).collect { case Right(s) => s }
}

// The WebSocket endpoint, builds the pipeline in serverLogicSuccess
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package sttp.tapir.server.netty.sync

import ox.Ox
import ox.channels.Source
import ox.flow.Flow
import sttp.capabilities.Streams

trait OxStreams extends Streams[OxStreams]:
override type BinaryStream = Nothing
override type Pipe[A, B] = Ox ?=> Source[A] => Source[B]
override type Pipe[A, B] = Flow[A] => Flow[B]

object OxStreams extends OxStreams
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package sttp.tapir.server.netty.sync.internal

import _root_.ox.*
import io.netty.channel.ChannelHandlerContext
import sttp.capabilities
import sttp.model.HasHeaders
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import sttp.tapir.server.netty.sync.internal.ox.OxDispatcher
import scala.concurrent.duration.*
import scala.concurrent.{Await, Future}
import scala.util.control.NonFatal
import ox.flow.Flow

/** A reactive Processor, which is both a Publisher and a Subscriber
*
Expand Down Expand Up @@ -65,10 +66,7 @@ private[sync] class OxProcessor[A, B](
if subscriber == null then throw new NullPointerException("Subscriber cannot be null")
val wrappedSubscriber = wrapSubscriber(subscriber)
pipelineForkFuture = oxDispatcher.runAsync {
val outgoingResponses: Source[B] = pipeline((channel: Source[A]).mapAsView { e =>
requestsSubscription.request(1)
e
})
val outgoingResponses: Source[B] = pipeline(Flow.fromSource(channel).tap(_ => requestsSubscription.request(1))).runToChannel()
val channelSubscription = new ChannelSubscription(wrappedSubscriber, outgoingResponses)
subscriber.onSubscribe(channelSubscription)
channelSubscription.runBlocking() // run the main loop which reads from the channel if there's demand
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import io.netty.handler.codec.http.websocketx.{CloseWebSocketFrame, WebSocketClo
import org.reactivestreams.{Processor, Subscriber, Subscription}
import org.slf4j.LoggerFactory
import ox.*
import ox.channels.{ChannelClosedException, Source}
import ox.channels.ChannelClosedException
import sttp.tapir.model.WebSocketFrameDecodeFailure
import sttp.tapir.server.netty.internal.ws.WebSocketFrameConverters.*
import sttp.tapir.server.netty.sync.OxStreams
Expand Down Expand Up @@ -35,29 +35,20 @@ private[sync] object OxSourceWebSocketProcessor:
case x: DecodeResult.Value[REQ] @unchecked => x.v
}

val frame2FramePipe: OxStreams.Pipe[NettyWebSocketFrame, NettyWebSocketFrame] = ox ?=>
val frame2FramePipe: OxStreams.Pipe[NettyWebSocketFrame, NettyWebSocketFrame] = incoming =>
val closeSignal = new Semaphore(0)
(incoming: Source[NettyWebSocketFrame]) =>
val outgoing = Flow
.fromSource(incoming)
.map { f =>
val sttpFrame = nettyFrameToFrame(f)
f.release()
sttpFrame
}
.pipe(takeUntilCloseFrame(passAlongCloseFrame = o.decodeCloseRequests, closeSignal))
.pipe(optionallyConcatenateFrames(o.concatenateFragmentedFrames))
.map(decodeFrame)
.runToChannel()
.pipe(processingPipe)
.mapAsView(r => frameToNettyFrame(o.responses.encode(r)))

// when the client closes the connection, we need to close the outgoing channel as well - this needs to be
// done in the client's pipeline code; monitoring that this happens within a timeout after the close happens
monitorOutgoingClosedAfterClientClose(closeSignal, outgoing)

outgoing
end frame2FramePipe
incoming
.map { f =>
val sttpFrame = nettyFrameToFrame(f)
f.release()
sttpFrame
}
.pipe(takeUntilCloseFrame(passAlongCloseFrame = o.decodeCloseRequests, closeSignal))
.pipe(optionallyConcatenateFrames(o.concatenateFragmentedFrames))
.map(decodeFrame)
.pipe(processingPipe)
.pipe(monitorOutgoingClosedAfterClientClose(closeSignal))
.map(r => frameToNettyFrame(o.responses.encode(r)))

// We need this kind of interceptor to make Netty reply correctly to closed channel or error
def wrapSubscriberWithNettyCallback[B](sub: Subscriber[? >: B]): Subscriber[? >: B] = new Subscriber[B] {
Expand Down Expand Up @@ -87,21 +78,29 @@ private[sync] object OxSourceWebSocketProcessor:
f.takeWhile(
{
case _: WebSocketFrame.Close => closeSignal.release(); false
case _ => true
case f => true
},
includeFirstFailing = passAlongCloseFrame
)

private def monitorOutgoingClosedAfterClientClose(closeSignal: Semaphore, outgoing: Source[_])(using Ox): Unit =
// will be interrupted when outgoing is completed
fork {
closeSignal.acquire()
sleep(outgoingCloseAfterCloseTimeout)
if !outgoing.isClosedForReceive then
logger.error(
s"WebSocket outgoing messages channel either not drained, or not closed, " +
s"$outgoingCloseAfterCloseTimeout after receiving a close frame from the client! " +
s"Make sure to complete the outgoing channel in your pipeline, once the incoming " +
s"channel is done!"
)
}.discard
private def monitorOutgoingClosedAfterClientClose[T](closeSignal: Semaphore)(outgoing: Flow[T]): Flow[T] =
// when the client closes the connection, the outgoing flow has to be completed as well, in the client's pipeline
// code; monitoring that this happens within a timeout after the close happens
Flow.usingEmit { emit =>
unsupervised {
forkUnsupervised {
// after the close frame is received from the client, waiting for the given grace period for the flow to
// complete. This will end this scope, and interrupt the `sleep`. If this doesn't happen, logging an error.
closeSignal.acquire()
sleep(outgoingCloseAfterCloseTimeout)
logger.error(
s"WebSocket outgoing messages flow either not drained, or not closed, " +
s"$outgoingCloseAfterCloseTimeout after receiving a close frame from the client! " +
s"Make sure to complete the outgoing flow in your pipeline, once the incoming " +
s"flow is done!"
)
}

outgoing.runToEmit(emit)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import org.scalatest.funsuite.AsyncFunSuite
import org.scalatest.matchers.should.Matchers.*
import org.slf4j.LoggerFactory
import ox.*
import ox.channels.*
import sttp.capabilities.WebSockets
import sttp.capabilities.fs2.Fs2Streams
import sttp.client3.*
Expand All @@ -21,11 +20,11 @@ import sttp.tapir.*
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.tests.*
import sttp.tapir.tests.*
import sttp.ws.{WebSocket, WebSocketFrame}

import java.util.concurrent.{CompletableFuture, TimeUnit}
import scala.concurrent.Future
import scala.concurrent.duration.FiniteDuration
import ox.flow.Flow
import scala.annotation.nowarn

class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll {

Expand All @@ -43,46 +42,13 @@ class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll {
.tests() ++
new ServerGracefulShutdownTests(createServerTest, sleeper).tests() ++
new ServerWebSocketTests(createServerTest, OxStreams, autoPing = true, failingPipe = true, handlePong = true) {
override def functionToPipe[A, B](f: A => B): OxStreams.Pipe[A, B] = ox ?=> in => in.map(f)
override def emptyPipe[A, B]: OxStreams.Pipe[A, B] = _ => Source.empty

import createServerTest._
override def tests(): List[Test] = super.tests() ++ List({
val released: CompletableFuture[Boolean] = new CompletableFuture[Boolean]()
testServer(
endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain].apply(streams)),
"closes supervision scope when client closes Web Socket"
)((_: Unit) =>
val pipe: OxStreams.Pipe[String, String] = in => {
releaseAfterScope {
released.complete(true).discard
}
in
}
Right(pipe)
) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("test1")
_ <- ws.close()
_ <- ws.receiveText()
closeResponse <- ws.eitherClose(ws.receiveText())
} yield closeResponse
})
.get(baseUri.scheme("ws"))
.send(backend)
.map { r =>
r.body.value shouldBe Left(WebSocketFrame.Close(1000, "normal closure"))
released.get(15, TimeUnit.SECONDS) shouldBe true
}
}
})
override def functionToPipe[A, B](f: A => B): OxStreams.Pipe[A, B] = _.map(f)
override def emptyPipe[A, B]: OxStreams.Pipe[A, B] = _ => Flow.empty
}.tests()

tests.foreach { t =>
if (testNameFilter.forall(filter => t.name.contains(filter))) {
implicit val pos: Position = t.pos
@nowarn implicit val pos: Position = t.pos // used by test macro

this.test(t.name)(t.f())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ object NettySyncServerRunner {
val wsBaseEndpoint = endpoint.get.in("ws" / "ts")

val wsPipe: OxStreams.Pipe[Long, Long] = { in =>
fork {
in.drain()
}
Flow.tick(WebSocketSingleResponseLag).map(_ => System.currentTimeMillis()).runToChannel()
in.map(_ => -1L)
.merge(
Flow.tick(WebSocketSingleResponseLag).map(_ => System.currentTimeMillis()),
propagateDoneLeft = true
)
.collect { case n if n > 0 => n }
}

val wsEndpoint: Endpoint[Unit, Unit, Unit, OxStreams.Pipe[Long, Long], OxStreams with WebSockets] = wsBaseEndpoint
Expand Down

0 comments on commit fc45215

Please sign in to comment.