Skip to content

Commit

Permalink
Static sizer and timeout handlers in the pipeline (#833)
Browse files Browse the repository at this point in the history
* Improve pipeline

This simplifies all pipeline code and ensures some listeners like the sizer are always present. The code already assumed that the sizer is always there and thus causes issues. The sizer can be deactivated still now and has pretty much no performance losses from this. The profit from this PR is that there is less logic with modifying the PR and thus developers interacting with the channel can assume specific things about the order and placements of elements in the pipeline. This will be useful once ViaVersion is supported, and it is expected that certain elements always are in the pipeline and don't change. My plan is to also always have an encryption and compression handler in the pipeline that is controlled via AttributeKeys from netty, but for that first #828 needs to be merged. So this PR only completes the goal partially, but that's fine. PR is ready for review like it is right now.

* Revert some stuff

* Fix channel race condition

* Fix closing race condition

* Prevent client race conditions.

* Fix test failure, idk how, idk why, but it works now

* Address review

* Update protocol/src/main/java/org/geysermc/mcprotocollib/network/BuiltinFlags.java

Co-authored-by: Konicai <[email protected]>

* Update protocol/src/main/java/org/geysermc/mcprotocollib/network/BuiltinFlags.java

Co-authored-by: Konicai <[email protected]>

* Update protocol/src/main/java/org/geysermc/mcprotocollib/network/BuiltinFlags.java

Co-authored-by: chris <[email protected]>

* Update protocol/src/main/java/org/geysermc/mcprotocollib/network/tcp/TcpServer.java

Co-authored-by: chris <[email protected]>

* Update protocol/src/main/java/org/geysermc/mcprotocollib/network/BuiltinFlags.java

Co-authored-by: chris <[email protected]>

* Update protocol/src/main/java/org/geysermc/mcprotocollib/network/BuiltinFlags.java

Co-authored-by: chris <[email protected]>

* Update protocol/src/main/java/org/geysermc/mcprotocollib/network/tcp/TcpServer.java

Co-authored-by: chris <[email protected]>

* Update protocol/src/main/java/org/geysermc/mcprotocollib/network/tcp/TcpServer.java

Co-authored-by: chris <[email protected]>

* Update protocol/src/main/java/org/geysermc/mcprotocollib/network/tcp/TcpServer.java

Co-authored-by: chris <[email protected]>

---------

Co-authored-by: Konicai <[email protected]>
Co-authored-by: chris <[email protected]>
  • Loading branch information
3 people authored Sep 10, 2024
1 parent 716f229 commit 4148fa9
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 282 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
* Built-in PacketLib session flags.
*/
public class BuiltinFlags {
public static final Flag<Boolean> ENABLE_CLIENT_PROXY_PROTOCOL = new Flag<>("enable-client-proxy-protocol", Boolean.class);

/**
* Enables HAProxy protocol support.
* When this value is not null it represents the ip and port the client claims the connection is from.
*/
public static final Flag<InetSocketAddress> CLIENT_PROXIED_ADDRESS = new Flag<>("client-proxied-address", InetSocketAddress.class);

/**
Expand All @@ -20,6 +23,24 @@ public class BuiltinFlags {
*/
public static final Flag<Boolean> TCP_FAST_OPEN = new Flag<>("tcp-fast-open", Boolean.class);

/**
* Connection timeout in seconds.
* Only used by the client.
*/
public static final Flag<Integer> CLIENT_CONNECT_TIMEOUT = new Flag<>("client-connect-timeout", Integer.class);

/**
* Read timeout in seconds.
* Used by both the server and client.
*/
public static final Flag<Integer> READ_TIMEOUT = new Flag<>("read-timeout", Integer.class);

/**
* Write timeout in seconds.
* Used by both the server and client.
*/
public static final Flag<Integer> WRITE_TIMEOUT = new Flag<>("write-timeout", Integer.class);

private BuiltinFlags() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public interface Session {
* @param wait Whether to wait for the connection to be established before returning.
* @param transferring Whether the session is a client being transferred.
*/
public void connect(boolean wait, boolean transferring);
void connect(boolean wait, boolean transferring);

/**
* Gets the host the session is connected to.
Expand Down Expand Up @@ -138,7 +138,7 @@ public interface Session {
*
* @param flags Collection of flags
*/
public void setFlags(Map<String, Object> flags);
void setFlags(Map<String, Object> flags);

/**
* Gets the listeners listening on this session.
Expand Down Expand Up @@ -204,48 +204,6 @@ public interface Session {
*/
void enableEncryption(PacketEncryption encryption);

/**
* Gets the connect timeout for this session in seconds.
*
* @return The session's connect timeout.
*/
int getConnectTimeout();

/**
* Sets the connect timeout for this session in seconds.
*
* @param timeout Connect timeout to set.
*/
void setConnectTimeout(int timeout);

/**
* Gets the read timeout for this session in seconds.
*
* @return The session's read timeout.
*/
int getReadTimeout();

/**
* Sets the read timeout for this session in seconds.
*
* @param timeout Read timeout to set.
*/
void setReadTimeout(int timeout);

/**
* Gets the write timeout for this session in seconds.
*
* @return The session's write timeout.
*/
int getWriteTimeout();

/**
* Sets the write timeout for this session in seconds.
*
* @param timeout Write timeout to set.
*/
void setWriteTimeout(int timeout);

/**
* Returns true if the session is connected.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
import io.netty.buffer.ByteBuf;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
Expand All @@ -25,6 +22,8 @@
import io.netty.handler.proxy.HttpProxyHandler;
import io.netty.handler.proxy.Socks4ProxyHandler;
import io.netty.handler.proxy.Socks5ProxyHandler;
import io.netty.handler.timeout.ReadTimeoutHandler;
import io.netty.handler.timeout.WriteTimeoutHandler;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.util.concurrent.DefaultThreadFactory;
Expand All @@ -40,6 +39,7 @@
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -90,56 +90,51 @@ public void connect(boolean wait, boolean transferring) {
createTcpEventLoopGroup();
}

try {
final Bootstrap bootstrap = new Bootstrap()
.channelFactory(TRANSPORT_TYPE.socketChannelFactory())
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.IP_TOS, 0x18)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, getConnectTimeout() * 1000)
.group(EVENT_LOOP_GROUP)
.remoteAddress(resolveAddress())
.localAddress(bindAddress, bindPort)
.handler(new ChannelInitializer<>() {
@Override
public void initChannel(Channel channel) {
PacketProtocol protocol = getPacketProtocol();
protocol.newClientSession(TcpClientSession.this, transferring);

ChannelPipeline pipeline = channel.pipeline();

refreshReadTimeoutHandler(channel);
refreshWriteTimeoutHandler(channel);

addProxy(pipeline);

int size = protocol.getPacketHeader().getLengthSize();
if (size > 0) {
pipeline.addLast("sizer", new TcpPacketSizer(TcpClientSession.this, size));
}

pipeline.addLast("codec", new TcpPacketCodec(TcpClientSession.this, true));
pipeline.addLast("manager", TcpClientSession.this);

addHAProxySupport(pipeline);
}
});
final Bootstrap bootstrap = new Bootstrap()
.channelFactory(TRANSPORT_TYPE.socketChannelFactory())
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.IP_TOS, 0x18)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, getFlag(BuiltinFlags.CLIENT_CONNECT_TIMEOUT, 30) * 1000)
.group(EVENT_LOOP_GROUP)
.remoteAddress(resolveAddress())
.localAddress(bindAddress, bindPort)
.handler(new ChannelInitializer<>() {
@Override
public void initChannel(Channel channel) {
PacketProtocol protocol = getPacketProtocol();
protocol.newClientSession(TcpClientSession.this, transferring);

if (getFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TRANSPORT_TYPE.supportsTcpFastOpenClient()) {
bootstrap.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
}
ChannelPipeline pipeline = channel.pipeline();

ChannelFuture future = bootstrap.connect();
if (wait) {
future.sync();
}
addProxy(pipeline);

future.addListener((futureListener) -> {
if (!futureListener.isSuccess()) {
exceptionCaught(null, futureListener.cause());
initializeHAProxySupport(channel);

pipeline.addLast("read-timeout", new ReadTimeoutHandler(getFlag(BuiltinFlags.READ_TIMEOUT, 30)));
pipeline.addLast("write-timeout", new WriteTimeoutHandler(getFlag(BuiltinFlags.WRITE_TIMEOUT, 0)));

pipeline.addLast("sizer", new TcpPacketSizer(protocol.getPacketHeader(), getCodecHelper()));

pipeline.addLast("codec", new TcpPacketCodec(TcpClientSession.this, true));
pipeline.addLast("manager", TcpClientSession.this);
}
});
} catch (Throwable t) {
exceptionCaught(null, t);

if (getFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TRANSPORT_TYPE.supportsTcpFastOpenClient()) {
bootstrap.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
}

CompletableFuture<Void> handleFuture = new CompletableFuture<>();
bootstrap.connect().addListener((futureListener) -> {
if (!futureListener.isSuccess()) {
exceptionCaught(null, futureListener.cause());
}

handleFuture.complete(null);
});

if (wait) {
handleFuture.join();
}
}

Expand All @@ -155,8 +150,8 @@ private InetSocketAddress resolveAddress() {
if (getFlag(BuiltinFlags.ATTEMPT_SRV_RESOLVE, true) && (!this.host.matches(IP_REGEX) && !this.host.equalsIgnoreCase("localhost"))) {
AddressedEnvelope<DnsResponse, InetSocketAddress> envelope = null;
try (DnsNameResolver resolver = new DnsNameResolverBuilder(EVENT_LOOP_GROUP.next())
.channelFactory(TRANSPORT_TYPE.datagramChannelFactory())
.build()) {
.channelFactory(TRANSPORT_TYPE.datagramChannelFactory())
.build()) {
envelope = resolver.query(new DefaultDnsQuestion(name, DnsRecordType.SRV)).get();

DnsResponse response = envelope.content();
Expand Down Expand Up @@ -206,54 +201,52 @@ private InetSocketAddress resolveAddress() {
}

private void addProxy(ChannelPipeline pipeline) {
if (proxy != null) {
switch (proxy.type()) {
case HTTP -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addFirst("proxy", new HttpProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addFirst("proxy", new HttpProxyHandler(proxy.address()));
}
if (proxy == null) {
return;
}

switch (proxy.type()) {
case HTTP -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addLast("proxy", new HttpProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addLast("proxy", new HttpProxyHandler(proxy.address()));
}
case SOCKS4 -> {
if (proxy.username() != null) {
pipeline.addFirst("proxy", new Socks4ProxyHandler(proxy.address(), proxy.username()));
} else {
pipeline.addFirst("proxy", new Socks4ProxyHandler(proxy.address()));
}
}
case SOCKS4 -> {
if (proxy.username() != null) {
pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address(), proxy.username()));
} else {
pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address()));
}
case SOCKS5 -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addFirst("proxy", new Socks5ProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addFirst("proxy", new Socks5ProxyHandler(proxy.address()));
}
}
case SOCKS5 -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address()));
}
default -> throw new UnsupportedOperationException("Unsupported proxy type: " + proxy.type());
}
default -> throw new UnsupportedOperationException("Unsupported proxy type: " + proxy.type());
}
}

private void addHAProxySupport(ChannelPipeline pipeline) {
private void initializeHAProxySupport(Channel channel) {
InetSocketAddress clientAddress = getFlag(BuiltinFlags.CLIENT_PROXIED_ADDRESS);
if (getFlag(BuiltinFlags.ENABLE_CLIENT_PROXY_PROTOCOL, false) && clientAddress != null) {
pipeline.addFirst("proxy-protocol-packet-sender", new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
HAProxyProxiedProtocol proxiedProtocol = clientAddress.getAddress() instanceof Inet4Address ? HAProxyProxiedProtocol.TCP4 : HAProxyProxiedProtocol.TCP6;
InetSocketAddress remoteAddress = (InetSocketAddress) ctx.channel().remoteAddress();
ctx.channel().writeAndFlush(new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, proxiedProtocol,
clientAddress.getAddress().getHostAddress(), remoteAddress.getAddress().getHostAddress(),
clientAddress.getPort(), remoteAddress.getPort()
));
ctx.pipeline().remove(this);
ctx.pipeline().remove("proxy-protocol-encoder");
super.channelActive(ctx);
}
});
pipeline.addFirst("proxy-protocol-encoder", HAProxyMessageEncoder.INSTANCE);
if (clientAddress == null) {
return;
}

channel.pipeline().addLast("proxy-protocol-encoder", HAProxyMessageEncoder.INSTANCE);
HAProxyProxiedProtocol proxiedProtocol = clientAddress.getAddress() instanceof Inet4Address ? HAProxyProxiedProtocol.TCP4 : HAProxyProxiedProtocol.TCP6;
InetSocketAddress remoteAddress = (InetSocketAddress) channel.remoteAddress();
channel.writeAndFlush(new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, proxiedProtocol,
clientAddress.getAddress().getHostAddress(), remoteAddress.getAddress().getHostAddress(),
clientAddress.getPort(), remoteAddress.getPort()
)).addListener(future -> {
channel.pipeline().remove("proxy-protocol-encoder");
});
}

private static void createTcpEventLoopGroup() {
Expand All @@ -264,7 +257,7 @@ private static void createTcpEventLoopGroup() {
EVENT_LOOP_GROUP = TRANSPORT_TYPE.eventLoopGroupFactory().apply(newThreadFactory());

Runtime.getRuntime().addShutdownHook(new Thread(
() -> EVENT_LOOP_GROUP.shutdownGracefully(SHUTDOWN_QUIET_PERIOD_MS, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)));
() -> EVENT_LOOP_GROUP.shutdownGracefully(SHUTDOWN_QUIET_PERIOD_MS, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)));
}

protected static ThreadFactory newThreadFactory() {
Expand Down
Loading

0 comments on commit 4148fa9

Please sign in to comment.