From d9422664ce16696b6947f2327f843926cebfdc4a Mon Sep 17 00:00:00 2001 From: Josiah Glosson Date: Sat, 27 Jul 2024 09:31:19 -0500 Subject: [PATCH] Clean up some networking code --- .../worldhost/protocol/ProtocolClient.java | 52 +++-- .../protocol/WorldHostC2SMessage.java | 78 +++++-- .../protocol/WorldHostS2CMessage.java | 203 +++++++++++++----- 3 files changed, 258 insertions(+), 75 deletions(-) diff --git a/src/main/java/io/github/gaming32/worldhost/protocol/ProtocolClient.java b/src/main/java/io/github/gaming32/worldhost/protocol/ProtocolClient.java index 66f4bc2..92d3298 100644 --- a/src/main/java/io/github/gaming32/worldhost/protocol/ProtocolClient.java +++ b/src/main/java/io/github/gaming32/worldhost/protocol/ProtocolClient.java @@ -14,7 +14,9 @@ import org.apache.commons.io.input.CountingInputStream; import org.jetbrains.annotations.Nullable; +import javax.crypto.Cipher; import javax.crypto.SecretKey; +import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; @@ -68,6 +70,8 @@ public ProtocolClient(String host, boolean successToast, boolean failureToast) { CONNECTION_THREAD_BUILDER.start(() -> { HostAndPort target = null; Socket socket = null; + Cipher decryptCipher = null; + Cipher encryptCipher = null; try { target = HostAndPort.fromString(host).withDefaultPort(9646); socket = new Socket(target.getHost(), target.getPort()); @@ -75,7 +79,9 @@ public ProtocolClient(String host, boolean successToast, boolean failureToast) { final User user = authUser.join(); authUser = null; - performHandshake(socket, user, connectionId); + final SecretKey secretKey = performHandshake(socket, user, connectionId); + decryptCipher = Crypt.getCipher(Cipher.DECRYPT_MODE, secretKey); + encryptCipher = Crypt.getCipher(Cipher.ENCRYPT_MODE, secretKey); } catch (Exception e) { WorldHost.LOGGER.error("Failed to connect to {} ({}).", originalHost, target, e); if (failureToast) { @@ -106,6 +112,8 @@ public ProtocolClient(String host, boolean successToast, boolean failureToast) { WHToast.builder("world-host.wh_connect.connected").show(); } final Socket fSocket = socket; + final Cipher fDecryptCipher = decryptCipher; + final Cipher fEncryptCipher = encryptCipher; final Thread sendThread = SEND_THREAD_BUILDER.start(() -> { try { @@ -113,12 +121,17 @@ public ProtocolClient(String host, boolean successToast, boolean failureToast) { final ByteArrayOutputStream baos = new ByteArrayOutputStream(); final DataOutputStream tempDos = new DataOutputStream(baos); while (!closed) { - final var message = sendQueue.take(); - if (message.isEmpty()) break; - message.get().encode(tempDos); - dos.writeInt(baos.size()); - dos.write(baos.toByteArray()); + final var optionalMessage = sendQueue.take(); + if (optionalMessage.isEmpty()) break; + final var message = optionalMessage.get(); + message.encode(tempDos); + final byte[] data = message.isEncrypted() + ? fEncryptCipher.update(baos.toByteArray()) + : baos.toByteArray(); baos.reset(); + dos.writeInt(data.length + 1); + dos.writeByte(message.typeId() & 0xff); + dos.write(data); dos.flush(); } } catch (IOException e) { @@ -133,18 +146,25 @@ public ProtocolClient(String host, boolean successToast, boolean failureToast) { try { final DataInputStream dis = new DataInputStream(fSocket.getInputStream()); while (!closed) { - final int length = dis.readInt(); - if (length < 1) { - WorldHost.LOGGER.warn("Received invalid short packet (under 1 byte) from WH server"); - dis.skipNBytes(length); + final int length = dis.readInt() - 1; + if (length < 0) { + WorldHost.LOGGER.warn("Received invalid empty packet from WH server"); continue; } - final BoundedInputStream bis = new BoundedInputStream(dis, length); - bis.setPropagateClose(false); - final CountingInputStream cis = new CountingInputStream(bis); + final int packetId = dis.readUnsignedByte(); + final CountingInputStream cis; + if (WorldHostS2CMessage.isEncrypted(packetId)) { + final byte[] data = dis.readNBytes(length); + final byte[] decrypted = fDecryptCipher.update(data); + cis = new CountingInputStream(new ByteArrayInputStream(decrypted)); + } else { + final BoundedInputStream bis = new BoundedInputStream(dis, length); + bis.setPropagateClose(false); + cis = new CountingInputStream(bis); + } WorldHostS2CMessage message = null; try { - message = WorldHostS2CMessage.decode(new DataInputStream(cis)); + message = WorldHostS2CMessage.decode(packetId, new DataInputStream(cis)); } catch (EOFException e) { WorldHost.LOGGER.error("Message decoder read past end (length {})!", length); } catch (Exception e) { @@ -192,7 +212,7 @@ public ProtocolClient(String host, boolean successToast, boolean failureToast) { }); } - private static void performHandshake( + private static SecretKey performHandshake( Socket socket, User user, long connectionId ) throws IOException, CryptException, AuthenticationException { final DataOutputStream dos = new DataOutputStream(socket.getOutputStream()); @@ -247,6 +267,8 @@ private static void performHandshake( WorldHostC2SMessage.writeString(dos, user.getName()); dos.writeLong(connectionId); dos.flush(); + + return secretKey; } public String getOriginalHost() { diff --git a/src/main/java/io/github/gaming32/worldhost/protocol/WorldHostC2SMessage.java b/src/main/java/io/github/gaming32/worldhost/protocol/WorldHostC2SMessage.java index 94ddd06..34abd78 100644 --- a/src/main/java/io/github/gaming32/worldhost/protocol/WorldHostC2SMessage.java +++ b/src/main/java/io/github/gaming32/worldhost/protocol/WorldHostC2SMessage.java @@ -12,9 +12,13 @@ // Mirrors https://github.com/Gaming32/world-host-server-kotlin/blob/main/src/main/kotlin/io/github/gaming32/worldhostserver/WorldHostC2SMessage.kt public sealed interface WorldHostC2SMessage { record ListOnline(Collection friends) implements WorldHostC2SMessage { + @Override + public byte typeId() { + return 0; + } + @Override public void encode(DataOutputStream dos) throws IOException { - dos.writeByte(0); dos.writeInt(friends.size()); for (final UUID friend : friends) { writeUuid(dos, friend); @@ -23,17 +27,25 @@ public void encode(DataOutputStream dos) throws IOException { } record FriendRequest(UUID toUser) implements WorldHostC2SMessage { + @Override + public byte typeId() { + return 1; + } + @Override public void encode(DataOutputStream dos) throws IOException { - dos.writeByte(1); writeUuid(dos, toUser); } } record PublishedWorld(Collection friends) implements WorldHostC2SMessage { + @Override + public byte typeId() { + return 2; + } + @Override public void encode(DataOutputStream dos) throws IOException { - dos.writeByte(2); dos.writeInt(friends.size()); for (final UUID friend : friends) { writeUuid(dos, friend); @@ -42,9 +54,13 @@ public void encode(DataOutputStream dos) throws IOException { } record ClosedWorld(Collection friends) implements WorldHostC2SMessage { + @Override + public byte typeId() { + return 3; + } + @Override public void encode(DataOutputStream dos) throws IOException { - dos.writeByte(3); dos.writeInt(friends.size()); for (final UUID friend : friends) { writeUuid(dos, friend); @@ -54,26 +70,38 @@ public void encode(DataOutputStream dos) throws IOException { @Deprecated record RequestJoin(UUID friend) implements WorldHostC2SMessage { + @Override + public byte typeId() { + return 4; + } + @Override public void encode(DataOutputStream dos) throws IOException { - dos.writeByte(4); writeUuid(dos, friend); } } record JoinGranted(long connectionId, JoinType joinType) implements WorldHostC2SMessage { + @Override + public byte typeId() { + return 5; + } + @Override public void encode(DataOutputStream dos) throws IOException { - dos.writeByte(5); dos.writeLong(connectionId); joinType.encode(dos); } } record QueryRequest(Collection friends) implements WorldHostC2SMessage { + @Override + public byte typeId() { + return 6; + } + @Override public void encode(DataOutputStream dos) throws IOException { - dos.writeByte(6); dos.writeInt(friends.size()); for (final UUID friend : friends) { writeUuid(dos, friend); @@ -83,9 +111,13 @@ public void encode(DataOutputStream dos) throws IOException { @Deprecated record QueryResponse(long connectionId, ServerStatus metadata) implements WorldHostC2SMessage { + @Override + public byte typeId() { + return 7; + } + @Override public void encode(DataOutputStream dos) throws IOException { - dos.writeByte(7); dos.writeLong(connectionId); final var buf = WorldHost.writeServerStatus(metadata); dos.writeInt(buf.readableBytes()); @@ -94,42 +126,64 @@ public void encode(DataOutputStream dos) throws IOException { } record ProxyS2CPacket(long connectionId, byte[] data) implements WorldHostC2SMessage { + @Override + public byte typeId() { + return 8; + } + @Override public void encode(DataOutputStream dos) throws IOException { - dos.writeByte(8); dos.writeLong(connectionId); dos.write(data); } } record ProxyDisconnect(long connectionId) implements WorldHostC2SMessage { + @Override + public byte typeId() { + return 9; + } + @Override public void encode(DataOutputStream dos) throws IOException { - dos.writeByte(9); dos.writeLong(connectionId); } } record RequestDirectJoin(long connectionId) implements WorldHostC2SMessage { + @Override + public byte typeId() { + return 10; + } + @Override public void encode(DataOutputStream dos) throws IOException { - dos.writeByte(10); dos.writeLong(connectionId); } } record NewQueryResponse(long connectionId, ServerStatus metadata) implements WorldHostC2SMessage { + @Override + public byte typeId() { + return 11; + } + @Override public void encode(DataOutputStream dos) throws IOException { - dos.writeByte(11); dos.writeLong(connectionId); final var buf = WorldHost.writeServerStatus(metadata); buf.readBytes(dos, buf.readableBytes()); } } + byte typeId(); + void encode(DataOutputStream dos) throws IOException; + default boolean isEncrypted() { + return false; + } + static void writeUuid(DataOutputStream dos, UUID uuid) throws IOException { dos.writeLong(uuid.getMostSignificantBits()); dos.writeLong(uuid.getLeastSignificantBits()); diff --git a/src/main/java/io/github/gaming32/worldhost/protocol/WorldHostS2CMessage.java b/src/main/java/io/github/gaming32/worldhost/protocol/WorldHostS2CMessage.java index 964be14..34330cc 100644 --- a/src/main/java/io/github/gaming32/worldhost/protocol/WorldHostS2CMessage.java +++ b/src/main/java/io/github/gaming32/worldhost/protocol/WorldHostS2CMessage.java @@ -50,10 +50,16 @@ default boolean checkAndLogSecurity() { } record Error(String message, boolean critical) implements WorldHostS2CMessage { + public static final int ID = 0; + public Error(String message) { this(message, false); } + public static Error decode(DataInputStream dis) throws IOException { + return new Error(readString(dis), dis.read() > 0); // -1 means that there was no critical flag sent + } + @Override public void handle(ProtocolClient client) { if (critical) { @@ -68,6 +74,12 @@ public void handle(ProtocolClient client) { } record IsOnlineTo(UUID user) implements WorldHostS2CMessage { + public static final int ID = 1; + + public static IsOnlineTo decode(DataInputStream dis) throws IOException { + return new IsOnlineTo(readUuid(dis)); + } + @Override public void handle(ProtocolClient client) { if (WorldHost.isFriend(user)) { @@ -80,6 +92,12 @@ public void handle(ProtocolClient client) { } record OnlineGame(String host, int port, long ownerCid, boolean isPunchProtocol) implements WorldHostS2CMessage { + public static final int ID = 2; + + public static OnlineGame decode(DataInputStream dis) throws IOException { + return new OnlineGame(readString(dis), dis.readUnsignedShort(), dis.readLong(), dis.readBoolean()); + } + @Override public void handle(ProtocolClient client) { Minecraft.getInstance().execute(() -> { @@ -97,6 +115,12 @@ public void handle(ProtocolClient client) { } record FriendRequest(UUID fromUser, SecurityLevel security) implements WorldHostS2CMessage, SecurityCheckable { + public static final int ID = 3; + + public static FriendRequest decode(DataInputStream dis) throws IOException { + return new FriendRequest(readUuid(dis), SecurityLevel.byId(dis.readUnsignedByte())); + } + @Override public void handle(ProtocolClient client) { if (!WorldHost.CONFIG.isEnableFriends() || !checkAndLogSecurity()) return; @@ -131,6 +155,12 @@ public void handle(ProtocolClient client) { record PublishedWorld( UUID user, long connectionId, SecurityLevel security ) implements WorldHostS2CMessage, SecurityCheckable { + public static final int ID = 4; + + public static PublishedWorld decode(DataInputStream dis) throws IOException { + return new PublishedWorld(readUuid(dis), dis.readLong(), SecurityLevel.byId(dis.readUnsignedByte())); + } + @Override public void handle(ProtocolClient client) { if (!checkAndLogSecurity() || !WorldHost.isFriend(user)) return; @@ -141,6 +171,12 @@ public void handle(ProtocolClient client) { } record ClosedWorld(UUID user) implements WorldHostS2CMessage { + public static final int ID = 5; + + public static ClosedWorld decode(DataInputStream dis) throws IOException { + return new ClosedWorld(readUuid(dis)); + } + @Override public void handle(ProtocolClient client) { WorldHost.ONLINE_FRIENDS.remove(user); @@ -152,6 +188,12 @@ public void handle(ProtocolClient client) { record RequestJoin( UUID user, long connectionId, SecurityLevel security ) implements WorldHostS2CMessage, SecurityCheckable { + public static final int ID = 6; + + public static RequestJoin decode(DataInputStream dis) throws IOException { + return new RequestJoin(readUuid(dis), dis.readLong(), SecurityLevel.byId(dis.readUnsignedByte())); + } + @Override public void handle(ProtocolClient client) { if (!checkAndLogSecurity()) return; @@ -179,6 +221,12 @@ public void handle(ProtocolClient client) { record QueryRequest( UUID friend, long connectionId, SecurityLevel security ) implements WorldHostS2CMessage, SecurityCheckable { + public static final int ID = 7; + + public static QueryRequest decode(DataInputStream dis) throws IOException { + return new QueryRequest(readUuid(dis), dis.readLong(), SecurityLevel.byId(dis.readUnsignedByte())); + } + @Override public void handle(ProtocolClient client) { if (!checkAndLogSecurity() || !WorldHost.isFriend(friend)) return; @@ -190,6 +238,22 @@ public void handle(ProtocolClient client) { } record QueryResponse(UUID friend, ServerStatus metadata) implements WorldHostS2CMessage { + public static final int ID = 8; + + public static QueryResponse decode(DataInputStream dis) throws IOException { + final UUID friend = readUuid(dis); + final var buf = WorldHost.createByteBuf(); + buf.writeBytes(dis, dis.readInt()); + ServerStatus serverStatus; + try { + serverStatus = WorldHost.parseServerStatus(buf); + } catch (Exception e) { + WorldHost.LOGGER.error("Failed to parse server status", e); + serverStatus = WorldHost.createEmptyServerStatus(); + } + return new QueryResponse(friend, serverStatus); + } + @Override public void handle(ProtocolClient client) { if (WorldHost.isFriend(friend)) { @@ -199,6 +263,12 @@ public void handle(ProtocolClient client) { } record ProxyC2SPacket(long connectionId, byte[] data) implements WorldHostS2CMessage { + public static final int ID = 9; + + public static ProxyC2SPacket decode(DataInputStream dis) throws IOException { + return new ProxyC2SPacket(dis.readLong(), dis.readAllBytes()); + } + @Override public void handle(ProtocolClient client) { WorldHost.proxyPacket(connectionId, data); @@ -206,6 +276,12 @@ public void handle(ProtocolClient client) { } record ProxyConnect(long connectionId, InetAddress remoteAddr) implements WorldHostS2CMessage { + public static final int ID = 10; + + public static ProxyConnect decode(DataInputStream dis) throws IOException { + return new ProxyConnect(dis.readLong(), InetAddress.getByAddress(dis.readNBytes(dis.readUnsignedByte()))); + } + @Override public void handle(ProtocolClient client) { WorldHost.proxyConnect(connectionId, remoteAddr, () -> WorldHost.protoClient); @@ -213,6 +289,12 @@ public void handle(ProtocolClient client) { } record ProxyDisconnect(long connectionId) implements WorldHostS2CMessage { + public static final int ID = 11; + + public static ProxyDisconnect decode(DataInputStream dis) throws IOException { + return new ProxyDisconnect(dis.readLong()); + } + @Override public void handle(ProtocolClient client) { WorldHost.proxyDisconnect(connectionId); @@ -222,6 +304,14 @@ public void handle(ProtocolClient client) { record ConnectionInfo( long connectionId, String baseIp, int basePort, String userIp, int protocolVersion, int punchPort ) implements WorldHostS2CMessage { + public static final int ID = 12; + + public static ConnectionInfo decode(DataInputStream dis) throws IOException { + return new ConnectionInfo( + dis.readLong(), readString(dis), dis.readUnsignedShort(), readString(dis), dis.readInt(), dis.readUnsignedShort() + ); + } + @Override public void handle(ProtocolClient client) { WorldHost.LOGGER.info("Received {}", this); @@ -241,6 +331,14 @@ public void handle(ProtocolClient client) { } record ExternalProxyServer(String host, int port, String baseAddr, int mcPort) implements WorldHostS2CMessage { + public static final int ID = 13; + + public static ExternalProxyServer decode(DataInputStream dis) throws IOException { + return new ExternalProxyServer( + readString(dis), dis.readUnsignedShort(), readString(dis), dis.readUnsignedShort() + ); + } + @Override public void handle(ProtocolClient client) { WorldHost.LOGGER.info("Attempting to connect to WHEP server at {}, {}", host, port); @@ -252,6 +350,12 @@ public void handle(ProtocolClient client) { } record OutdatedWorldHost(String recommendedVersion) implements WorldHostS2CMessage { + public static final int ID = 14; + + public static OutdatedWorldHost decode(DataInputStream dis) throws IOException { + return new OutdatedWorldHost(readString(dis)); + } + @Override public void handle(ProtocolClient client) { final String currentVersion = WorldHost.getModVersion(WorldHost.MOD_ID); @@ -276,6 +380,12 @@ public void handle(ProtocolClient client) { } record ConnectionNotFound(long connectionId) implements WorldHostS2CMessage { + public static final int ID = 15; + + public static ConnectionNotFound decode(DataInputStream dis) throws IOException { + return new ConnectionNotFound(dis.readLong()); + } + @Override public void handle(ProtocolClient client) { Minecraft.getInstance().execute(() -> { @@ -296,6 +406,22 @@ public void handle(ProtocolClient client) { } record NewQueryResponse(UUID friend, ServerStatus metadata) implements WorldHostS2CMessage { + public static final int ID = 16; + + public static NewQueryResponse decode(DataInputStream dis) throws IOException { + final UUID friend = readUuid(dis); + final var buf = WorldHost.createByteBuf(); + buf.writeBytes(dis.readAllBytes()); + ServerStatus serverStatus; + try { + serverStatus = WorldHost.parseServerStatus(buf); + } catch (Exception e) { + WorldHost.LOGGER.error("Failed to parse server status", e); + serverStatus = WorldHost.createEmptyServerStatus(); + } + return new NewQueryResponse(friend, serverStatus); + } + @Override public void handle(ProtocolClient client) { if (WorldHost.isFriend(friend)) { @@ -305,6 +431,12 @@ public void handle(ProtocolClient client) { } record Warning(String message, boolean important) implements WorldHostS2CMessage { + public static final int ID = 17; + + public static Warning decode(DataInputStream dis) throws IOException { + return new Warning(readString(dis), dis.readBoolean()); + } + @Override public void handle(ProtocolClient client) { WorldHost.LOGGER.warn("Warning from WH server (important: {}): {}", important, message); @@ -325,55 +457,30 @@ public void handle(ProtocolClient client) { */ void handle(ProtocolClient client); - static WorldHostS2CMessage decode(DataInputStream dis) throws IOException { - final int typeId = dis.readUnsignedByte(); + static boolean isEncrypted(int typeId) { + return false; + } + + static WorldHostS2CMessage decode(int typeId, DataInputStream dis) throws IOException { return switch (typeId) { - case 0 -> new Error(readString(dis), dis.read() > 0); // -1 means that there was no critical flag sent - case 1 -> new IsOnlineTo(readUuid(dis)); - case 2 -> new OnlineGame(readString(dis), dis.readUnsignedShort(), dis.readLong(), dis.readBoolean()); - case 3 -> new FriendRequest(readUuid(dis), SecurityLevel.byId(dis.readUnsignedByte())); - case 4 -> new PublishedWorld(readUuid(dis), dis.readLong(), SecurityLevel.byId(dis.readUnsignedByte())); - case 5 -> new ClosedWorld(readUuid(dis)); - case 6 -> new RequestJoin(readUuid(dis), dis.readLong(), SecurityLevel.byId(dis.readUnsignedByte())); - case 7 -> new QueryRequest(readUuid(dis), dis.readLong(), SecurityLevel.byId(dis.readUnsignedByte())); - case 8 -> { - final UUID friend = readUuid(dis); - final var buf = WorldHost.createByteBuf(); - buf.writeBytes(dis, dis.readInt()); - ServerStatus serverStatus; - try { - serverStatus = WorldHost.parseServerStatus(buf); - } catch (Exception e) { - WorldHost.LOGGER.error("Failed to parse server status", e); - serverStatus = WorldHost.createEmptyServerStatus(); - } - yield new QueryResponse(friend, serverStatus); - } - case 9 -> new ProxyC2SPacket(dis.readLong(), dis.readAllBytes()); - case 10 -> new ProxyConnect(dis.readLong(), InetAddress.getByAddress(dis.readNBytes(dis.readUnsignedByte()))); - case 11 -> new ProxyDisconnect(dis.readLong()); - case 12 -> new ConnectionInfo( - dis.readLong(), readString(dis), dis.readUnsignedShort(), readString(dis), dis.readInt(), dis.readUnsignedShort() - ); - case 13 -> new ExternalProxyServer( - readString(dis), dis.readUnsignedShort(), readString(dis), dis.readUnsignedShort() - ); - case 14 -> new OutdatedWorldHost(readString(dis)); - case 15 -> new ConnectionNotFound(dis.readLong()); - case 16 -> { - final UUID friend = readUuid(dis); - final var buf = WorldHost.createByteBuf(); - buf.writeBytes(dis.readAllBytes()); - ServerStatus serverStatus; - try { - serverStatus = WorldHost.parseServerStatus(buf); - } catch (Exception e) { - WorldHost.LOGGER.error("Failed to parse server status", e); - serverStatus = WorldHost.createEmptyServerStatus(); - } - yield new NewQueryResponse(friend, serverStatus); - } - case 17 -> new Warning(readString(dis), dis.readBoolean()); + case Error.ID -> Error.decode(dis); + case IsOnlineTo.ID -> IsOnlineTo.decode(dis); + case OnlineGame.ID -> OnlineGame.decode(dis); + case FriendRequest.ID -> FriendRequest.decode(dis); + case PublishedWorld.ID -> PublishedWorld.decode(dis); + case ClosedWorld.ID -> ClosedWorld.decode(dis); + case RequestJoin.ID -> RequestJoin.decode(dis); + case QueryRequest.ID -> QueryRequest.decode(dis); + case QueryResponse.ID -> QueryResponse.decode(dis); + case ProxyC2SPacket.ID -> ProxyC2SPacket.decode(dis); + case ProxyConnect.ID -> ProxyConnect.decode(dis); + case ProxyDisconnect.ID -> ProxyDisconnect.decode(dis); + case ConnectionInfo.ID -> ConnectionInfo.decode(dis); + case ExternalProxyServer.ID -> ExternalProxyServer.decode(dis); + case OutdatedWorldHost.ID -> OutdatedWorldHost.decode(dis); + case ConnectionNotFound.ID -> ConnectionNotFound.decode(dis); + case NewQueryResponse.ID -> NewQueryResponse.decode(dis); + case Warning.ID -> Warning.decode(dis); default -> new Error("Received packet with unknown typeId from server (outdated client?): " + typeId); }; }