Skip to content

Commit

Permalink
Network addon session refactor (FabricMC#3394)
Browse files Browse the repository at this point in the history
* refactor network addon session handling

* Check payload size

* Fix ClientLoginNetworkAddon does not handle unsuccessful query responses
Closes FabricMC#3384

* Adjust some logging.

---------

Co-authored-by: deirn <[email protected]>
(cherry picked from commit bff13c8)
  • Loading branch information
modmuss committed Nov 2, 2023
1 parent 5aa6d9a commit 8ed13ef
Show file tree
Hide file tree
Showing 20 changed files with 130 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import java.util.Collections;
import java.util.List;
import java.util.Map;

import net.minecraft.client.MinecraftClient;
import net.minecraft.client.network.ClientConfigurationNetworkHandler;
Expand Down Expand Up @@ -51,17 +50,10 @@ public ClientConfigurationNetworkAddon(ClientConfigurationNetworkHandler handler

// Must register pending channels via lateinit
this.registerPendingChannels((ChannelInfoHolder) this.connection, NetworkState.CONFIGURATION);

// Register global receivers and attach to session
this.receiver.startSession(this);
}

@Override
public void lateInit() {
for (Map.Entry<Identifier, ClientConfigurationNetworking.ConfigurationChannelHandler> entry : this.receiver.getHandlers().entrySet()) {
this.registerChannel(entry.getKey(), entry.getValue());
}

protected void invokeInitEvent() {
ClientConfigurationConnectionEvents.INIT.invoker().onConfigurationInit(this.handler, this.client);
}

Expand Down Expand Up @@ -153,7 +145,6 @@ public void handleReady() {
@Override
protected void invokeDisconnectEvent() {
ClientConfigurationConnectionEvents.DISCONNECT.invoker().onConfigurationDisconnect(this.handler, this.client);
this.receiver.endSession(this);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ public ClientLoginNetworkAddon(ClientLoginNetworkHandler handler, MinecraftClien
super(ClientNetworkingImpl.LOGIN, "ClientLoginNetworkAddon for Client");
this.handler = handler;
this.client = client;
}

@Override
protected void invokeInitEvent() {
ClientLoginConnectionEvents.INIT.invoker().onLoginStart(this.handler, this.client);
this.receiver.startSession(this);
}

public boolean handlePacket(LoginQueryRequestS2CPacket packet) {
Expand Down Expand Up @@ -86,7 +88,7 @@ private boolean handlePacket(int queryId, Identifier channelName, PacketByteBuf
try {
CompletableFuture<@Nullable PacketByteBuf> future = handler.receive(this.client, this.handler, buf, futureListeners::add);
future.thenAccept(result -> {
LoginQueryResponseC2SPacket packet = new LoginQueryResponseC2SPacket(queryId, new PacketByteBufLoginQueryResponse(result));
LoginQueryResponseC2SPacket packet = new LoginQueryResponseC2SPacket(queryId, result == null ? null : new PacketByteBufLoginQueryResponse(result));
GenericFutureListener<? extends Future<? super Void>> listener = null;

for (GenericFutureListener<? extends Future<? super Void>> each : futureListeners) {
Expand Down Expand Up @@ -114,11 +116,6 @@ protected void handleUnregistration(Identifier channelName) {
@Override
protected void invokeDisconnectEvent() {
ClientLoginConnectionEvents.DISCONNECT.invoker().onLoginDisconnect(this.handler, this.client);
this.receiver.endSession(this);
}

public void handleConfigurationTransition() {
this.receiver.endSession(this);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import java.util.Collections;
import java.util.List;
import java.util.Map;

import com.mojang.logging.LogUtils;
import org.slf4j.Logger;
Expand Down Expand Up @@ -53,17 +52,10 @@ public ClientPlayNetworkAddon(ClientPlayNetworkHandler handler, MinecraftClient

// Must register pending channels via lateinit
this.registerPendingChannels((ChannelInfoHolder) this.connection, NetworkState.PLAY);

// Register global receivers and attach to session
this.receiver.startSession(this);
}

@Override
public void lateInit() {
for (Map.Entry<Identifier, ClientPlayNetworking.PlayChannelHandler> entry : this.receiver.getHandlers().entrySet()) {
this.registerChannel(entry.getKey(), entry.getValue());
}

protected void invokeInitEvent() {
ClientPlayConnectionEvents.INIT.invoker().onPlayInit(this.handler, this.client);
}

Expand Down Expand Up @@ -148,7 +140,6 @@ protected void handleUnregistration(Identifier channelName) {
@Override
protected void invokeDisconnectEvent() {
ClientPlayConnectionEvents.DISCONNECT.invoker().onPlayDisconnect(this.handler, this.client);
this.receiver.endSession(this);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import net.minecraft.client.network.ClientCommonNetworkHandler;
import net.minecraft.network.packet.s2c.common.CustomPayloadS2CPacket;
import net.minecraft.text.Text;

import net.fabricmc.fabric.impl.networking.NetworkHandlerExtensions;
import net.fabricmc.fabric.impl.networking.client.ClientConfigurationNetworkAddon;
Expand All @@ -32,11 +31,6 @@

@Mixin(ClientCommonNetworkHandler.class)
public abstract class ClientCommonNetworkHandlerMixin implements NetworkHandlerExtensions {
@Inject(method = "onDisconnected", at = @At("HEAD"))
private void handleDisconnection(Text reason, CallbackInfo ci) {
this.getAddon().handleDisconnect();
}

@Inject(method = "onCustomPayload(Lnet/minecraft/network/packet/s2c/common/CustomPayloadS2CPacket;)V", at = @At("HEAD"), cancellable = true)
public void onCustomPayload(CustomPayloadS2CPacket packet, CallbackInfo ci) {
if (packet.payload() instanceof PacketByteBufPayload payload) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import net.minecraft.client.network.ClientLoginNetworkHandler;
import net.minecraft.network.ClientConnection;
import net.minecraft.network.packet.s2c.login.LoginQueryRequestS2CPacket;
import net.minecraft.text.Text;

import net.fabricmc.fabric.impl.networking.NetworkHandlerExtensions;
import net.fabricmc.fabric.impl.networking.client.ClientConfigurationNetworkAddon;
Expand Down Expand Up @@ -64,16 +63,6 @@ private void handleQueryRequest(LoginQueryRequestS2CPacket packet, CallbackInfo
}
}

@Inject(method = "onDisconnected", at = @At("HEAD"))
private void invokeLoginDisconnectEvent(Text reason, CallbackInfo ci) {
this.addon.handleDisconnect();
}

@Inject(method = "onSuccess", at = @At("HEAD"))
private void handleConfigurationTransition(CallbackInfo ci) {
addon.handleConfigurationTransition();
}

@Inject(method = "onSuccess", at = @At("TAIL"))
private void handleConfigurationReady(CallbackInfo ci) {
NetworkHandlerExtensions networkHandlerExtensions = (NetworkHandlerExtensions) connection.getPacketListener();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ protected AbstractChanneledNetworkAddon(GlobalReceiverRegistry<H> receiver, Clie
this.sendableChannels = Collections.synchronizedSet(new HashSet<>());
}

public abstract void lateInit();

protected void registerPendingChannels(ChannelInfoHolder holder, NetworkState state) {
final Collection<Identifier> pending = holder.getPendingChannelsNames(state);

Expand Down Expand Up @@ -211,7 +209,7 @@ public void onCommonVersionPacket(int negotiatedVersion) {
assert negotiatedVersion == 1; // We only support version 1 for now

commonVersion = negotiatedVersion;
this.logger.info("Negotiated common packet version {}", commonVersion);
this.logger.debug("Negotiated common packet version {}", commonVersion);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ protected AbstractNetworkAddon(GlobalReceiverRegistry<H> receiver, String descri
this.logger = LoggerFactory.getLogger(description);
}

public final void lateInit() {
this.receiver.startSession(this);
invokeInitEvent();
}

protected abstract void invokeInitEvent();

public final void endSession() {
this.receiver.endSession(this);
}

@Nullable
public H getHandler(Identifier channel) {
Lock lock = this.lock.readLock();
Expand All @@ -64,13 +75,32 @@ public H getHandler(Identifier channel) {
}
}

private void assertNotReserved(Identifier channel) {
if (this.isReservedChannel(channel)) {
throw new IllegalArgumentException(String.format("Cannot (un)register handler for reserved channel with name \"%s\"", channel));
}
}

public void registerChannels(Map<Identifier, H> map) {
Lock lock = this.lock.writeLock();
lock.lock();

try {
for (Map.Entry<Identifier, H> entry : map.entrySet()) {
assertNotReserved(entry.getKey());

boolean unique = this.handlers.putIfAbsent(entry.getKey(), entry.getValue()) == null;
if (unique) handleRegistration(entry.getKey());
}
} finally {
lock.unlock();
}
}

public boolean registerChannel(Identifier channelName, H handler) {
Objects.requireNonNull(channelName, "Channel name cannot be null");
Objects.requireNonNull(handler, "Packet handler cannot be null");

if (this.isReservedChannel(channelName)) {
throw new IllegalArgumentException(String.format("Cannot register handler for reserved channel with name \"%s\"", channelName));
}
assertNotReserved(channelName);

Lock lock = this.lock.writeLock();
lock.lock();
Expand All @@ -90,10 +120,7 @@ public boolean registerChannel(Identifier channelName, H handler) {

public H unregisterChannel(Identifier channelName) {
Objects.requireNonNull(channelName, "Channel name cannot be null");

if (this.isReservedChannel(channelName)) {
throw new IllegalArgumentException(String.format("Cannot register handler for reserved channel with name \"%s\"", channelName));
}
assertNotReserved(channelName);

Lock lock = this.lock.writeLock();
lock.lock();
Expand Down Expand Up @@ -129,6 +156,7 @@ public Set<Identifier> getReceivableChannels() {
public final void handleDisconnect() {
if (disconnected.compareAndSet(false, true)) {
invokeDisconnectEvent();
endSession();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@
import java.util.concurrent.locks.ReentrantReadWriteLock;

import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.minecraft.network.NetworkState;
import net.minecraft.util.Identifier;

public final class GlobalReceiverRegistry<H> {
private static final Logger LOGGER = LoggerFactory.getLogger(GlobalReceiverRegistry.class);

private final NetworkState state;

private final ReadWriteLock lock = new ReentrantReadWriteLock();
Expand Down Expand Up @@ -134,7 +138,11 @@ public void startSession(AbstractNetworkAddon<H> addon) {
lock.lock();

try {
this.trackedAddons.add(addon);
if (this.trackedAddons.add(addon)) {
addon.registerChannels(handlers);
}

this.logTrackedAddonSize();
} finally {
lock.unlock();
}
Expand All @@ -145,17 +153,29 @@ public void endSession(AbstractNetworkAddon<H> addon) {
lock.lock();

try {
this.logTrackedAddonSize();
this.trackedAddons.remove(addon);
} finally {
lock.unlock();
}
}

/**
* In practice, trackedAddons should never contain more than the number of players.
*/
private void logTrackedAddonSize() {
if (LOGGER.isTraceEnabled() && this.trackedAddons.size() > 1) {
LOGGER.trace("{} receiver registry tracks {} addon instances", state.getId(), trackedAddons.size());
}
}

private void handleRegistration(Identifier channelName, H handler) {
Lock lock = this.lock.writeLock();
lock.lock();

try {
this.logTrackedAddonSize();

for (AbstractNetworkAddon<H> addon : this.trackedAddons) {
addon.registerChannel(channelName, handler);
}
Expand All @@ -169,6 +189,8 @@ private void handleUnregistration(Identifier channelName) {
lock.lock();

try {
this.logTrackedAddonSize();

for (AbstractNetworkAddon<H> addon : this.trackedAddons) {
addon.unregisterChannel(channelName);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@ public static void write(PacketByteBuf byteBuf, PacketByteBuf data) {
byteBuf.writeBytes(data.copy());
}

public static PacketByteBuf read(PacketByteBuf byteBuf) {
public static PacketByteBuf read(PacketByteBuf byteBuf, int maxSize) {
int size = byteBuf.readableBytes();

if (size < 0 || size > maxSize) {
throw new IllegalArgumentException("Payload may not be larger than %d bytes".formatted(maxSize));
}

PacketByteBuf newBuf = PacketByteBufs.create();
newBuf.writeBytes(byteBuf.copy());
byteBuf.skipBytes(byteBuf.readableBytes());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import java.util.Collections;
import java.util.List;
import java.util.Map;

import net.minecraft.network.NetworkState;
import net.minecraft.network.PacketByteBuf;
Expand Down Expand Up @@ -52,16 +51,10 @@ public ServerConfigurationNetworkAddon(ServerConfigurationNetworkHandler handler

// Must register pending channels via lateinit
this.registerPendingChannels((ChannelInfoHolder) this.connection, NetworkState.CONFIGURATION);

// Register global receivers and attach to session
this.receiver.startSession(this);
}

@Override
public void lateInit() {
for (Map.Entry<Identifier, ServerConfigurationNetworking.ConfigurationChannelHandler> entry : this.receiver.getHandlers().entrySet()) {
this.registerChannel(entry.getKey(), entry.getValue());
}
protected void invokeInitEvent() {
}

public void preConfiguration() {
Expand Down Expand Up @@ -177,7 +170,6 @@ protected void handleUnregistration(Identifier channelName) {
@Override
protected void invokeDisconnectEvent() {
ServerConfigurationConnectionEvents.DISCONNECT.invoker().onConfigureDisconnect(handler, server);
this.receiver.endSession(this);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ public ServerLoginNetworkAddon(ServerLoginNetworkHandler handler) {
this.handler = handler;
this.server = ((ServerLoginNetworkHandlerAccessor) handler).getServer();
this.queryIdFactory = QueryIdFactory.create();
}

@Override
protected void invokeInitEvent() {
ServerLoginConnectionEvents.INIT.invoker().onLoginInit(handler, this.server);
this.receiver.startSession(this);
}

// return true if no longer ticks query
Expand Down Expand Up @@ -202,11 +204,6 @@ protected void handleUnregistration(Identifier channelName) {
@Override
protected void invokeDisconnectEvent() {
ServerLoginConnectionEvents.DISCONNECT.invoker().onLoginDisconnect(this.handler, this.server);
this.receiver.endSession(this);
}

public void handleConfigurationTransition() {
this.receiver.endSession(this);
}

@Override
Expand Down
Loading

0 comments on commit 8ed13ef

Please sign in to comment.