diff --git a/java/client/build.gradle b/java/client/build.gradle index be9120cf3f..4aa2415723 100644 --- a/java/client/build.gradle +++ b/java/client/build.gradle @@ -52,16 +52,19 @@ tasks.register('cleanProtobuf') { tasks.register('buildRustRelease', Exec) { commandLine 'cargo', 'build', '--release' workingDir project.rootDir + environment 'CARGO_TERM_COLOR', 'always' } tasks.register('buildRustReleaseStrip', Exec) { commandLine 'cargo', 'build', '--release', '--strip' workingDir project.rootDir + environment 'CARGO_TERM_COLOR', 'always' } tasks.register('buildRust', Exec) { commandLine 'cargo', 'build' workingDir project.rootDir + environment 'CARGO_TERM_COLOR', 'always' } tasks.register('buildWithRust') { diff --git a/java/client/src/main/java/babushka/connectors/handlers/CallbackDispatcher.java b/java/client/src/main/java/babushka/connectors/handlers/CallbackDispatcher.java index 5ebaa03969..17589b826e 100644 --- a/java/client/src/main/java/babushka/connectors/handlers/CallbackDispatcher.java +++ b/java/client/src/main/java/babushka/connectors/handlers/CallbackDispatcher.java @@ -1,29 +1,39 @@ package babushka.connectors.handlers; -import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import lombok.RequiredArgsConstructor; import org.apache.commons.lang3.tuple.Pair; import response.ResponseOuterClass.Response; /** Holder for resources required to dispatch responses and used by {@link ReadHandler}. */ +@RequiredArgsConstructor public class CallbackDispatcher { - /** Unique request ID (callback ID). Thread-safe. */ - private final AtomicInteger requestId = new AtomicInteger(0); + + /** Client connection status needed to distinguish connection request. */ + private final AtomicBoolean connectionStatus; + + /** Reserved callback ID for connection request. */ + private final Integer CONNECTION_PROMISE_ID = 0; /** * Storage of Futures to handle responses. Map key is callback id, which starts from 1.
- * Each future is a promise for every submitted by user request. + * Each future is a promise for every submitted by user request.
+ * Note: Protobuf packet contains callback ID as uint32, but it stores data as a bit field.
+ * Negative java values would be shown as positive on rust side. Meanwhile, no data loss happen, + * because callback ID remains unique. */ - private final Map> responses = new ConcurrentHashMap<>(); + private final ConcurrentHashMap> responses = + new ConcurrentHashMap<>(); /** - * Storage for connection request similar to {@link #responses}. Unfortunately, connection - * requests can't be stored in the same storage, because callback ID = 0 is hardcoded for - * connection requests. + * Storage of freed callback IDs. It is needed to avoid occupying an ID being used and to speed up + * search for a next free ID.
*/ - private final CompletableFuture connectionPromise = new CompletableFuture<>(); + // TODO: Optimize to avoid growing up to 2e32 (16 Gb) https://github.com/aws/babushka/issues/704 + private final ConcurrentLinkedQueue freeRequestIds = new ConcurrentLinkedQueue<>(); /** * Register a new request to be sent. Once response received, the given future completes with it. @@ -32,14 +42,21 @@ public class CallbackDispatcher { * response. */ public Pair> registerRequest() { - int callbackId = requestId.incrementAndGet(); var future = new CompletableFuture(); - responses.put(callbackId, future); + Integer callbackId = connectionStatus.get() ? freeRequestIds.poll() : CONNECTION_PROMISE_ID; + synchronized (responses) { + if (callbackId == null) { + long size = responses.mappingCount(); + callbackId = (int) (size < Integer.MAX_VALUE ? size : -(size - Integer.MAX_VALUE)); + } + responses.put(callbackId, future); + } return Pair.of(callbackId, future); } public CompletableFuture registerConnection() { - return connectionPromise; + var res = registerRequest(); + return res.getValue(); } /** @@ -48,17 +65,22 @@ public CompletableFuture registerConnection() { * @param response A response received */ public void completeRequest(Response response) { - int callbackId = response.getCallbackIdx(); - if (callbackId == 0) { - connectionPromise.completeAsync(() -> response); + // A connection response doesn't contain a callback id + int callbackId = connectionStatus.get() ? response.getCallbackIdx() : CONNECTION_PROMISE_ID; + CompletableFuture future = responses.get(callbackId); + if (future != null) { + future.completeAsync(() -> response); } else { - responses.get(callbackId).completeAsync(() -> response); + // TODO: log an error. + // probably a response was received after shutdown or `registerRequest` call was missing + } + synchronized (responses) { responses.remove(callbackId); } + freeRequestIds.add(callbackId); } public void shutdownGracefully() { - connectionPromise.cancel(false); responses.values().forEach(future -> future.cancel(false)); responses.clear(); } diff --git a/java/client/src/main/java/babushka/connectors/handlers/ChannelHandler.java b/java/client/src/main/java/babushka/connectors/handlers/ChannelHandler.java index adfdcbbbcc..3a41a3e20c 100644 --- a/java/client/src/main/java/babushka/connectors/handlers/ChannelHandler.java +++ b/java/client/src/main/java/babushka/connectors/handlers/ChannelHandler.java @@ -1,13 +1,13 @@ package babushka.connectors.handlers; import babushka.connectors.resources.Platform; +import babushka.connectors.resources.ThreadPoolAllocator; import connection_request.ConnectionRequestOuterClass.ConnectionRequest; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.unix.DomainSocketAddress; -import java.util.OptionalInt; +import java.util.Optional; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicBoolean; import redis_request.RedisRequestOuterClass.RedisRequest; import response.ResponseOuterClass.Response; @@ -24,7 +24,7 @@ public ChannelHandler(CallbackDispatcher callbackDispatcher, String socketPath) channel = new Bootstrap() // TODO let user specify the thread pool or pool size as an option - .group(Platform.createNettyThreadPool("babushka-channel", OptionalInt.empty())) + .group(ThreadPoolAllocator.createNettyThreadPool("babushka-channel", Optional.empty())) .channel(Platform.getClientUdsNettyChannelType()) .handler(new ProtobufSocketChannelInitializer(callbackDispatcher)) .connect(new DomainSocketAddress(socketPath)) @@ -52,13 +52,9 @@ public CompletableFuture connect(ConnectionRequest request) { return callbackDispatcher.registerConnection(); } - private final AtomicBoolean closed = new AtomicBoolean(false); - /** Closes the UDS connection and frees corresponding resources. */ public void close() { - if (closed.compareAndSet(false, true)) { - channel.close(); - callbackDispatcher.shutdownGracefully(); - } + channel.close(); + callbackDispatcher.shutdownGracefully(); } } diff --git a/java/client/src/main/java/babushka/connectors/resources/Platform.java b/java/client/src/main/java/babushka/connectors/resources/Platform.java index b411f04f50..4967a9b9f0 100644 --- a/java/client/src/main/java/babushka/connectors/resources/Platform.java +++ b/java/client/src/main/java/babushka/connectors/resources/Platform.java @@ -1,18 +1,10 @@ package babushka.connectors.resources; -import io.netty.channel.EventLoopGroup; import io.netty.channel.epoll.Epoll; import io.netty.channel.epoll.EpollDomainSocketChannel; -import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.kqueue.KQueue; import io.netty.channel.kqueue.KQueueDomainSocketChannel; -import io.netty.channel.kqueue.KQueueEventLoopGroup; import io.netty.channel.unix.DomainSocketChannel; -import io.netty.util.concurrent.DefaultThreadFactory; -import java.util.Map; -import java.util.OptionalInt; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Supplier; import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Getter; @@ -44,12 +36,6 @@ public static class Capabilities { private static final Capabilities capabilities = new Capabilities(isKQueueAvailable(), isEPollAvailable(), false, false); - /** - * Thread pools supplied to Netty to perform all async IO.
- * Map key is supposed to be pool name + thread count as a string concat product. - */ - private static final Map groups = new ConcurrentHashMap<>(); - /** Detect kqueue availability. */ private static boolean isKQueueAvailable() { try { @@ -70,42 +56,6 @@ private static boolean isEPollAvailable() { } } - /** - * Allocate Netty thread pool required to manage connection. A thread pool could be shared across - * multiple connections. - * - * @return A new thread pool. - */ - public static EventLoopGroup createNettyThreadPool(String prefix, OptionalInt threadLimit) { - int threadCount = threadLimit.orElse(Runtime.getRuntime().availableProcessors()); - if (capabilities.isKQueueAvailable()) { - var name = prefix + "-kqueue-elg"; - return getOrCreate( - name + threadCount, - () -> new KQueueEventLoopGroup(threadCount, new DefaultThreadFactory(name, true))); - } else if (capabilities.isEPollAvailable()) { - var name = prefix + "-epoll-elg"; - return getOrCreate( - name + threadCount, - () -> new EpollEventLoopGroup(threadCount, new DefaultThreadFactory(name, true))); - } - // TODO support IO-Uring and NIO - - throw new RuntimeException("Current platform supports no known thread pool types"); - } - - /** - * Get a cached thread pool from {@link #groups} or create a new one by given lambda and cache. - */ - private static EventLoopGroup getOrCreate(String name, Supplier supplier) { - if (groups.containsKey(name)) { - return groups.get(name); - } - var group = supplier.get(); - groups.put(name, group); - return group; - } - /** * Get a channel class required by Netty to open a client UDS channel. * @@ -120,20 +70,4 @@ public static Class getClientUdsNettyChannelType( } throw new RuntimeException("Current platform supports no known socket types"); } - - /** - * A JVM shutdown hook to be registered. It is responsible for closing connection and freeing - * resources. It is recommended to use a class instead of lambda to ensure that it is called.
- * See {@link Runtime#addShutdownHook}. - */ - private static class ShutdownHook implements Runnable { - @Override - public void run() { - groups.values().forEach(EventLoopGroup::shutdownGracefully); - } - } - - static { - Runtime.getRuntime().addShutdownHook(new Thread(new ShutdownHook(), "Babushka-shutdown-hook")); - } } diff --git a/java/client/src/main/java/babushka/connectors/resources/ThreadPoolAllocator.java b/java/client/src/main/java/babushka/connectors/resources/ThreadPoolAllocator.java new file mode 100644 index 0000000000..daefdb93e5 --- /dev/null +++ b/java/client/src/main/java/babushka/connectors/resources/ThreadPoolAllocator.java @@ -0,0 +1,72 @@ +package babushka.connectors.resources; + +import io.netty.channel.EventLoopGroup; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.kqueue.KQueueEventLoopGroup; +import io.netty.util.concurrent.DefaultThreadFactory; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Supplier; + +/** A class responsible to allocating and deallocating shared thread pools. */ +public class ThreadPoolAllocator { + + /** + * Thread pools supplied to Netty to perform all async IO.
+ * Map key is supposed to be pool name + thread count as a string concat product. + */ + private static final Map groups = new ConcurrentHashMap<>(); + + /** + * Allocate (create new or share existing) Netty thread pool required to manage connection. A + * thread pool could be shared across multiple connections. + * + * @return A new thread pool. + */ + public static EventLoopGroup createNettyThreadPool(String prefix, Optional threadLimit) { + int threadCount = threadLimit.orElse(Runtime.getRuntime().availableProcessors()); + if (Platform.getCapabilities().isKQueueAvailable()) { + String name = prefix + "-kqueue-elg"; + return getOrCreate( + name + threadCount, + () -> new KQueueEventLoopGroup(threadCount, new DefaultThreadFactory(name, true))); + } else if (Platform.getCapabilities().isEPollAvailable()) { + String name = prefix + "-epoll-elg"; + return getOrCreate( + name + threadCount, + () -> new EpollEventLoopGroup(threadCount, new DefaultThreadFactory(name, true))); + } + // TODO support IO-Uring and NIO + + throw new RuntimeException("Current platform supports no known thread pool types"); + } + + /** + * Get a cached thread pool from {@link #groups} or create a new one by given lambda and cache. + */ + private static EventLoopGroup getOrCreate(String name, Supplier supplier) { + if (groups.containsKey(name)) { + return groups.get(name); + } + EventLoopGroup group = supplier.get(); + groups.put(name, group); + return group; + } + + /** + * A JVM shutdown hook to be registered. It is responsible for closing connection and freeing + * resources. It is recommended to use a class instead of lambda to ensure that it is called.
+ * See {@link Runtime#addShutdownHook}. + */ + private static class ShutdownHook implements Runnable { + @Override + public void run() { + groups.values().forEach(EventLoopGroup::shutdownGracefully); + } + } + + static { + Runtime.getRuntime().addShutdownHook(new Thread(new ShutdownHook(), "Babushka-shutdown-hook")); + } +} diff --git a/java/client/src/main/java/babushka/ffi/resolvers/RedisValueResolver.java b/java/client/src/main/java/babushka/ffi/resolvers/RedisValueResolver.java new file mode 100644 index 0000000000..133ccef0ab --- /dev/null +++ b/java/client/src/main/java/babushka/ffi/resolvers/RedisValueResolver.java @@ -0,0 +1,13 @@ +package babushka.ffi.resolvers; + +import response.ResponseOuterClass.Response; + +public class RedisValueResolver { + /** + * Resolve a value received from Redis using given C-style pointer. + * + * @param pointer A memory pointer from {@link Response} + * @return A RESP3 value + */ + public static native Object valueFromPointer(long pointer); +} diff --git a/java/client/src/main/java/babushka/ffi/resolvers/BabushkaCoreNativeDefinitions.java b/java/client/src/main/java/babushka/ffi/resolvers/SocketListenerResolver.java similarity index 63% rename from java/client/src/main/java/babushka/ffi/resolvers/BabushkaCoreNativeDefinitions.java rename to java/client/src/main/java/babushka/ffi/resolvers/SocketListenerResolver.java index 6d4ec45121..112582f9ba 100644 --- a/java/client/src/main/java/babushka/ffi/resolvers/BabushkaCoreNativeDefinitions.java +++ b/java/client/src/main/java/babushka/ffi/resolvers/SocketListenerResolver.java @@ -1,9 +1,9 @@ package babushka.ffi.resolvers; -public class BabushkaCoreNativeDefinitions { - public static native String startSocketListenerExternal() throws Exception; +public class SocketListenerResolver { - public static native Object valueFromPointer(long pointer); + /** Make an FFI call to Babushka to open a UDS socket to connect to. */ + private static native String startSocketListener() throws Exception; static { System.loadLibrary("javababushka"); @@ -16,7 +16,7 @@ public class BabushkaCoreNativeDefinitions { */ public static String getSocket() { try { - return startSocketListenerExternal(); + return startSocketListener(); } catch (Exception | UnsatisfiedLinkError e) { System.err.printf("Failed to create a UDS connection: %s%n%n", e); throw new RuntimeException(e); diff --git a/java/client/src/main/java/babushka/managers/CommandManager.java b/java/client/src/main/java/babushka/managers/CommandManager.java new file mode 100644 index 0000000000..79a7474a3f --- /dev/null +++ b/java/client/src/main/java/babushka/managers/CommandManager.java @@ -0,0 +1,81 @@ +package babushka.managers; + +import babushka.connectors.handlers.ChannelHandler; +import babushka.ffi.resolvers.RedisValueResolver; +import babushka.models.RequestBuilder; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import lombok.RequiredArgsConstructor; +import redis_request.RedisRequestOuterClass.RequestType; +import response.ResponseOuterClass.Response; + +@RequiredArgsConstructor +public class CommandManager { + + /** UDS connection representation. */ + private final ChannelHandler channel; + + /** + * Async (non-blocking) get.
+ * See REDIS docs for GET. + * + * @param key The key name + */ + public CompletableFuture get(String key) { + return submitNewRequest(RequestType.GetString, List.of(key)); + } + + /** + * Async (non-blocking) set.
+ * See REDIS docs for SET. + * + * @param key The key name + * @param value The value to set + */ + public CompletableFuture set(String key, String value) { + return submitNewRequest(RequestType.SetString, List.of(key, value)); + } + + /** + * Build a command and submit it Netty to send. + * + * @param command Command type + * @param args Command arguments + * @return A result promise + */ + private CompletableFuture submitNewRequest(RequestType command, List args) { + // TODO this explicitly uses ForkJoin thread pool. May be we should use another one. + return CompletableFuture.supplyAsync( + () -> channel.write(RequestBuilder.prepareRedisRequest(command, args), true)) + // TODO: is there a better way to execute this? + .thenComposeAsync(f -> f) + .thenApplyAsync(this::extractValueFromResponse); + } + + /** + * Check response and extract data from it. + * + * @param response A response received from Babushka + * @return A String from the Redis RESP2 response, or Ok. Otherwise, returns null + */ + private String extractValueFromResponse(Response response) { + if (response.hasRequestError()) { + // TODO do we need to support different types of exceptions and distinguish them by type? + throw new RuntimeException( + String.format( + "%s: %s", + response.getRequestError().getType(), response.getRequestError().getMessage())); + } else if (response.hasClosingError()) { + CompletableFuture.runAsync(channel::close); + throw new RuntimeException("Connection closed: " + response.getClosingError()); + } else if (response.hasConstantResponse()) { + return response.getConstantResponse().toString(); + } else if (response.hasRespPointer()) { + return RedisValueResolver.valueFromPointer(response.getRespPointer()).toString(); + } + // TODO commented out due to #710 https://github.com/aws/babushka/issues/710 + // empty response means a successful command + // throw new IllegalStateException("A malformed response received: " + response.toString()); + return "OK"; + } +} diff --git a/java/client/src/main/java/babushka/managers/ConnectionManager.java b/java/client/src/main/java/babushka/managers/ConnectionManager.java new file mode 100644 index 0000000000..4bb38c27ee --- /dev/null +++ b/java/client/src/main/java/babushka/managers/ConnectionManager.java @@ -0,0 +1,68 @@ +package babushka.managers; + +import babushka.connectors.handlers.ChannelHandler; +import babushka.ffi.resolvers.RedisValueResolver; +import babushka.models.RequestBuilder; +import connection_request.ConnectionRequestOuterClass.ConnectionRequest; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import lombok.RequiredArgsConstructor; +import response.ResponseOuterClass.ConstantResponse; +import response.ResponseOuterClass.Response; + +@RequiredArgsConstructor +public class ConnectionManager { + + /** UDS connection representation. */ + private final ChannelHandler channel; + + /** Client connection status to update when connection established. */ + private final AtomicBoolean connectionStatus; + + /** + * Connect to Redis using a ProtoBuf connection request. + * + * @param host Server address + * @param port Server port + * @param useSsl true if communication with the server or cluster should use Transport Level + * Security + * @param clusterMode true if REDIS instance runs in the cluster mode + */ + // TODO support more parameters and/or configuration object + public CompletableFuture connectToRedis( + String host, int port, boolean useSsl, boolean clusterMode) { + ConnectionRequest request = + RequestBuilder.createConnectionRequest(host, port, useSsl, clusterMode); + return channel.connect(request).thenApplyAsync(this::checkBabushkaResponse); + } + + /** Check a response received from Babushka. */ + private boolean checkBabushkaResponse(Response response) { + // TODO do we need to check callback value? It could be -1 or 0 + if (response.hasRequestError()) { + // TODO do we need to support different types of exceptions and distinguish them by type? + throw new RuntimeException( + String.format( + "%s: %s", + response.getRequestError().getType(), response.getRequestError().getMessage())); + } else if (response.hasClosingError()) { + throw new RuntimeException("Connection closed: " + response.getClosingError()); + } else if (response.hasConstantResponse()) { + return connectionStatus.compareAndSet( + false, response.getConstantResponse() == ConstantResponse.OK); + } else if (response.hasRespPointer()) { + throw new RuntimeException( + "Unexpected response data: " + + RedisValueResolver.valueFromPointer(response.getRespPointer())); + } + // TODO commented out due to #710 https://github.com/aws/babushka/issues/710 + // empty response means a successful connection + // throw new IllegalStateException("A malformed response received: " + response.toString()); + return connectionStatus.compareAndSet(false, true); + } + + /** Close the connection and the corresponding channel. */ + public CompletableFuture closeConnection() { + return CompletableFuture.runAsync(channel::close); + } +} diff --git a/java/client/src/main/java/babushka/models/RequestBuilder.java b/java/client/src/main/java/babushka/models/RequestBuilder.java new file mode 100644 index 0000000000..2ec729e4eb --- /dev/null +++ b/java/client/src/main/java/babushka/models/RequestBuilder.java @@ -0,0 +1,60 @@ +package babushka.models; + +import babushka.connectors.handlers.CallbackDispatcher; +import babushka.managers.CommandManager; +import babushka.managers.ConnectionManager; +import connection_request.ConnectionRequestOuterClass.ConnectionRequest; +import connection_request.ConnectionRequestOuterClass.NodeAddress; +import connection_request.ConnectionRequestOuterClass.ReadFrom; +import connection_request.ConnectionRequestOuterClass.TlsMode; +import java.util.List; +import redis_request.RedisRequestOuterClass.Command; +import redis_request.RedisRequestOuterClass.Command.ArgsArray; +import redis_request.RedisRequestOuterClass.RedisRequest; +import redis_request.RedisRequestOuterClass.RequestType; +import redis_request.RedisRequestOuterClass.Routes; +import redis_request.RedisRequestOuterClass.SimpleRoutes; + +public class RequestBuilder { + + /** + * Build a protobuf connection request.
+ * Used by {@link ConnectionManager#connectToRedis}. + */ + // TODO support more parameters and/or configuration object + public static ConnectionRequest createConnectionRequest( + String host, int port, boolean useSsl, boolean clusterMode) { + return ConnectionRequest.newBuilder() + .addAddresses(NodeAddress.newBuilder().setHost(host).setPort(port).build()) + .setTlsMode(useSsl ? TlsMode.SecureTls : TlsMode.NoTls) + .setClusterModeEnabled(clusterMode) + .setReadFrom(ReadFrom.Primary) + .setDatabaseId(0) + .build(); + } + + /** + * Build a protobuf command/transaction request draft.
+ * Used by {@link CommandManager}. + * + * @return An uncompleted request. {@link CallbackDispatcher} is responsible to complete it by + * adding a callback id. + */ + public static RedisRequest.Builder prepareRedisRequest(RequestType command, List args) { + var commandArgs = ArgsArray.newBuilder(); + for (var arg : args) { + commandArgs.addArgs(arg); + } + + return RedisRequest.newBuilder() + .setSingleCommand( // set command + Command.newBuilder() + .setRequestType(command) // set command name + .setArgsArray(commandArgs.build()) // set arguments + .build()) + .setRoute( // set route + Routes.newBuilder() + .setSimpleRoutes(SimpleRoutes.AllNodes) // set route type + .build()); + } +} diff --git a/java/src/lib.rs b/java/src/lib.rs index 13577f0805..8ff3b684fb 100644 --- a/java/src/lib.rs +++ b/java/src/lib.rs @@ -42,9 +42,7 @@ fn redis_value_to_java(mut env: JNIEnv, val: Value) -> JObject { } #[no_mangle] -pub extern "system" fn Java_babushka_ffi_resolvers_BabushkaCoreNativeDefinitions_valueFromPointer< - 'local, ->( +pub extern "system" fn Java_babushka_ffi_resolvers_RedisValueResolver_valueFromPointer<'local>( env: JNIEnv<'local>, _class: JClass<'local>, pointer: jlong, @@ -54,7 +52,7 @@ pub extern "system" fn Java_babushka_ffi_resolvers_BabushkaCoreNativeDefinitions } #[no_mangle] -pub extern "system" fn Java_babushka_ffi_resolvers_BabushkaCoreNativeDefinitions_startSocketListenerExternal< +pub extern "system" fn Java_babushka_ffi_resolvers_SocketListenerResolver_startSocketListener< 'local, >( env: JNIEnv<'local>,