Skip to content

Commit

Permalink
Add managers layer to client (#44)
Browse files Browse the repository at this point in the history
* Managers

Signed-off-by: Yury-Fridlyand <[email protected]>

* Refactor

Signed-off-by: Yury-Fridlyand <[email protected]>

* Split `ClientState` and `Platform`.

Signed-off-by: Yury-Fridlyand <[email protected]>

* Refactor

Signed-off-by: Yury-Fridlyand <[email protected]>

* Address PR feedback.

Signed-off-by: Yury-Fridlyand <[email protected]>

* Minor fix.

Signed-off-by: Yury-Fridlyand <[email protected]>

* Typo fix.

Signed-off-by: Yury-Fridlyand <[email protected]>

* Refactor

Signed-off-by: Yury-Fridlyand <[email protected]>

* javadocs

Signed-off-by: Yury-Fridlyand <[email protected]>

---------

Signed-off-by: Yury-Fridlyand <[email protected]>
  • Loading branch information
Yury-Fridlyand authored Dec 21, 2023
1 parent 12dcaba commit 68b42a5
Show file tree
Hide file tree
Showing 11 changed files with 348 additions and 101 deletions.
3 changes: 3 additions & 0 deletions java/client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,19 @@ tasks.register('cleanProtobuf') {
tasks.register('buildRustRelease', Exec) {
commandLine 'cargo', 'build', '--release'
workingDir project.rootDir
environment 'CARGO_TERM_COLOR', 'always'
}

tasks.register('buildRustReleaseStrip', Exec) {
commandLine 'cargo', 'build', '--release', '--strip'
workingDir project.rootDir
environment 'CARGO_TERM_COLOR', 'always'
}

tasks.register('buildRust', Exec) {
commandLine 'cargo', 'build'
workingDir project.rootDir
environment 'CARGO_TERM_COLOR', 'always'
}

tasks.register('buildWithRust') {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,39 @@
package babushka.connectors.handlers;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.tuple.Pair;
import response.ResponseOuterClass.Response;

/** Holder for resources required to dispatch responses and used by {@link ReadHandler}. */
@RequiredArgsConstructor
public class CallbackDispatcher {
/** Unique request ID (callback ID). Thread-safe. */
private final AtomicInteger requestId = new AtomicInteger(0);

/** Client connection status needed to distinguish connection request. */
private final AtomicBoolean connectionStatus;

/** Reserved callback ID for connection request. */
private final Integer CONNECTION_PROMISE_ID = 0;

/**
* Storage of Futures to handle responses. Map key is callback id, which starts from 1.<br>
* Each future is a promise for every submitted by user request.
* Each future is a promise for every submitted by user request.<br>
* Note: Protobuf packet contains callback ID as uint32, but it stores data as a bit field.<br>
* Negative java values would be shown as positive on rust side. Meanwhile, no data loss happen,
* because callback ID remains unique.
*/
private final Map<Integer, CompletableFuture<Response>> responses = new ConcurrentHashMap<>();
private final ConcurrentHashMap<Integer, CompletableFuture<Response>> responses =
new ConcurrentHashMap<>();

/**
* Storage for connection request similar to {@link #responses}. Unfortunately, connection
* requests can't be stored in the same storage, because callback ID = 0 is hardcoded for
* connection requests.
* Storage of freed callback IDs. It is needed to avoid occupying an ID being used and to speed up
* search for a next free ID.<br>
*/
private final CompletableFuture<Response> connectionPromise = new CompletableFuture<>();
// TODO: Optimize to avoid growing up to 2e32 (16 Gb) https://github.com/aws/babushka/issues/704
private final ConcurrentLinkedQueue<Integer> freeRequestIds = new ConcurrentLinkedQueue<>();

/**
* Register a new request to be sent. Once response received, the given future completes with it.
Expand All @@ -32,14 +42,21 @@ public class CallbackDispatcher {
* response.
*/
public Pair<Integer, CompletableFuture<Response>> registerRequest() {
int callbackId = requestId.incrementAndGet();
var future = new CompletableFuture<Response>();
responses.put(callbackId, future);
Integer callbackId = connectionStatus.get() ? freeRequestIds.poll() : CONNECTION_PROMISE_ID;
synchronized (responses) {
if (callbackId == null) {
long size = responses.mappingCount();
callbackId = (int) (size < Integer.MAX_VALUE ? size : -(size - Integer.MAX_VALUE));
}
responses.put(callbackId, future);
}
return Pair.of(callbackId, future);
}

public CompletableFuture<Response> registerConnection() {
return connectionPromise;
var res = registerRequest();
return res.getValue();
}

/**
Expand All @@ -48,17 +65,22 @@ public CompletableFuture<Response> registerConnection() {
* @param response A response received
*/
public void completeRequest(Response response) {
int callbackId = response.getCallbackIdx();
if (callbackId == 0) {
connectionPromise.completeAsync(() -> response);
// A connection response doesn't contain a callback id
int callbackId = connectionStatus.get() ? response.getCallbackIdx() : CONNECTION_PROMISE_ID;
CompletableFuture<Response> future = responses.get(callbackId);
if (future != null) {
future.completeAsync(() -> response);
} else {
responses.get(callbackId).completeAsync(() -> response);
// TODO: log an error.
// probably a response was received after shutdown or `registerRequest` call was missing
}
synchronized (responses) {
responses.remove(callbackId);
}
freeRequestIds.add(callbackId);
}

public void shutdownGracefully() {
connectionPromise.cancel(false);
responses.values().forEach(future -> future.cancel(false));
responses.clear();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package babushka.connectors.handlers;

import babushka.connectors.resources.Platform;
import babushka.connectors.resources.ThreadPoolAllocator;
import connection_request.ConnectionRequestOuterClass.ConnectionRequest;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.unix.DomainSocketAddress;
import java.util.OptionalInt;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import redis_request.RedisRequestOuterClass.RedisRequest;
import response.ResponseOuterClass.Response;

Expand All @@ -24,7 +24,7 @@ public ChannelHandler(CallbackDispatcher callbackDispatcher, String socketPath)
channel =
new Bootstrap()
// TODO let user specify the thread pool or pool size as an option
.group(Platform.createNettyThreadPool("babushka-channel", OptionalInt.empty()))
.group(ThreadPoolAllocator.createNettyThreadPool("babushka-channel", Optional.empty()))
.channel(Platform.getClientUdsNettyChannelType())
.handler(new ProtobufSocketChannelInitializer(callbackDispatcher))
.connect(new DomainSocketAddress(socketPath))
Expand Down Expand Up @@ -52,13 +52,9 @@ public CompletableFuture<Response> connect(ConnectionRequest request) {
return callbackDispatcher.registerConnection();
}

private final AtomicBoolean closed = new AtomicBoolean(false);

/** Closes the UDS connection and frees corresponding resources. */
public void close() {
if (closed.compareAndSet(false, true)) {
channel.close();
callbackDispatcher.shutdownGracefully();
}
channel.close();
callbackDispatcher.shutdownGracefully();
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
package babushka.connectors.resources;

import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollDomainSocketChannel;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.kqueue.KQueue;
import io.netty.channel.kqueue.KQueueDomainSocketChannel;
import io.netty.channel.kqueue.KQueueEventLoopGroup;
import io.netty.channel.unix.DomainSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.util.Map;
import java.util.OptionalInt;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
Expand Down Expand Up @@ -44,12 +36,6 @@ public static class Capabilities {
private static final Capabilities capabilities =
new Capabilities(isKQueueAvailable(), isEPollAvailable(), false, false);

/**
* Thread pools supplied to <em>Netty</em> to perform all async IO.<br>
* Map key is supposed to be pool name + thread count as a string concat product.
*/
private static final Map<String, EventLoopGroup> groups = new ConcurrentHashMap<>();

/** Detect <em>kqueue</em> availability. */
private static boolean isKQueueAvailable() {
try {
Expand All @@ -70,42 +56,6 @@ private static boolean isEPollAvailable() {
}
}

/**
* Allocate Netty thread pool required to manage connection. A thread pool could be shared across
* multiple connections.
*
* @return A new thread pool.
*/
public static EventLoopGroup createNettyThreadPool(String prefix, OptionalInt threadLimit) {
int threadCount = threadLimit.orElse(Runtime.getRuntime().availableProcessors());
if (capabilities.isKQueueAvailable()) {
var name = prefix + "-kqueue-elg";
return getOrCreate(
name + threadCount,
() -> new KQueueEventLoopGroup(threadCount, new DefaultThreadFactory(name, true)));
} else if (capabilities.isEPollAvailable()) {
var name = prefix + "-epoll-elg";
return getOrCreate(
name + threadCount,
() -> new EpollEventLoopGroup(threadCount, new DefaultThreadFactory(name, true)));
}
// TODO support IO-Uring and NIO

throw new RuntimeException("Current platform supports no known thread pool types");
}

/**
* Get a cached thread pool from {@link #groups} or create a new one by given lambda and cache.
*/
private static EventLoopGroup getOrCreate(String name, Supplier<EventLoopGroup> supplier) {
if (groups.containsKey(name)) {
return groups.get(name);
}
var group = supplier.get();
groups.put(name, group);
return group;
}

/**
* Get a channel class required by Netty to open a client UDS channel.
*
Expand All @@ -120,20 +70,4 @@ public static Class<? extends DomainSocketChannel> getClientUdsNettyChannelType(
}
throw new RuntimeException("Current platform supports no known socket types");
}

/**
* A JVM shutdown hook to be registered. It is responsible for closing connection and freeing
* resources. It is recommended to use a class instead of lambda to ensure that it is called.<br>
* See {@link Runtime#addShutdownHook}.
*/
private static class ShutdownHook implements Runnable {
@Override
public void run() {
groups.values().forEach(EventLoopGroup::shutdownGracefully);
}
}

static {
Runtime.getRuntime().addShutdownHook(new Thread(new ShutdownHook(), "Babushka-shutdown-hook"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package babushka.connectors.resources;

import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.kqueue.KQueueEventLoopGroup;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;

/** A class responsible to allocating and deallocating shared thread pools. */
public class ThreadPoolAllocator {

/**
* Thread pools supplied to <em>Netty</em> to perform all async IO.<br>
* Map key is supposed to be pool name + thread count as a string concat product.
*/
private static final Map<String, EventLoopGroup> groups = new ConcurrentHashMap<>();

/**
* Allocate (create new or share existing) Netty thread pool required to manage connection. A
* thread pool could be shared across multiple connections.
*
* @return A new thread pool.
*/
public static EventLoopGroup createNettyThreadPool(String prefix, Optional<Integer> threadLimit) {
int threadCount = threadLimit.orElse(Runtime.getRuntime().availableProcessors());
if (Platform.getCapabilities().isKQueueAvailable()) {
String name = prefix + "-kqueue-elg";
return getOrCreate(
name + threadCount,
() -> new KQueueEventLoopGroup(threadCount, new DefaultThreadFactory(name, true)));
} else if (Platform.getCapabilities().isEPollAvailable()) {
String name = prefix + "-epoll-elg";
return getOrCreate(
name + threadCount,
() -> new EpollEventLoopGroup(threadCount, new DefaultThreadFactory(name, true)));
}
// TODO support IO-Uring and NIO

throw new RuntimeException("Current platform supports no known thread pool types");
}

/**
* Get a cached thread pool from {@link #groups} or create a new one by given lambda and cache.
*/
private static EventLoopGroup getOrCreate(String name, Supplier<EventLoopGroup> supplier) {
if (groups.containsKey(name)) {
return groups.get(name);
}
EventLoopGroup group = supplier.get();
groups.put(name, group);
return group;
}

/**
* A JVM shutdown hook to be registered. It is responsible for closing connection and freeing
* resources. It is recommended to use a class instead of lambda to ensure that it is called.<br>
* See {@link Runtime#addShutdownHook}.
*/
private static class ShutdownHook implements Runnable {
@Override
public void run() {
groups.values().forEach(EventLoopGroup::shutdownGracefully);
}
}

static {
Runtime.getRuntime().addShutdownHook(new Thread(new ShutdownHook(), "Babushka-shutdown-hook"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package babushka.ffi.resolvers;

import response.ResponseOuterClass.Response;

public class RedisValueResolver {
/**
* Resolve a value received from Redis using given C-style pointer.
*
* @param pointer A memory pointer from {@link Response}
* @return A RESP3 value
*/
public static native Object valueFromPointer(long pointer);
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package babushka.ffi.resolvers;

public class BabushkaCoreNativeDefinitions {
public static native String startSocketListenerExternal() throws Exception;
public class SocketListenerResolver {

public static native Object valueFromPointer(long pointer);
/** Make an FFI call to Babushka to open a UDS socket to connect to. */
private static native String startSocketListener() throws Exception;

static {
System.loadLibrary("javababushka");
Expand All @@ -16,7 +16,7 @@ public class BabushkaCoreNativeDefinitions {
*/
public static String getSocket() {
try {
return startSocketListenerExternal();
return startSocketListener();
} catch (Exception | UnsatisfiedLinkError e) {
System.err.printf("Failed to create a UDS connection: %s%n%n", e);
throw new RuntimeException(e);
Expand Down
Loading

0 comments on commit 68b42a5

Please sign in to comment.