From 370ac4e6672423a5f7ee17ab08a648b4077a2f69 Mon Sep 17 00:00:00 2001 From: Andrew Carbonetto Date: Thu, 16 Nov 2023 09:34:43 -0800 Subject: [PATCH] Clean up Redis close connection Signed-off-by: Andrew Carbonetto --- java/benchmarks/build.gradle | 4 +- .../clients/babushka/JniNettyClient.java | 389 ++++++++++-------- .../benchmarks/utils/Benchmarking.java | 20 +- 3 files changed, 223 insertions(+), 190 deletions(-) diff --git a/java/benchmarks/build.gradle b/java/benchmarks/build.gradle index 90c633d9f1..f07b2b8e29 100644 --- a/java/benchmarks/build.gradle +++ b/java/benchmarks/build.gradle @@ -51,7 +51,7 @@ application { // Define the main class for the application. mainClass = 'javababushka.benchmarks.BenchmarkingApp' mainClass = 'javababushka.benchmarks.clients.babushka.JniNettyClient' - applicationDefaultJvmArgs += "-Djava.library.path=${projectDir}/../target/debug" + applicationDefaultJvmArgs += "-Djava.library.path=${projectDir}/../target/release:${projectDir}/../target/debug" } tasks.withType(Test) { @@ -60,5 +60,5 @@ tasks.withType(Test) { events "started", "skipped", "passed", "failed" showStandardStreams true } - jvmArgs "-Djava.library.path=${projectDir}/../target/debug" + jvmArgs "-Djava.library.path=${projectDir}/../target/release:${projectDir}/../target/debug" } diff --git a/java/benchmarks/src/main/java/javababushka/benchmarks/clients/babushka/JniNettyClient.java b/java/benchmarks/src/main/java/javababushka/benchmarks/clients/babushka/JniNettyClient.java index 82bfb48a74..edbafd1890 100644 --- a/java/benchmarks/src/main/java/javababushka/benchmarks/clients/babushka/JniNettyClient.java +++ b/java/benchmarks/src/main/java/javababushka/benchmarks/clients/babushka/JniNettyClient.java @@ -1,18 +1,18 @@ package javababushka.benchmarks.clients.babushka; -import static connection_request.ConnectionRequestOuterClass.ConnectionRequest; import static connection_request.ConnectionRequestOuterClass.AddressInfo; -import static connection_request.ConnectionRequestOuterClass.ReadFromReplicaStrategy; -import static connection_request.ConnectionRequestOuterClass.ConnectionRetryStrategy; import static connection_request.ConnectionRequestOuterClass.AuthenticationInfo; +import static connection_request.ConnectionRequestOuterClass.ConnectionRequest; +import static connection_request.ConnectionRequestOuterClass.ConnectionRetryStrategy; +import static connection_request.ConnectionRequestOuterClass.ReadFromReplicaStrategy; import static connection_request.ConnectionRequestOuterClass.TlsMode; -import static response.ResponseOuterClass.Response; -import static redis_request.RedisRequestOuterClass.Command.ArgsArray; import static redis_request.RedisRequestOuterClass.Command; -import static redis_request.RedisRequestOuterClass.RequestType; +import static redis_request.RedisRequestOuterClass.Command.ArgsArray; import static redis_request.RedisRequestOuterClass.RedisRequest; -import static redis_request.RedisRequestOuterClass.SimpleRoutes; +import static redis_request.RedisRequestOuterClass.RequestType; import static redis_request.RedisRequestOuterClass.Routes; +import static redis_request.RedisRequestOuterClass.SimpleRoutes; +import static response.ResponseOuterClass.Response; import com.google.common.annotations.VisibleForTesting; import io.netty.bootstrap.Bootstrap; @@ -33,6 +33,7 @@ import io.netty.channel.kqueue.KQueue; import io.netty.channel.kqueue.KQueueDomainSocketChannel; import io.netty.channel.kqueue.KQueueEventLoopGroup; +import io.netty.channel.unix.DomainSocketAddress; import io.netty.channel.unix.UnixChannel; import io.netty.handler.codec.protobuf.ProtobufVarint32FrameDecoder; import io.netty.handler.codec.protobuf.ProtobufVarint32LengthFieldPrepender; @@ -40,13 +41,6 @@ import io.netty.handler.logging.LoggingHandler; import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.Slf4JLoggerFactory; -import javababushka.benchmarks.clients.AsyncClient; -import javababushka.benchmarks.clients.SyncClient; -import javababushka.benchmarks.utils.ConnectionSettings; -import io.netty.channel.unix.DomainSocketAddress; -import javababushka.client.RedisClient; -import org.apache.commons.lang3.tuple.Pair; - import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -54,6 +48,11 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; +import javababushka.benchmarks.clients.AsyncClient; +import javababushka.benchmarks.clients.SyncClient; +import javababushka.benchmarks.utils.ConnectionSettings; +import javababushka.client.RedisClient; +import org.apache.commons.lang3.tuple.Pair; @VisibleForTesting public class JniNettyClient implements SyncClient, AsyncClient, AutoCloseable { @@ -62,21 +61,23 @@ public class JniNettyClient implements SyncClient, AsyncClient, AutoCl // https://netty.io/3.6/api/org/jboss/netty/handler/queue/BufferedWriteHandler.html // Flush every N bytes if !ALWAYS_FLUSH_ON_WRITE - public static int AUTO_FLUSH_THRESHOLD_BYTES = 512;//1024; + public static int AUTO_FLUSH_THRESHOLD_BYTES = 512; // 1024; private final AtomicInteger nonFlushedBytesCounter = new AtomicInteger(0); // Flush every N writes if !ALWAYS_FLUSH_ON_WRITE public static int AUTO_FLUSH_THRESHOLD_WRITES = 10; private final AtomicInteger nonFlushedWritesCounter = new AtomicInteger(0); - // If !ALWAYS_FLUSH_ON_WRITE and a command has no response in N millis, flush (probably it wasn't send) + // If !ALWAYS_FLUSH_ON_WRITE and a command has no response in N millis, flush (probably it wasn't + // send) public static int AUTO_FLUSH_RESPONSE_TIMEOUT_MILLIS = 100; // If !ALWAYS_FLUSH_ON_WRITE flush on timer (like a cron) public static int AUTO_FLUSH_TIMER_MILLIS = 200; public static int PENDING_RESPONSES_ON_CLOSE_TIMEOUT_MILLIS = 1000; - // Futures to handle responses. Index is callback id, starting from 1 (0 index is for connection request always). + // Futures to handle responses. Index is callback id, starting from 1 (0 index is for connection + // request always). // Is it not a concurrent nor sync collection, but it is synced on adding. No removes. // TODO clean up completed futures private final List> responses = new ArrayList<>(); @@ -98,9 +99,11 @@ private static String getSocket() { private Channel channel = null; private EventLoopGroup group = null; - // We support MacOS and Linux only, because Babushka does not support Windows, because tokio does not support it. + // We support MacOS and Linux only, because Babushka does not support Windows, because tokio does + // not support it. // Probably we should use NIO (NioEventLoopGroup) for Windows. - private final static boolean isMacOs = isMacOs(); + private static final boolean isMacOs = isMacOs(); + private static boolean isMacOs() { try { Class.forName("io.netty.channel.kqueue.KQueue"); @@ -141,95 +144,109 @@ public void connectToRedis(ConnectionSettings connectionSettings) { private void createChannel() { // TODO maybe move to constructor or to static? try { - channel = new Bootstrap() - .option(ChannelOption.WRITE_BUFFER_WATER_MARK, - new WriteBufferWaterMark(1024, 4096)) - .option(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT) - .group(group = isMacOs ? new KQueueEventLoopGroup() : new EpollEventLoopGroup()) - .channel(isMacOs ? KQueueDomainSocketChannel.class : EpollDomainSocketChannel.class) - .handler(new ChannelInitializer() { - @Override - public void initChannel(UnixChannel ch) throws Exception { - ch - .pipeline() - .addLast("logger", new LoggingHandler(LogLevel.DEBUG)) - //https://netty.io/4.1/api/io/netty/handler/codec/protobuf/ProtobufEncoder.html - .addLast("protobufDecoder", new ProtobufVarint32FrameDecoder()) - .addLast("protobufEncoder", new ProtobufVarint32LengthFieldPrepender()) - - .addLast(new ChannelInboundHandlerAdapter() { - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - //System.out.printf("=== channelRead %s %s %n", ctx, msg); - var buf = (ByteBuf) msg; - var bytes = new byte[buf.readableBytes()]; - buf.readBytes(bytes); - // TODO surround parsing with try-catch, set error to future if parsing failed. - var response = Response.parseFrom(bytes); - int callbackId = response.getCallbackIdx(); - if (callbackId != 0) { - // connection request has hardcoded callback id = 0 - // https://github.com/aws/babushka/issues/600 - callbackId -= callbackOffset; - } - //System.out.printf("== Received response with callback %d%n", response.getCallbackIdx()); - responses.get(callbackId).complete(response); - responses.set(callbackId, null); - super.channelRead(ctx, bytes); - } - + channel = + new Bootstrap() + .option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(1024, 4096)) + .option(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT) + .group(group = isMacOs ? new KQueueEventLoopGroup() : new EpollEventLoopGroup()) + .channel(isMacOs ? KQueueDomainSocketChannel.class : EpollDomainSocketChannel.class) + .handler( + new ChannelInitializer() { @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - System.out.printf("=== exceptionCaught %s %s %n", ctx, cause); - cause.printStackTrace(); - super.exceptionCaught(ctx, cause); + public void initChannel(UnixChannel ch) throws Exception { + ch.pipeline() + .addLast("logger", new LoggingHandler(LogLevel.DEBUG)) + // https://netty.io/4.1/api/io/netty/handler/codec/protobuf/ProtobufEncoder.html + .addLast("protobufDecoder", new ProtobufVarint32FrameDecoder()) + .addLast("protobufEncoder", new ProtobufVarint32LengthFieldPrepender()) + .addLast( + new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) + throws Exception { + // System.out.printf("=== channelRead %s %s %n", ctx, msg); + var buf = (ByteBuf) msg; + var bytes = new byte[buf.readableBytes()]; + buf.readBytes(bytes); + // TODO surround parsing with try-catch, set error to future if + // parsing failed. + var response = Response.parseFrom(bytes); + int callbackId = response.getCallbackIdx(); + if (callbackId != 0) { + // connection request has hardcoded callback id = 0 + // https://github.com/aws/babushka/issues/600 + callbackId -= callbackOffset; + } + // System.out.printf("== Received response with callback %d%n", + // response.getCallbackIdx()); + responses.get(callbackId).complete(response); + responses.set(callbackId, null); + super.channelRead(ctx, bytes); + } + + @Override + public void exceptionCaught( + ChannelHandlerContext ctx, Throwable cause) throws Exception { + System.out.printf("=== exceptionCaught %s %s %n", ctx, cause); + cause.printStackTrace(); + super.exceptionCaught(ctx, cause); + } + }) + .addLast( + new ChannelOutboundHandlerAdapter() { + @Override + public void write( + ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + // System.out.printf("=== write %s %s %s %n", ctx, msg, promise); + var bytes = (byte[]) msg; + + boolean needFlush = false; + if (!ALWAYS_FLUSH_ON_WRITE) { + synchronized (nonFlushedBytesCounter) { + if (nonFlushedBytesCounter.addAndGet(bytes.length) + >= AUTO_FLUSH_THRESHOLD_BYTES + || nonFlushedWritesCounter.incrementAndGet() + >= AUTO_FLUSH_THRESHOLD_WRITES) { + nonFlushedBytesCounter.set(0); + nonFlushedWritesCounter.set(0); + needFlush = true; + } + } + } + super.write(ctx, Unpooled.copiedBuffer(bytes), promise); + if (needFlush) { + // flush outside the sync block + flush(ctx); + // System.out.println("-- auto flush - buffer"); + } + } + }); } }) - .addLast(new ChannelOutboundHandlerAdapter() { - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - //System.out.printf("=== write %s %s %s %n", ctx, msg, promise); - var bytes = (byte[])msg; - - boolean needFlush = false; - if (!ALWAYS_FLUSH_ON_WRITE) { - synchronized (nonFlushedBytesCounter) { - if (nonFlushedBytesCounter.addAndGet(bytes.length) >= AUTO_FLUSH_THRESHOLD_BYTES - || nonFlushedWritesCounter.incrementAndGet() >= AUTO_FLUSH_THRESHOLD_WRITES) { - nonFlushedBytesCounter.set(0); - nonFlushedWritesCounter.set(0); - needFlush = true; - } - } - } - super.write(ctx, Unpooled.copiedBuffer(bytes), promise); - if (needFlush) { - // flush outside the sync block - flush(ctx); - //System.out.println("-- auto flush - buffer"); - } - } + .connect(new DomainSocketAddress(unixSocket)) + .sync() + .channel(); - }); - } - }) - .connect(new DomainSocketAddress(unixSocket)).sync().channel(); - - } - catch (Exception e) { - System.err.printf("Failed to create a channel %s: %s%n", e.getClass().getSimpleName(), e.getMessage()); + } catch (Exception e) { + System.err.printf( + "Failed to create a channel %s: %s%n", e.getClass().getSimpleName(), e.getMessage()); e.printStackTrace(System.err); } if (!ALWAYS_FLUSH_ON_WRITE) { - new Timer(true).scheduleAtFixedRate(new TimerTask() { - @Override - public void run() { - channel.flush(); - nonFlushedBytesCounter.set(0); - nonFlushedWritesCounter.set(0); - } - }, 0, AUTO_FLUSH_TIMER_MILLIS); + new Timer(true) + .scheduleAtFixedRate( + new TimerTask() { + @Override + public void run() { + channel.flush(); + nonFlushedBytesCounter.set(0); + nonFlushedWritesCounter.set(0); + } + }, + 0, + AUTO_FLUSH_TIMER_MILLIS); } } @@ -239,7 +256,8 @@ public void closeConnection() { channel.flush(); long waitStarted = System.nanoTime(); - long waitUntil = waitStarted + PENDING_RESPONSES_ON_CLOSE_TIMEOUT_MILLIS * 100_000; // in nanos + long waitUntil = + waitStarted + PENDING_RESPONSES_ON_CLOSE_TIMEOUT_MILLIS * 100_000; // in nanos for (var future : responses) { if (future == null || future.isDone()) { continue; @@ -254,7 +272,17 @@ public void closeConnection() { } } } finally { - group.shutdownGracefully(); + var shuttingDown = group.shutdownGracefully(); + try { + shuttingDown.get(); + } catch (Exception e) { + e.printStackTrace(); + } + if (group.isShutdown()) { + System.out.println("Done shutdownGracefully"); + } else { + System.out.println("Something went wrong"); + } } } @@ -353,34 +381,33 @@ public void close() throws Exception { public Future asyncConnectToRedis(ConnectionSettings connectionSettings) { createChannel(); - var request = ConnectionRequest.newBuilder() - .addAddresses( - AddressInfo.newBuilder() - .setHost(connectionSettings.host) - .setPort(connectionSettings.port) - .build()) - .setTlsMode(connectionSettings.useSsl // TODO: secure or insecure TLS? - ? TlsMode.SecureTls - : TlsMode.NoTls) - .setClusterModeEnabled(false) - // In millis - .setResponseTimeout(250) - // In millis - .setClientCreationTimeout(2500) - .setReadFromReplicaStrategy(ReadFromReplicaStrategy.AlwaysFromPrimary) - .setConnectionRetryStrategy( - ConnectionRetryStrategy.newBuilder() - .setNumberOfRetries(1) - .setFactor(1) - .setExponentBase(1) - .build()) - .setAuthenticationInfo( - AuthenticationInfo.newBuilder() - .setPassword("") - .setUsername("default") - .build()) - .setDatabaseId(0) - .build(); + var request = + ConnectionRequest.newBuilder() + .addAddresses( + AddressInfo.newBuilder() + .setHost(connectionSettings.host) + .setPort(connectionSettings.port) + .build()) + .setTlsMode( + connectionSettings.useSsl // TODO: secure or insecure TLS? + ? TlsMode.SecureTls + : TlsMode.NoTls) + .setClusterModeEnabled(false) + // In millis + .setResponseTimeout(250) + // In millis + .setClientCreationTimeout(2500) + .setReadFromReplicaStrategy(ReadFromReplicaStrategy.AlwaysFromPrimary) + .setConnectionRetryStrategy( + ConnectionRetryStrategy.newBuilder() + .setNumberOfRetries(1) + .setFactor(1) + .setExponentBase(1) + .build()) + .setAuthenticationInfo( + AuthenticationInfo.newBuilder().setPassword("").setUsername("default").build()) + .setDatabaseId(0) + .build(); var future = new CompletableFuture(); responses.add(future); @@ -390,68 +417,70 @@ public Future asyncConnectToRedis(ConnectionSettings connectionSetting private CompletableFuture submitNewCommand(RequestType command, List args) { var commandId = getNextCallback(); - //System.out.printf("== %s(%s), callback %d%n", command, String.join(", ", args), commandId); - - return CompletableFuture.supplyAsync(() -> { - var commandArgs = ArgsArray.newBuilder(); - for (var arg : args) { - commandArgs.addArgs(arg); - } - - RedisRequest request = - RedisRequest.newBuilder() - .setCallbackIdx(commandId.getKey() + callbackOffset) - .setSingleCommand( - Command.newBuilder() - .setRequestType(command) - .setArgsArray(commandArgs.build()) - .build()) - .setRoute( - Routes.newBuilder() - .setSimpleRoutes(SimpleRoutes.AllNodes) - .build()) - .build(); - if (ALWAYS_FLUSH_ON_WRITE) { - channel.writeAndFlush(request.toByteArray()); - return commandId.getRight(); - } - channel.write(request.toByteArray()); - return autoFlushFutureWrapper(commandId.getRight()); - }).thenCompose(f -> f); + // System.out.printf("== %s(%s), callback %d%n", command, String.join(", ", args), commandId); + + return CompletableFuture.supplyAsync( + () -> { + var commandArgs = ArgsArray.newBuilder(); + for (var arg : args) { + commandArgs.addArgs(arg); + } + + RedisRequest request = + RedisRequest.newBuilder() + .setCallbackIdx(commandId.getKey() + callbackOffset) + .setSingleCommand( + Command.newBuilder() + .setRequestType(command) + .setArgsArray(commandArgs.build()) + .build()) + .setRoute(Routes.newBuilder().setSimpleRoutes(SimpleRoutes.AllNodes).build()) + .build(); + if (ALWAYS_FLUSH_ON_WRITE) { + channel.writeAndFlush(request.toByteArray()); + return commandId.getRight(); + } + channel.write(request.toByteArray()); + return autoFlushFutureWrapper(commandId.getRight()); + }) + .thenCompose(f -> f); } private CompletableFuture autoFlushFutureWrapper(Future future) { - return CompletableFuture.supplyAsync(() -> { - try { - return future.get(AUTO_FLUSH_RESPONSE_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } catch (TimeoutException e) { - //System.out.println("-- auto flush - timeout"); - channel.flush(); - nonFlushedBytesCounter.set(0); - nonFlushedWritesCounter.set(0); - } - try { - return future.get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - }); + return CompletableFuture.supplyAsync( + () -> { + try { + return future.get(AUTO_FLUSH_RESPONSE_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } catch (TimeoutException e) { + // System.out.println("-- auto flush - timeout"); + channel.flush(); + nonFlushedBytesCounter.set(0); + nonFlushedWritesCounter.set(0); + } + try { + return future.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }); } @Override public Future asyncSet(String key, String value) { - //System.out.printf("== set(%s, %s), callback %d%n", key, value, callbackId); + // System.out.printf("== set(%s, %s), callback %d%n", key, value, callbackId); return submitNewCommand(RequestType.SetString, List.of(key, value)); } @Override public Future asyncGet(String key) { - //System.out.printf("== get(%s), callback %d%n", key, callbackId); + // System.out.printf("== get(%s), callback %d%n", key, callbackId); return submitNewCommand(RequestType.GetString, List.of(key)) - .thenApply(response -> response.hasRespPointer() - ? RedisClient.valueFromPointer(response.getRespPointer()).toString() - : null); + .thenApply( + response -> + response.hasRespPointer() + ? RedisClient.valueFromPointer(response.getRespPointer()).toString() + : null); } } diff --git a/java/benchmarks/src/main/java/javababushka/benchmarks/utils/Benchmarking.java b/java/benchmarks/src/main/java/javababushka/benchmarks/utils/Benchmarking.java index 95d3a8ffac..bfe084204c 100644 --- a/java/benchmarks/src/main/java/javababushka/benchmarks/utils/Benchmarking.java +++ b/java/benchmarks/src/main/java/javababushka/benchmarks/utils/Benchmarking.java @@ -149,14 +149,11 @@ public static void testClientSetGet( Supplier clientCreator, BenchmarkingApp.RunConfiguration config, boolean async) { for (int concurrentNum : config.concurrentTasks) { int iterations = 100000; - Math.min(Math.max(LATENCY_MIN, concurrentNum * LATENCY_MULTIPLIER), LATENCY_MAX); + Math.min(Math.max(LATENCY_MIN, concurrentNum * LATENCY_MULTIPLIER), LATENCY_MAX); for (int clientCount : config.clientCount) { for (int dataSize : config.dataSize) { - System.out.printf( - "%n =====> %s <===== %d clients %d concurrent %d data %n%n", - clientCreator.get().getName(), clientCount, concurrentNum, dataSize); AtomicInteger iterationCounter = new AtomicInteger(0); - // Collections.synchronizedList + Map> actionResults = Map.of( ChosenAction.GET_EXISTING, new ArrayList<>(), @@ -172,6 +169,12 @@ public static void testClientSetGet( clients.add(newClient); } + String clientName = clients.get(0).getName(); + + System.out.printf( + "%n =====> %s <===== %d clients %d concurrent %d data %n%n", + clientName, clientCount, concurrentNum, dataSize); + for (int taskNum = 0; taskNum < concurrentNum; taskNum++) { final int taskNumDebugging = taskNum; tasks.add( @@ -214,7 +217,7 @@ public static void testClientSetGet( }); } if (config.debugLogging) { - System.out.printf("%s client Benchmarking: %n", clientCreator.get().getName()); + System.out.printf("%s client Benchmarking: %n", clientName); System.out.printf( "===> concurrentNum = %d, clientNum = %d, tasks = %d%n", concurrentNum, clientCount, tasks.size()); @@ -257,7 +260,7 @@ public static void testClientSetGet( calculatedResults, config.resultsFile.get(), dataSize, - clientCreator.get().getName(), + clientName, clientCount, concurrentNum, iterations / ((after - before) / TPS_NORMALIZATION)); @@ -265,7 +268,8 @@ public static void testClientSetGet( printResults(calculatedResults, (after - before) / TPS_NORMALIZATION, iterations); try { Thread.sleep(2000); - } catch (InterruptedException ignored) {} + } catch (InterruptedException ignored) { + } } } }