Skip to content

Commit

Permalink
Update to Ox 0.5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Oct 7, 2024
1 parent 54c1d8a commit 854eb3e
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,38 @@ package sttp.tapir.examples.websocket

import ox.*
import ox.channels.*
import ox.flow.Flow
import sttp.capabilities.WebSockets
import sttp.tapir.*
import sttp.tapir.server.netty.sync.OxStreams
import sttp.tapir.server.netty.sync.OxStreams.Pipe
import sttp.tapir.server.netty.sync.NettySyncServer
import sttp.ws.WebSocketFrame

import java.util.concurrent.atomic.AtomicBoolean
import scala.concurrent.duration.*

object WebSocketNettySyncServer:
// Web socket endpoint
val wsEndpoint =
endpoint.get
.in("ws")
.out(
webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](OxStreams)
.concatenateFragmentedFrames(false) // All these options are supported by tapir-netty
.ignorePong(true)
.autoPongOnPing(true)
.decodeCloseRequests(false)
.decodeCloseResponses(false)
.autoPing(Some((10.seconds, WebSocketFrame.Ping("ping-content".getBytes))))
)
.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](OxStreams))

// Your processor transforming a stream of requests into a stream of responses
val wsPipe: Pipe[String, String] = requestStream => requestStream.map(_.toUpperCase)

// Alternative logic (not used here): requests and responses can be treated separately, for example to emit frames
// to the client from another source.
val wsPipe2: Pipe[String, String] = { in =>
val running = new AtomicBoolean(true) // TODO use https://github.com/softwaremill/ox/issues/209 once available
fork {
in.drain() // read and ignore requests
running.set(false) // stopping the responses
}
val flowLeft: Flow[Either[String, String]] = Flow.fromSource(in).map(Left(_))
// emit periodic responses
Source.tick(1.second).takeWhile(_ => running.get()).map(_ => System.currentTimeMillis()).map(_.toString)
val flowRight: Flow[Either[String, String]] = Flow.tick(1.second).map(_ => System.currentTimeMillis()).map(_.toString).map(Right(_))

// ignore whatever is sent by the client (represented as `Left`)
flowLeft.merge(flowRight, propagateDoneLeft = true).collect { case Right(s) => s }.runToChannel()
}

// The WebSocket endpoint, builds the pipeline in serverLogicSuccess
val wsServerEndpoint = wsEndpoint.handleSuccess(_ => wsPipe2)
val wsServerEndpoint = wsEndpoint.handleSuccess(_ => wsPipe)

// A regular /GET endpoint
val helloWorldEndpoint =
Expand Down
2 changes: 1 addition & 1 deletion project/Versions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object Versions {
val json4s = "4.0.7"
val metrics4Scala = "4.3.2"
val nettyReactiveStreams = "3.0.2"
val ox = "0.4.0"
val ox = "0.5.0"
val reactiveStreams = "1.0.4"
val sprayJson = "1.3.6"
val scalaCheck = "1.18.1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ private[sync] class OxProcessor[A, B](
if (pipelineForkFuture != null) try {
val pipelineFork = Await.result(pipelineForkFuture, pipelineCancelationTimeout)
oxDispatcher.runAsync {
race(
raceSuccess(
{
ox.sleep(pipelineCancelationTimeout)
logger.error(s"Pipeline fork cancelation did not complete in time ($pipelineCancelationTimeout).")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import java.io.IOException
import java.util.concurrent.Semaphore

import scala.concurrent.duration.*
import ox.flow.Flow

private[sync] object OxSourceWebSocketProcessor:
private val logger = LoggerFactory.getLogger(getClass.getName)
Expand All @@ -37,15 +38,17 @@ private[sync] object OxSourceWebSocketProcessor:
val frame2FramePipe: OxStreams.Pipe[NettyWebSocketFrame, NettyWebSocketFrame] = ox ?=>
val closeSignal = new Semaphore(0)
(incoming: Source[NettyWebSocketFrame]) =>
val outgoing = incoming
.mapAsView { f =>
val outgoing = Flow
.fromSource(incoming)
.map { f =>
val sttpFrame = nettyFrameToFrame(f)
f.release()
sttpFrame
}
.pipe(takeUntilCloseFrame(passAlongCloseFrame = o.decodeCloseRequests, closeSignal))
.pipe(optionallyConcatenateFrames(o.concatenateFragmentedFrames))
.mapAsView(decodeFrame)
.map(decodeFrame)
.runToChannel()
.pipe(processingPipe)
.mapAsView(r => frameToNettyFrame(o.responses.encode(r)))

Expand Down Expand Up @@ -76,14 +79,12 @@ private[sync] object OxSourceWebSocketProcessor:
new OxProcessor(oxDispatcher, frame2FramePipe, wrapSubscriberWithNettyCallback)
end apply

private def optionallyConcatenateFrames(doConcatenate: Boolean)(s: Source[WebSocketFrame])(using Ox): Source[WebSocketFrame] =
if doConcatenate then s.mapStateful(() => None: Accumulator)(accumulateFrameState).collectAsView { case Some(f: WebSocketFrame) => f }
else s
private def optionallyConcatenateFrames(doConcatenate: Boolean)(f: Flow[WebSocketFrame]): Flow[WebSocketFrame] =
if doConcatenate then f.mapStateful(() => None: Accumulator)(accumulateFrameState).collect { case Some(f: WebSocketFrame) => f }
else f

private def takeUntilCloseFrame(passAlongCloseFrame: Boolean, closeSignal: Semaphore)(
s: Source[WebSocketFrame]
)(using Ox): Source[WebSocketFrame] =
s.takeWhile(
private def takeUntilCloseFrame(passAlongCloseFrame: Boolean, closeSignal: Semaphore)(f: Flow[WebSocketFrame]): Flow[WebSocketFrame] =
f.takeWhile(
{
case _: WebSocketFrame.Close => closeSignal.release(); false
case _ => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sttp.tapir.server.netty.sync.perf

import ox.*
import ox.channels.*
import ox.flow.Flow
import sttp.shared.Identity
import sttp.tapir.server.netty.sync.NettySyncServerOptions
import sttp.tapir.server.netty.sync.NettySyncServerBinding
Expand Down Expand Up @@ -61,7 +62,7 @@ object NettySyncServerRunner {
fork {
in.drain()
}
Source.tick(WebSocketSingleResponseLag).map(_ => System.currentTimeMillis())
Flow.tick(WebSocketSingleResponseLag).map(_ => System.currentTimeMillis()).runToChannel()
}

val wsEndpoint: Endpoint[Unit, Unit, Unit, OxStreams.Pipe[Long, Long], OxStreams with WebSockets] = wsBaseEndpoint
Expand Down

0 comments on commit 854eb3e

Please sign in to comment.