Skip to content

Commit

Permalink
Expose lavalink sessions for plugins (#1042)
Browse files Browse the repository at this point in the history
  • Loading branch information
topi314 authored Apr 28, 2024
1 parent 0736cc4 commit 78c090c
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
17 changes: 9 additions & 8 deletions LavalinkServer/src/main/java/lavalink/server/io/SocketServer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package lavalink.server.io

import com.sedmelluq.discord.lavaplayer.player.AudioPlayerManager
import dev.arbjerg.lavalink.api.AudioPluginInfoModifier
import dev.arbjerg.lavalink.api.ISocketServer
import dev.arbjerg.lavalink.api.PluginEventHandler
import dev.arbjerg.lavalink.protocol.v4.Message
import dev.arbjerg.lavalink.protocol.v4.PlayerState
Expand All @@ -46,11 +47,11 @@ final class SocketServer(
koeOptions: KoeOptions,
private val eventHandlers: List<PluginEventHandler>,
private val pluginInfoModifiers: List<AudioPluginInfoModifier>
) : TextWebSocketHandler() {
) : TextWebSocketHandler(), ISocketServer {

// sessionID <-> Session
val contextMap = ConcurrentHashMap<String, SocketContext>()
private val resumableSessions = mutableMapOf<String, SocketContext>()
override val sessions = ConcurrentHashMap<String, SocketContext>()
override val resumableSessions = mutableMapOf<String, SocketContext>()
private val koe = Koe.koe(koeOptions)
private val statsCollector = StatsCollector(this)
private val charPool = ('a'..'z') + ('0'..'9')
Expand Down Expand Up @@ -81,12 +82,12 @@ final class SocketServer(
var sessionId: String
do {
sessionId = List(16) { charPool.random() }.joinToString("")
} while (contextMap[sessionId] != null)
} while (sessions[sessionId] != null)
return sessionId
}

val contexts: Collection<SocketContext>
get() = contextMap.values
get() = sessions.values

@Suppress("UastIncorrectHttpHeaderInspection")
override fun afterConnectionEstablished(session: WebSocketSession) {
Expand All @@ -100,7 +101,7 @@ final class SocketServer(

if (resumable != null) {
session.attributes["sessionId"] = resumable.sessionId
contextMap[resumable.sessionId] = resumable
sessions[resumable.sessionId] = resumable
resumable.resume(session)
log.info("Resumed session with id $sessionId")
resumable.eventEmitter.onWebSocketOpen(true)
Expand All @@ -123,7 +124,7 @@ final class SocketServer(
eventHandlers,
pluginInfoModifiers
)
contextMap[sessionId] = socketContext
sessions[sessionId] = socketContext
socketContext.sendMessage(Message.Serializer, Message.ReadyEvent(false, sessionId))
socketContext.eventEmitter.onWebSocketOpen(false)
if (clientName != null) {
Expand All @@ -140,7 +141,7 @@ final class SocketServer(
}

override fun afterConnectionClosed(session: WebSocketSession, status: CloseStatus) {
val context = contextMap.remove(session.attributes["sessionId"]) ?: return
val context = sessions.remove(session.attributes["sessionId"]) ?: return
if (context.resumable) {
resumableSessions.remove(context.sessionId)?.let { removed ->
log.warn(
Expand Down
2 changes: 1 addition & 1 deletion LavalinkServer/src/main/java/lavalink/server/util/util.kt
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ fun getRootCause(throwable: Throwable?): Throwable {
}

fun socketContext(socketServer: SocketServer, sessionId: String) =
socketServer.contextMap[sessionId] ?: throw ResponseStatusException(HttpStatus.NOT_FOUND, "Session not found")
socketServer.sessions[sessionId] ?: throw ResponseStatusException(HttpStatus.NOT_FOUND, "Session not found")

fun existingPlayer(socketContext: SocketContext, guildId: Long) =
socketContext.players[guildId] ?: throw ResponseStatusException(HttpStatus.NOT_FOUND, "Player not found")
Expand Down
17 changes: 17 additions & 0 deletions plugin-api/src/main/java/dev/arbjerg/lavalink/api/ISocketServer.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package dev.arbjerg.lavalink.api

/**
* Represents a Lavalink server which handles WebSocket connections.
*/
interface ISocketServer {
/**
* A map of all active sessions by their session id.
*/
val sessions: Map<String, ISocketContext>

/**
* A map of all resumable sessions by their session id.
* A session is resumable if the client configured resuming and has disconnected.
*/
val resumableSessions: Map<String, ISocketContext>
}

0 comments on commit 78c090c

Please sign in to comment.