diff --git a/java/benchmarks/src/main/java/babushka/benchmarks/clients/babushka/JniNettyClient.java b/java/benchmarks/src/main/java/babushka/benchmarks/clients/babushka/JniNettyClient.java index 3731e16206..8ad0971585 100644 --- a/java/benchmarks/src/main/java/babushka/benchmarks/clients/babushka/JniNettyClient.java +++ b/java/benchmarks/src/main/java/babushka/benchmarks/clients/babushka/JniNettyClient.java @@ -1,21 +1,18 @@ package babushka.benchmarks.clients.babushka; +import babushka.Client; import babushka.benchmarks.clients.AsyncClient; import babushka.benchmarks.clients.SyncClient; import babushka.benchmarks.utils.ConnectionSettings; -import babushka.client.Commands; -import babushka.client.Connection; import java.util.concurrent.Future; public class JniNettyClient implements SyncClient, AsyncClient { - private final Connection connection; - private Commands commands = null; + private final Client testClient = new Client(); private String name = "JNI Netty"; public JniNettyClient(boolean async) { name += async ? " async" : " sync"; - connection = new Connection(); } @Override @@ -25,36 +22,37 @@ public String getName() { @Override public void closeConnection() { - connection.closeConnection(); + testClient.getConnection().closeConnection(); } @Override public void connectToRedis(ConnectionSettings connectionSettings) { - connection.connectToRedis( - connectionSettings.host, - connectionSettings.port, - connectionSettings.useSsl, - connectionSettings.clusterMode); - commands = connection.getCommands(); + testClient + .getConnection() + .connectToRedis( + connectionSettings.host, + connectionSettings.port, + connectionSettings.useSsl, + connectionSettings.clusterMode); } @Override public Future asyncSet(String key, String value) { - return commands.asyncSet(key, value); + return testClient.getCommands().asyncSet(key, value); } @Override public Future asyncGet(String key) { - return commands.asyncGet(key); + return testClient.getCommands().asyncGet(key); } @Override public void set(String key, String value) { - commands.set(key, value); + testClient.getCommands().set(key, value); } @Override public String get(String key) { - return commands.get(key); + return testClient.getCommands().get(key); } } diff --git a/java/client/src/main/java/babushka/Client.java b/java/client/src/main/java/babushka/Client.java new file mode 100644 index 0000000000..a82dc75e95 --- /dev/null +++ b/java/client/src/main/java/babushka/Client.java @@ -0,0 +1,32 @@ +package babushka; + +import babushka.client.ChannelHolder; +import babushka.client.Commands; +import babushka.client.Connection; +import babushka.connection.CallbackManager; +import babushka.connection.SocketManager; +import java.nio.channels.NotYetConnectedException; +import lombok.Getter; + +public class Client { + + private final ChannelHolder channelHolder; + private final Commands commands; + @Getter private final Connection connection; + + public Client() { + var callBackManager = new CallbackManager(); + channelHolder = + new ChannelHolder( + SocketManager.getInstance().openNewChannel(callBackManager), callBackManager); + connection = new Connection(channelHolder); + commands = new Commands(channelHolder); + } + + public Commands getCommands() { + if (!connection.isConnected()) { + throw new NotYetConnectedException(); + } + return commands; + } +} diff --git a/java/client/src/main/java/babushka/client/ChannelHolder.java b/java/client/src/main/java/babushka/client/ChannelHolder.java new file mode 100644 index 0000000000..e87c958dd8 --- /dev/null +++ b/java/client/src/main/java/babushka/client/ChannelHolder.java @@ -0,0 +1,40 @@ +package babushka.client; + +import babushka.connection.CallbackManager; +import connection_request.ConnectionRequestOuterClass.ConnectionRequest; +import io.netty.channel.Channel; +import java.util.concurrent.CompletableFuture; +import lombok.RequiredArgsConstructor; +import redis_request.RedisRequestOuterClass.RedisRequest; +import response.ResponseOuterClass.Response; + +@RequiredArgsConstructor +public class ChannelHolder { + private final Channel channel; + private final CallbackManager callbackManager; + + /** Write a protobuf message to the socket. */ + public CompletableFuture write(RedisRequest.Builder request, boolean flush) { + var commandId = callbackManager.registerRequest(); + request.setCallbackIdx(commandId.getKey()); + + if (flush) { + channel.writeAndFlush(request.build().toByteArray()); + } else { + channel.write(request.build().toByteArray()); + } + return commandId.getValue(); + } + + /** Write a protobuf message to the socket. */ + public CompletableFuture connect(ConnectionRequest request) { + channel.writeAndFlush(request.toByteArray()); + return callbackManager.getConnectionPromise(); + } + + /** Closes the UDS connection and frees corresponding resources. */ + public void close() { + channel.close(); + callbackManager.clean(); + } +} diff --git a/java/client/src/main/java/babushka/client/Commands.java b/java/client/src/main/java/babushka/client/Commands.java index ed74f77294..c6ad7826d3 100644 --- a/java/client/src/main/java/babushka/client/Commands.java +++ b/java/client/src/main/java/babushka/client/Commands.java @@ -1,22 +1,20 @@ package babushka.client; import babushka.FFI.BabushkaCoreNativeDefinitions; -import babushka.connection.SocketManager; import babushka.tools.Awaiter; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; +import lombok.RequiredArgsConstructor; import redis_request.RedisRequestOuterClass.RedisRequest; import redis_request.RedisRequestOuterClass.RequestType; import response.ResponseOuterClass.ConstantResponse; import response.ResponseOuterClass.Response; +@RequiredArgsConstructor public class Commands { - private final SocketManager socketManager; - public Commands(SocketManager socketManager) { - this.socketManager = socketManager; - } + private final ChannelHolder channel; /** * Sync (blocking) set. See async option in {@link #asyncSet}.
@@ -71,7 +69,7 @@ public Future asyncGet(String key) { private CompletableFuture submitRequest(RedisRequest.Builder builder) { // TODO this explicitly uses ForkJoin thread pool. May be we should use another one. - return CompletableFuture.supplyAsync(() -> socketManager.write(builder, true)) + return CompletableFuture.supplyAsync(() -> channel.write(builder, true)) .thenComposeAsync(f -> f); } } diff --git a/java/client/src/main/java/babushka/client/Connection.java b/java/client/src/main/java/babushka/client/Connection.java index 5bbeeae140..564ecc0cf5 100644 --- a/java/client/src/main/java/babushka/client/Connection.java +++ b/java/client/src/main/java/babushka/client/Connection.java @@ -1,25 +1,17 @@ package babushka.client; -import babushka.connection.SocketManager; import babushka.tools.Awaiter; -import java.nio.channels.NotYetConnectedException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; +import lombok.RequiredArgsConstructor; import response.ResponseOuterClass.ConstantResponse; +@RequiredArgsConstructor public class Connection { - // TODO: not used yet, not implemented on rust side - // https://github.com/aws/babushka/issues/635 - private final int connectionId = 0; - private final AtomicBoolean isConnected = new AtomicBoolean(false); - private final SocketManager socketManager; - - public Connection() { - socketManager = SocketManager.getInstance(); - } + private final ChannelHolder channel; /** * Sync (blocking) connect to REDIS. See async option in {@link #asyncConnectToRedis}. @@ -48,7 +40,7 @@ public boolean connectToRedis(String host, int port, boolean useSsl, boolean clu public CompletableFuture asyncConnectToRedis( String host, int port, boolean useSsl, boolean clusterMode) { var request = RequestBuilder.createConnectionRequest(host, port, useSsl, clusterMode); - return socketManager + return channel .connect(request) .thenApplyAsync( response -> @@ -62,16 +54,13 @@ public void closeConnection() { } /** Async (non-blocking) disconnect. See sync option in {@link #closeConnection}. */ - // TODO Not implemented yet in rust core lib. public CompletableFuture asyncCloseConnection() { isConnected.setPlain(false); - return CompletableFuture.runAsync(() -> {}); + return CompletableFuture.runAsync(channel::close); } - public Commands getCommands() { - if (!isConnected.get()) { - throw new NotYetConnectedException(); - } - return new Commands(socketManager); + /** Check that connection established. This doesn't validate whether it is alive. */ + public boolean isConnected() { + return isConnected.get(); } } diff --git a/java/client/src/main/java/babushka/connection/CallbackManager.java b/java/client/src/main/java/babushka/connection/CallbackManager.java index 9053143507..ece1db7f76 100644 --- a/java/client/src/main/java/babushka/connection/CallbackManager.java +++ b/java/client/src/main/java/babushka/connection/CallbackManager.java @@ -1,35 +1,31 @@ package babushka.connection; -import java.util.Deque; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.atomic.AtomicInteger; +import lombok.Getter; import org.apache.commons.lang3.tuple.Pair; import response.ResponseOuterClass.Response; /** Holder for resources owned by {@link SocketManager} and used by {@link ReadHandler}. */ -class CallbackManager { +public class CallbackManager { /** Unique request ID (callback ID). Thread-safe. */ - private static final AtomicInteger requestId = new AtomicInteger(0); + private final AtomicInteger requestId = new AtomicInteger(0); /** * Storage of Futures to handle responses. Map key is callback id, which starts from 1.
* Each future is a promise for every submitted by user request. */ - private static final Map> responses = - new ConcurrentHashMap<>(); + private final Map> responses = new ConcurrentHashMap<>(); /** - * Storage for connection requests similar to {@link #responses}. Unfortunately, connection + * Storage for connection request similar to {@link #responses}. Unfortunately, connection * requests can't be stored in the same storage, because callback ID = 0 is hardcoded for - * connection requests. Will be removed once issue #600 on GH fixed. + * connection requests. */ - private static final Deque> connectionRequests = - new ConcurrentLinkedDeque<>(); + @Getter private final CompletableFuture connectionPromise = new CompletableFuture<>(); /** * Register a new request to be sent. Once response received, the given future completes with it. @@ -37,45 +33,30 @@ class CallbackManager { * @return A pair of unique callback ID which should set into request and a client promise for * response. */ - public static Pair> registerRequest() { + public Pair> registerRequest() { int callbackId = requestId.incrementAndGet(); var future = new CompletableFuture(); responses.put(callbackId, future); return Pair.of(callbackId, future); } - /** - * Register a new connection request similar to {@link #registerRequest}.
- * No callback ID returned, because connection request/response pair have no such field (subject - * to change). Track issue #600 for more - * details. - */ - public static CompletableFuture registerConnection() { - var future = new CompletableFuture(); - connectionRequests.add(future); - return future; - } - /** * Complete the corresponding client promise and free resources. + * * @param response A response received */ - public static void completeRequest(Response response) { + public void completeRequest(Response response) { int callbackId = response.getCallbackIdx(); if (callbackId == 0) { - // can't distinguish connection requests since they have no - // callback ID - // https://github.com/aws/babushka/issues/600 - connectionRequests.pop().completeAsync(() -> response); + connectionPromise.completeAsync(() -> response); } else { responses.get(callbackId).completeAsync(() -> response); responses.remove(callbackId); } } - public static void clean() { + public void clean() { // TODO should we reply in uncompleted futures? - connectionRequests.clear(); responses.clear(); } } diff --git a/java/client/src/main/java/babushka/connection/ChannelBuilder.java b/java/client/src/main/java/babushka/connection/ChannelBuilder.java index efe074b1c7..d865abce1c 100644 --- a/java/client/src/main/java/babushka/connection/ChannelBuilder.java +++ b/java/client/src/main/java/babushka/connection/ChannelBuilder.java @@ -5,16 +5,21 @@ import io.netty.handler.codec.protobuf.ProtobufVarint32FrameDecoder; import io.netty.handler.codec.protobuf.ProtobufVarint32LengthFieldPrepender; import lombok.NonNull; +import lombok.RequiredArgsConstructor; /** Builder for the channel used by {@link SocketManager}. */ +@RequiredArgsConstructor public class ChannelBuilder extends ChannelInitializer { + + private final CallbackManager callbackManager; + @Override public void initChannel(@NonNull UnixChannel ch) { ch.pipeline() // https://netty.io/4.1/api/io/netty/handler/codec/protobuf/ProtobufEncoder.html .addLast("protobufDecoder", new ProtobufVarint32FrameDecoder()) .addLast("protobufEncoder", new ProtobufVarint32LengthFieldPrepender()) - .addLast(new ReadHandler()) + .addLast(new ReadHandler(callbackManager)) .addLast(new WriteHandler()); } } diff --git a/java/client/src/main/java/babushka/connection/ReadHandler.java b/java/client/src/main/java/babushka/connection/ReadHandler.java index 7467116e5c..d6eba71235 100644 --- a/java/client/src/main/java/babushka/connection/ReadHandler.java +++ b/java/client/src/main/java/babushka/connection/ReadHandler.java @@ -4,10 +4,15 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import lombok.NonNull; +import lombok.RequiredArgsConstructor; import response.ResponseOuterClass.Response; /** Handler for inbound traffic though UDS. Used by Netty. */ +@RequiredArgsConstructor public class ReadHandler extends ChannelInboundHandlerAdapter { + + private final CallbackManager callbackManager; + /** * Handles responses from babushka core: * @@ -25,7 +30,7 @@ public void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg) buf.readBytes(bytes); // TODO surround parsing with try-catch, set error to future if parsing failed. var response = Response.parseFrom(bytes); - CallbackManager.completeRequest(response); + callbackManager.completeRequest(response); buf.release(); } diff --git a/java/client/src/main/java/babushka/connection/SocketManager.java b/java/client/src/main/java/babushka/connection/SocketManager.java index feccd8907e..9d538415fc 100644 --- a/java/client/src/main/java/babushka/connection/SocketManager.java +++ b/java/client/src/main/java/babushka/connection/SocketManager.java @@ -1,17 +1,11 @@ package babushka.connection; -import static response.ResponseOuterClass.Response; - import babushka.FFI.BabushkaCoreNativeDefinitions; import babushka.client.Commands; import babushka.client.Connection; -import connection_request.ConnectionRequestOuterClass.ConnectionRequest; import io.netty.bootstrap.Bootstrap; -import io.netty.buffer.ByteBufAllocator; import io.netty.channel.Channel; -import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; -import io.netty.channel.WriteBufferWaterMark; import io.netty.channel.epoll.EpollDomainSocketChannel; import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.kqueue.KQueue; @@ -19,8 +13,6 @@ import io.netty.channel.kqueue.KQueueEventLoopGroup; import io.netty.channel.unix.DomainSocketAddress; import io.netty.util.concurrent.DefaultThreadFactory; -import java.util.concurrent.CompletableFuture; -import redis_request.RedisRequestOuterClass.RedisRequest; /** * A UDS connection manager. This class is responsible for: @@ -39,6 +31,8 @@ */ public class SocketManager { + private static final String socketPath = getSocket(); + /** * Make an FFI call to obtain the socket path. * @@ -55,6 +49,8 @@ private static String getSocket() { // At the moment, Windows is not supported // Probably we should use NIO (NioEventLoopGroup) for Windows. + private static final boolean isMacOs = isMacOs(); + // TODO support IO-Uring and NIO /** * Detect platform to identify which native implementation to use for UDS interaction. Currently @@ -70,9 +66,6 @@ private static boolean isMacOs() { } } - /** A channel to make socket interactions with. */ - private Channel channel = null; - /** Thread pool supplied to Netty to perform all async IO. */ private EventLoopGroup group = null; @@ -93,26 +86,14 @@ public static synchronized SocketManager getInstance() { /** Constructor for the single instance. */ private SocketManager() { - boolean isMacOs = isMacOs(); try { int cpuCount = Runtime.getRuntime().availableProcessors(); group = isMacOs ? new KQueueEventLoopGroup( - cpuCount, new DefaultThreadFactory("NettyWrapper-kqueue-elg", true)) + cpuCount, new DefaultThreadFactory("SocketManager-kqueue-elg", true)) : new EpollEventLoopGroup( - cpuCount, new DefaultThreadFactory("NettyWrapper-epoll-elg", true)); - channel = - new Bootstrap() - .option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(1024, 4096)) - .option(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT) - .group(group) - .channel(isMacOs ? KQueueDomainSocketChannel.class : EpollDomainSocketChannel.class) - .handler(new ChannelBuilder()) - .connect(new DomainSocketAddress(getSocket())) - .sync() - .channel(); - + cpuCount, new DefaultThreadFactory("SocketManager-epoll-elg", true)); } catch (Exception e) { System.err.printf( "Failed to create a channel %s: %s%n", e.getClass().getSimpleName(), e.getMessage()); @@ -120,23 +101,21 @@ cpuCount, new DefaultThreadFactory("NettyWrapper-kqueue-elg", true)) } } - /** Write a protobuf message to the socket. */ - public CompletableFuture write(RedisRequest.Builder request, boolean flush) { - var commandId = CallbackManager.registerRequest(); - request.setCallbackIdx(commandId.getKey()); - - if (flush) { - channel.writeAndFlush(request.build().toByteArray()); - } else { - channel.write(request.build().toByteArray()); + public Channel openNewChannel(CallbackManager callbackManager) { + try { + return new Bootstrap() + .group(group) + .channel(isMacOs ? KQueueDomainSocketChannel.class : EpollDomainSocketChannel.class) + .handler(new ChannelBuilder(callbackManager)) + .connect(new DomainSocketAddress(socketPath)) + .sync() + .channel(); + } catch (InterruptedException e) { + System.err.printf( + "Failed to create a channel %s: %s%n", e.getClass().getSimpleName(), e.getMessage()); + e.printStackTrace(System.err); + throw new RuntimeException(e); } - return commandId.getValue(); - } - - /** Write a protobuf message to the socket. */ - public CompletableFuture connect(ConnectionRequest request) { - channel.writeAndFlush(request.toByteArray()); - return CallbackManager.registerConnection(); } /** @@ -144,10 +123,8 @@ public CompletableFuture connect(ConnectionRequest request) { * #getInstance()} will create a new connection with new resource pool. */ public void close() { - channel.close(); group.shutdownGracefully(); INSTANCE = null; - CallbackManager.clean(); } /** @@ -168,6 +145,6 @@ public void run() { static { Runtime.getRuntime() - .addShutdownHook(new Thread(new ShutdownHook(), "NettyWrapper-shutdown-hook")); + .addShutdownHook(new Thread(new ShutdownHook(), "SocketManager-shutdown-hook")); } }