Skip to content

Commit

Permalink
Yamux
Browse files Browse the repository at this point in the history
  • Loading branch information
erwin-kok committed Jan 29, 2024
1 parent 66b0d91 commit ec7a912
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ import org.erwinkok.result.Result
import org.erwinkok.result.errorMessage
import org.erwinkok.result.flatMap
import org.erwinkok.result.getOrElse
import org.erwinkok.result.map
import org.erwinkok.result.onFailure
import org.erwinkok.result.onSuccess
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicLong
import kotlin.coroutines.cancellation.CancellationException
import kotlin.experimental.and
import kotlin.system.measureNanoTime
import kotlin.time.Duration
import kotlin.time.Duration.Companion.nanoseconds
Expand Down Expand Up @@ -272,7 +272,7 @@ class Session(
}


private fun goAway(goAwayProtoErr: Int): YamuxHeader {
private fun goAway(goAwayProtoErr: GoAwayType): YamuxHeader {
TODO("Not yet implemented")
}

Expand Down Expand Up @@ -328,10 +328,9 @@ class Session(
.flatMap { header ->
extendKeepalive()
when (header.type) {
YamuxConst.typeData, YamuxConst.typeWindowUpdate -> handleStreamMessage(header)
YamuxConst.typePing -> handlePing(header)
YamuxConst.typeGoAway -> handleGoAway(header)
else -> Err(YamuxConst.errInvalidMsgType)
FrameType.TypeData, FrameType.TypeWindowUpdate -> handleStreamMessage(header)
FrameType.TypePing -> handlePing(header)
FrameType.TypeGoAway -> handleGoAway(header)
}
}
.onFailure {
Expand All @@ -354,29 +353,29 @@ class Session(
private suspend fun handleStreamMessage(header: YamuxHeader): Result<Unit> {
val id = header.streamId
val flags = header.flags
if ((flags and YamuxConst.flagSyn) == YamuxConst.flagSyn) {
if (flags.hasFlag(Flag.flagSyn)) {
incomingStream(id)
.onFailure { return Err(it) }
}
val stream = streamLock.withLock {
streams[id]
}
if (stream == null) {
if (header.type == YamuxConst.typeData && header.length > 0) {
if (header.type == FrameType.TypeData && header.length > 0) {
logger.warn { "[WARN] yamux: Discarding data for stream: $id" }
connection.input.readPacket(header.length).close()
} else {
logger.warn { "[WARN] yamux: frame for missing stream: $id" }
}
return Ok(Unit)
}
if (header.type == YamuxConst.typeWindowUpdate) {
if (header.type == FrameType.TypeWindowUpdate) {
stream.increaseSendWindow(header, flags)
return Ok(Unit)
}
stream.readData(header, flags, connection.input)
.onFailure {
sendMessage(goAway(YamuxConst.goAwayProtoErr))
sendMessage(goAway(GoAwayType.GoAwayProtoError))
return Err(it)
}
return Ok(Unit)
Expand All @@ -385,7 +384,7 @@ class Session(
private suspend fun handlePing(header: YamuxHeader): Result<Unit> {
val flags = header.flags
val pingId = header.length
if ((flags and YamuxConst.flagSyn) == YamuxConst.flagSyn) {
if (flags.hasFlag(Flag.flagSyn)) {
pongChannel.trySend(pingId)
.onFailure {
logger.warn { "[WARN] yamux: dropped ping reply" }
Expand All @@ -402,28 +401,25 @@ class Session(
}

private fun handleGoAway(header: YamuxHeader): Result<Unit> {
val code = header.length
when (code) {
YamuxConst.goAwayNormal -> {
remoteGoAway.set(true)
}

YamuxConst.goAwayProtoErr -> {
logger.error { "[ERR] yamux: received protocol error go away" }
return Err("yamux protocol error")
}
return GoAwayType.fromInt(header.length)
.map { code ->
when (code) {
GoAwayType.GoAwayNormal -> {
remoteGoAway.set(true)
Ok(Unit)
}

YamuxConst.goAwayInternalErr -> {
logger.error { "[ERR] yamux: received internal error go away" }
return Err("remote yamux internal error")
}
GoAwayType.GoAwayProtoError -> {
logger.error { "[ERR] yamux: received protocol error go away" }
Err("yamux protocol error")
}

else -> {
logger.error { "[ERR] yamux: received unexpected go away" }
return Err("unexpected go away received")
GoAwayType.GoAwayInternalError -> {
logger.error { "[ERR] yamux: received internal error go away" }
Err("remote yamux internal error")
}
}
}
}
return Ok(Unit)
}

private suspend fun incomingStream(id: Int): Result<Unit> {
Expand All @@ -433,7 +429,7 @@ class Session(
// }
// Reject immediately if we are doing a go away
if (localGoAway.get()) {
return sendMessage(YamuxHeader(YamuxConst.typeWindowUpdate, YamuxConst.flagRst, id, 0))
return sendMessage(YamuxHeader(FrameType.TypeWindowUpdate, Flags.of(Flag.flagRst), id, 0))
}
// Allocate a new stream
val mm = memoryManager
Expand All @@ -454,7 +450,7 @@ class Session(
streamLock.withLock {
if (streams.contains(id)) {
logger.error { "[ERR] yamux: duplicate stream declared" }
sendMessage(goAway(YamuxConst.goAwayProtoErr))
sendMessage(goAway(GoAwayType.GoAwayProtoError))
.onFailure {
logger.warn { "[WARN] yamux: failed to send go away: ${errorMessage(it)}" }
}
Expand All @@ -465,7 +461,7 @@ class Session(
if (numIncomingStreams >= config.maxIncomingStreams) {
// too many active streams at the same time
logger.warn { "[WARN] yamux: MaxIncomingStreams exceeded, forcing stream reset" }
val result = sendMessage(YamuxHeader(YamuxConst.typeWindowUpdate, YamuxConst.flagRst, id, 0))
val result = sendMessage(YamuxHeader(FrameType.TypeWindowUpdate, Flags.of(Flag.flagRst), id, 0))
span?.done()
return result
}
Expand All @@ -477,7 +473,7 @@ class Session(
// Backlog exceeded! RST the stream
logger.warn { "[WARN] yamux: backlog exceeded, forcing stream reset" }
deleteStream(id)
val result = sendMessage(YamuxHeader(YamuxConst.typeWindowUpdate, YamuxConst.flagRst, id, 0))
val result = sendMessage(YamuxHeader(FrameType.TypeWindowUpdate, Flags.of(Flag.flagRst), id, 0))
span?.done()
return result
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
// Copyright (c) 2024 Erwin Kok. BSD-3-Clause license. See LICENSE file for more details.
package org.erwinkok.libp2p.muxer.yamux

import org.erwinkok.libp2p.muxer.yamux.YamuxConst.errInvalidFrameType
import org.erwinkok.libp2p.muxer.yamux.YamuxConst.errInvalidGoAwayType
import org.erwinkok.result.Err
import org.erwinkok.result.Error
import org.erwinkok.result.Ok
import org.erwinkok.result.Result
import kotlin.experimental.and
import kotlin.experimental.or

object YamuxConst {
val errInvalidVersion = Error("Invalid protocol version")
val errInvalidMsgType = Error("Invalid message type")
val errInvalidFrameType = Error("Invalid frame type")
val errInvalidGoAwayType = Error("Invalid GoAway type")
val errSessionShutdown = Error("Session shutdown")
val errStreamsExhausted = Error("streams exhausted")
val errDuplicateStream = Error("duplicate stream initiated")
Expand All @@ -20,27 +28,67 @@ object YamuxConst {

const val protoVersion = 0.toByte()

const val typeData = 0.toByte() // Data is used for data frames. They are followed by length bytes worth of payload.
const val typeWindowUpdate = 1.toByte() // WindowUpdate is used to change the window of a given stream. The length indicates the delta update to the window.
const val typePing = 2.toByte() // Ping is sent as a keep-alive or to measure the RTT. The StreamID and Length value are echoed back in the response.
const val typeGoAway = 3.toByte() // GoAway is sent to terminate a session. The StreamID should be 0 and the length is an error code.

const val flagSyn = 1.toShort() // SYN is sent to signal a new stream. May be sent with a data payload
const val flagAck = 2.toShort() // ACK is sent to acknowledge a new stream. May be sent with a data payload
const val flagFin = 4.toShort() // FIN is sent to half-close the given stream. May be sent with a data payload.
const val flagRst = 8.toShort() // RST is used to hard close a given stream.

// initialStreamWindow is the initial stream window size.
// It's not an implementation choice, the value defined in the specification.
const val initialStreamWindow: Int = 256 * 1024
const val maxStreamWindow: Int = 16 * 1024 * 1024
}

enum class FrameType(val code: Byte) {
TypeData(0), // Data is used for data frames. They are followed by length bytes worth of payload.
TypeWindowUpdate(1), // WindowUpdate is used to change the window of a given stream. The length indicates the delta update to the window.
TypePing(2), // Ping is sent as a keep-alive or to measure the RTT. The StreamID and Length value are echoed back in the response.
TypeGoAway(3) // GoAway is sent to terminate a session. The StreamID should be 0 and the length is an error code.
;

companion object {
private val intToTypeMap = entries.associateBy { it.code }

fun fromInt(value: Byte): Result<FrameType> {
val type = intToTypeMap[value] ?: return Err(errInvalidFrameType)
return Ok(type)
}
}
}

enum class GoAwayType(val code: Int) {
GoAwayNormal(0), // goAwayNormal is sent on a normal termination
GoAwayProtoError(1), // goAwayProtoErr sent on a protocol error
GoAwayInternalError(2), // goAwayInternalErr sent on an internal error
;

companion object {
private val intToTypeMap = GoAwayType.entries.associateBy { it.code }

fun fromInt(value: Int): Result<GoAwayType> {
val type = intToTypeMap[value] ?: return Err(errInvalidGoAwayType)
return Ok(type)
}
}
}

enum class Flag(val code: Short) {
flagSyn(1), // SYN is sent to signal a new stream. May be sent with a data payload
flagAck(2), // ACK is sent to acknowledge a new stream. May be sent with a data payload
flagFin(4), // FIN is sent to half-close the given stream. May be sent with a data payload.
flagRst(8), // RST is used to hard close a given stream.
}

class Flags(val code: Short = 0) {
fun hasFlag(flag: Flag): Boolean {
return ((code and flag.code) == flag.code)
}

// goAwayNormal is sent on a normal termination
const val goAwayNormal = 0
companion object {
private val intToTypeMap = GoAwayType.entries.associateBy { it.code }

// goAwayProtoErr sent on a protocol error
const val goAwayProtoErr = 1
fun fromShort(value: Short): Result<Flags> {
return Ok(Flags(value))
}

// goAwayInternalErr sent on an internal error
const val goAwayInternalErr = 2
fun of(vararg values: Flag): Flags {
val code = values.fold(0.toShort()) { code, flag -> code or flag.code }
return Flags(code)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import org.erwinkok.result.Result

private val logger = KotlinLogging.logger {}

class YamuxHeader(val type: Byte, val flags: Short, val streamId: Int, val length: Int) {
class YamuxHeader(val type: FrameType, val flags: Flags, val streamId: Int, val length: Int) {
override fun toString(): String {
return "type=$type, flags: $flags, id=$streamId, length=$length"
}
Expand All @@ -28,16 +28,18 @@ internal suspend fun ByteReadChannel.readYamuxHeader(): Result<YamuxHeader> {
logger.error { "yamux: Invalid protocol version $version" }
return Err(YamuxConst.errInvalidVersion)
}
if (type < YamuxConst.typeData || type > YamuxConst.typeGoAway) {
return Err(YamuxConst.errInvalidMsgType)
return Result.zip(
{ FrameType.fromInt(type) },
{ Flags.fromShort(flags) },
) { t, f ->
Ok(YamuxHeader(t, f, streamId, length))
}
return Ok(YamuxHeader(type, flags, streamId, length))
}

internal fun BytePacketBuilder.writeYamuxHeader(header: YamuxHeader) {
writeByte(YamuxConst.protoVersion)
writeByte(header.type)
writeShort(header.flags)
writeByte(header.type.code)
writeShort(header.flags.code)
writeInt(header.streamId)
writeInt(header.length)
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ import org.erwinkok.result.errorMessage
import java.io.IOException
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicInteger
import kotlin.experimental.and

enum class StreamState {
StreamInit,
Expand Down Expand Up @@ -226,10 +225,10 @@ class YamuxMuxedStream(
session.closeStream(streamId)
}

private suspend fun processFlags(flags: Short) {
private suspend fun processFlags(flags: Flags) {
var closeStream = false

if ((flags and YamuxConst.flagAck) == YamuxConst.flagAck) {
if (flags.hasFlag(Flag.flagAck)) {
stateLock.withLock {
if (state == StreamState.StreamSYNSent) {
state = StreamState.StreamEstablished
Expand All @@ -238,7 +237,7 @@ class YamuxMuxedStream(
session.establishStream(streamId)
}

if ((flags and YamuxConst.flagFin) == YamuxConst.flagFin) {
if (flags.hasFlag(Flag.flagFin)) {
var notify = false
stateLock.withLock {
if (readState == HalfStreamState.HalfOpen) {
Expand All @@ -256,7 +255,7 @@ class YamuxMuxedStream(
}
}

if ((flags and YamuxConst.flagRst) == YamuxConst.flagRst) {
if (flags.hasFlag(Flag.flagRst)) {
stateLock.withLock {
if (readState == HalfStreamState.HalfOpen) {
readState = HalfStreamState.HalfReset
Expand All @@ -280,14 +279,14 @@ class YamuxMuxedStream(
// asyncNotify(s.sendNotifyCh)
}

internal suspend fun increaseSendWindow(header: YamuxHeader, flags: Short) {
internal suspend fun increaseSendWindow(header: YamuxHeader, flags: Flags) {
processFlags(flags)
// Increase window, unblock a sender
sendWindow.addAndGet(header.length)
// asyncNotify(s.sendNotifyCh)
}

internal suspend fun readData(header: YamuxHeader, flags: Short, input: ByteReadChannel): Result<Unit> {
internal suspend fun readData(header: YamuxHeader, flags: Flags, input: ByteReadChannel): Result<Unit> {
processFlags(flags)

// Check that our recv window is not exceeded
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ internal class YamuxMultiplexerTest : TestWithLeakCheck {
// }
//

private fun createPacket(type: Byte, flags: Short, streamId: Int, body: ByteArray? = null): ByteReadPacket {
private fun createPacket(type: FrameType, flags: Flags, streamId: Int, body: ByteArray? = null): ByteReadPacket {
return buildPacket(pool) {
val size = body?.size ?: 0
writeYamuxHeader(YamuxHeader(type, flags, streamId, size))
Expand Down

0 comments on commit ec7a912

Please sign in to comment.