diff --git a/java/client/build.gradle b/java/client/build.gradle
index be9120cf3f..7721e4b7b7 100644
--- a/java/client/build.gradle
+++ b/java/client/build.gradle
@@ -28,6 +28,8 @@ dependencies {
// junit
testImplementation('org.junit.jupiter:junit-jupiter:5.6.2')
testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '3.12.4'
+
+ testImplementation group: 'io.netty', name: 'netty-codec-redis', version: '4.1.100.Final'
}
tasks.register('protobuf', Exec) {
diff --git a/java/client/src/main/java/babushka/connectors/resources/Platform.java b/java/client/src/main/java/babushka/connectors/resources/Platform.java
index b411f04f50..353b884daa 100644
--- a/java/client/src/main/java/babushka/connectors/resources/Platform.java
+++ b/java/client/src/main/java/babushka/connectors/resources/Platform.java
@@ -4,10 +4,16 @@
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollDomainSocketChannel;
import io.netty.channel.epoll.EpollEventLoopGroup;
+import io.netty.channel.epoll.EpollServerDomainSocketChannel;
+import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.kqueue.KQueue;
import io.netty.channel.kqueue.KQueueDomainSocketChannel;
import io.netty.channel.kqueue.KQueueEventLoopGroup;
+import io.netty.channel.kqueue.KQueueServerDomainSocketChannel;
+import io.netty.channel.kqueue.KQueueServerSocketChannel;
+import io.netty.channel.socket.ServerSocketChannel;
import io.netty.channel.unix.DomainSocketChannel;
+import io.netty.channel.unix.ServerDomainSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.util.Map;
import java.util.OptionalInt;
@@ -121,6 +127,36 @@ public static Class extends DomainSocketChannel> getClientUdsNettyChannelType(
throw new RuntimeException("Current platform supports no known socket types");
}
+ /**
+ * Get a channel class required by Netty to open a server UDS channel.
+ *
+ * @return Return a class, supported by the current native platform.
+ */
+ public static Class extends ServerDomainSocketChannel> getServerUdsNettyChannelType() {
+ if (capabilities.isKQueueAvailable()) {
+ return KQueueServerDomainSocketChannel.class;
+ }
+ if (capabilities.isEPollAvailable()) {
+ return EpollServerDomainSocketChannel.class;
+ }
+ throw new RuntimeException("Current platform supports no known socket types");
+ }
+
+ /**
+ * Get a channel class required by Netty to open a server TCP channel.
+ *
+ * @return Return a class, supported by the current native platform.
+ */
+ public static Class extends ServerSocketChannel> getServerTcpNettyChannelType() {
+ if (capabilities.isKQueueAvailable()) {
+ return KQueueServerSocketChannel.class;
+ }
+ if (capabilities.isEPollAvailable()) {
+ return EpollServerSocketChannel.class;
+ }
+ throw new RuntimeException("Current platform supports no known socket types");
+ }
+
/**
* A JVM shutdown hook to be registered. It is responsible for closing connection and freeing
* resources. It is recommended to use a class instead of lambda to ensure that it is called.
diff --git a/java/client/src/test/java/babushka/utils/Awaiter.java b/java/client/src/test/java/babushka/utils/Awaiter.java
new file mode 100644
index 0000000000..2f54e37a39
--- /dev/null
+++ b/java/client/src/test/java/babushka/utils/Awaiter.java
@@ -0,0 +1,35 @@
+package babushka.utils;
+
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+public class Awaiter {
+ private static final long DEFAULT_TIMEOUT_MILLISECONDS = 30000;
+
+ /** Get the future result with default timeout. */
+ public static T await(Future future) {
+ return await(future, DEFAULT_TIMEOUT_MILLISECONDS);
+ }
+
+ /** Get the future result with given timeout in ms. */
+ public static T await(Future future, long timeout) {
+ try {
+ return future.get(timeout, TimeUnit.MILLISECONDS);
+ } catch (TimeoutException e) {
+ throw new RuntimeException("Request timed out", e);
+ } catch (ExecutionException e) {
+ throw new RuntimeException(e.getMessage(), e.getCause());
+ } catch (InterruptedException e) {
+ if (Thread.currentThread().isInterrupted()) {
+ // restore interrupt
+ Thread.interrupted();
+ }
+ throw new RuntimeException("The thread was interrupted", e);
+ } catch (CancellationException e) {
+ throw new RuntimeException("Request was cancelled", e);
+ }
+ }
+}
diff --git a/java/client/src/test/java/babushka/utils/RedisMockTestBase.java b/java/client/src/test/java/babushka/utils/RedisMockTestBase.java
new file mode 100644
index 0000000000..d194f02b63
--- /dev/null
+++ b/java/client/src/test/java/babushka/utils/RedisMockTestBase.java
@@ -0,0 +1,42 @@
+package babushka.utils;
+
+import lombok.SneakyThrows;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+
+public class RedisMockTestBase {
+
+ public static boolean started = false;
+
+ @SneakyThrows
+ public static void startRedisMock(RedisServerMock.ServerMock serverMock) {
+ assert !started
+ : "Previous `RedisMock` wasn't stopped, probably your test class does not inherit"
+ + " `RedisMockTestBase`.";
+ RedisServerMock.start(serverMock);
+ started = true;
+ }
+
+ @BeforeEach
+ public void preTestCheck() {
+ assert started
+ : "You missed to call `startRustCoreLibMock` in a `@BeforeAll` method of your test class"
+ + " inherited from `RedisMockTestBase`.";
+ }
+
+ @AfterEach
+ public void afterTestCheck() {
+ assert !RedisServerMock.failed() : "Error occurred in `RedisMock`";
+ }
+
+ @AfterAll
+ @SneakyThrows
+ public static void stopRedisMock() {
+ assert started
+ : "You missed to call `startRustCoreLibMock` in a `@BeforeAll` method of your test class"
+ + " inherited from `RedisMockTestBase`.";
+ RedisServerMock.stop();
+ started = false;
+ }
+}
diff --git a/java/client/src/test/java/babushka/utils/RedisServerMock.java b/java/client/src/test/java/babushka/utils/RedisServerMock.java
new file mode 100644
index 0000000000..a233659f9d
--- /dev/null
+++ b/java/client/src/test/java/babushka/utils/RedisServerMock.java
@@ -0,0 +1,219 @@
+package babushka.utils;
+
+import babushka.connectors.resources.Platform;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.codec.CodecException;
+import io.netty.handler.codec.redis.ArrayRedisMessage;
+import io.netty.handler.codec.redis.ErrorRedisMessage;
+import io.netty.handler.codec.redis.FullBulkStringRedisMessage;
+import io.netty.handler.codec.redis.IntegerRedisMessage;
+import io.netty.handler.codec.redis.RedisArrayAggregator;
+import io.netty.handler.codec.redis.RedisBulkStringAggregator;
+import io.netty.handler.codec.redis.RedisDecoder;
+import io.netty.handler.codec.redis.RedisEncoder;
+import io.netty.handler.codec.redis.RedisMessage;
+import io.netty.handler.codec.redis.SimpleStringRedisMessage;
+import io.netty.util.CharsetUtil;
+import java.net.InetSocketAddress;
+import java.util.OptionalInt;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.stream.Collectors;
+import lombok.Setter;
+
+public class RedisServerMock {
+
+ public abstract static class ServerMock {
+ /** Return `null` to do not reply. */
+ public abstract RedisMessage reply(String cmd);
+
+ protected RedisMessage reply0(String cmd) {
+ return reply(cmd);
+ }
+
+ public static RedisMessage error(String text) {
+ return new ErrorRedisMessage(text);
+ }
+
+ public static RedisMessage error(String prefix, String text) {
+ // https://redis.io/docs/reference/protocol-spec/#simple-errors
+ if (prefix.contains(" ") || prefix.contains("\r") || prefix.contains("\n")) {
+ throw new IllegalArgumentException();
+ }
+ return new ErrorRedisMessage(prefix.toUpperCase() + " " + text);
+ }
+
+ public static RedisMessage simpleString(String text) {
+ return new SimpleStringRedisMessage(text);
+ }
+
+ public static RedisMessage OK() {
+ return simpleString("OK");
+ }
+
+ public static RedisMessage number(long value) {
+ return new IntegerRedisMessage(value);
+ }
+
+ /** A multi-line message. */
+ public static RedisMessage multiString(String text) {
+ return new FullBulkStringRedisMessage(Unpooled.copiedBuffer(text.getBytes()));
+ }
+ }
+
+ public abstract static class ServerMockConnectAll extends ServerMock {
+ @Override
+ protected RedisMessage reply0(String cmd) {
+ if (cmd.startsWith("CLIENT SETINFO")) {
+ return OK();
+ } else if (cmd.startsWith("INFO REPLICATION")) {
+ var response =
+ "# Replication\r\n"
+ + "role:master\r\n"
+ + "connected_slaves:0\r\n"
+ + "master_failover_state:no-failover\r\n"
+ + "master_replid:d7646c8d14901de9347f1f675c70bcf269a503eb\r\n"
+ + "master_replid2:0000000000000000000000000000000000000000\r\n"
+ + "master_repl_offset:0\r\n"
+ + "second_repl_offset:-1\r\n"
+ + "repl_backlog_active:0\r\n"
+ + "repl_backlog_size:1048576\r\n"
+ + "repl_backlog_first_byte_offset:0\r\n"
+ + "repl_backlog_histlen:0\r\n";
+ return multiString(response);
+ }
+ return reply(cmd);
+ }
+ }
+
+ // TODO support configurable port to test cluster mode
+ public static final int PORT = 6380;
+
+ /** Thread pool supplied to Netty to perform all async IO. */
+ private EventLoopGroup group;
+
+ private Channel channel;
+
+ private static RedisServerMock instance;
+
+ private ServerMock messageProcessor;
+
+ /** Update {@link ServerMock} into a running {@link RedisServerMock}. */
+ public static void updateServerMock(ServerMock newMock) {
+ instance.messageProcessor = newMock;
+ }
+
+ private final AtomicBoolean failed = new AtomicBoolean(false);
+
+ /** Get and clear failure status. */
+ public static boolean failed() {
+ return instance.failed.compareAndSet(true, false);
+ }
+
+ @Setter private static boolean debugLogging = false;
+
+ private RedisServerMock() {
+ try {
+ channel =
+ new ServerBootstrap()
+ .group(group = Platform.createNettyThreadPool("RedisMock", OptionalInt.empty()))
+ .channel(Platform.getServerTcpNettyChannelType())
+ .childHandler(
+ new ChannelInitializer() {
+ @Override
+ protected void initChannel(SocketChannel ch) throws Exception {
+ ch.pipeline()
+ // https://github.com/netty/netty/blob/4.1/example/src/main/java/io/netty/example/redis/RedisClient.java
+ .addLast(new RedisDecoder())
+ .addLast(new RedisBulkStringAggregator())
+ .addLast(new RedisArrayAggregator())
+ .addLast(new RedisEncoder())
+ .addLast(
+ new ChannelInboundHandlerAdapter() {
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg)
+ throws Exception {
+ RedisMessage redisMessage = (RedisMessage) msg;
+ var str = RedisMessageToString(redisMessage);
+ if (debugLogging) {
+ System.out.printf("-- Received%n %s%n", str);
+ }
+ var response = messageProcessor.reply0(str);
+ if (response != null) {
+ if (debugLogging) {
+ System.out.printf(
+ "-- Replying with%n %s%n",
+ RedisMessageToString(response));
+ }
+ ctx.writeAndFlush(response);
+ } else if (debugLogging) {
+ System.out.printf("-- Ignoring%n");
+ }
+ }
+
+ @Override
+ public void exceptionCaught(
+ ChannelHandlerContext ctx, Throwable cause) throws Exception {
+ cause.printStackTrace();
+ ctx.close();
+ failed.setPlain(true);
+ }
+ });
+ }
+ })
+ .bind(new InetSocketAddress(PORT))
+ // .sync()
+ .channel();
+ } catch (Exception e) {
+ System.err.printf(
+ "Failed to create a channel %s: %s%n", e.getClass().getSimpleName(), e.getMessage());
+ e.printStackTrace(System.err);
+ }
+ }
+
+ public static void start(ServerMock messageProcessor) {
+ if (instance != null) {
+ stop();
+ }
+ instance = new RedisServerMock();
+ instance.messageProcessor = messageProcessor;
+ }
+
+ public static void stop() {
+ instance.channel.close();
+ instance.group.shutdownGracefully();
+ instance = null;
+ }
+
+ private static String RedisMessageToString(RedisMessage msg) {
+ if (msg instanceof SimpleStringRedisMessage) {
+ return ((SimpleStringRedisMessage) msg).content();
+ } else if (msg instanceof ErrorRedisMessage) {
+ return ((ErrorRedisMessage) msg).content();
+ } else if (msg instanceof IntegerRedisMessage) {
+ return String.valueOf(((IntegerRedisMessage) msg).value());
+ } else if (msg instanceof FullBulkStringRedisMessage) {
+ return getString((FullBulkStringRedisMessage) msg);
+ } else if (msg instanceof ArrayRedisMessage) {
+ return ((ArrayRedisMessage) msg)
+ .children().stream()
+ .map(RedisServerMock::RedisMessageToString)
+ .collect(Collectors.joining(" "));
+ } else {
+ throw new CodecException("unknown message type: " + msg);
+ }
+ }
+
+ private static String getString(FullBulkStringRedisMessage msg) {
+ if (msg.isNull()) {
+ return "(null)";
+ }
+ return msg.content().toString(CharsetUtil.UTF_8);
+ }
+}
diff --git a/java/client/src/test/java/babushka/utils/RustCoreLibMockTestBase.java b/java/client/src/test/java/babushka/utils/RustCoreLibMockTestBase.java
new file mode 100644
index 0000000000..24b9daca47
--- /dev/null
+++ b/java/client/src/test/java/babushka/utils/RustCoreLibMockTestBase.java
@@ -0,0 +1,48 @@
+package babushka.utils;
+
+import babushka.connectors.handlers.ChannelHandler;
+import babushka.ffi.resolvers.BabushkaCoreNativeDefinitions;
+import lombok.SneakyThrows;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+
+public class RustCoreLibMockTestBase {
+
+ /**
+ * Pass this socket path to {@link ChannelHandler} or mock {@link
+ * BabushkaCoreNativeDefinitions#getSocket()} to return it.
+ */
+ protected static String socketPath = null;
+
+ @SneakyThrows
+ public static void startRustCoreLibMock(RustCoreMock.BabushkaMock rustCoreLibMock) {
+ assert socketPath == null
+ : "Previous `RustCoreMock` wasn't stopped, probably your test class does not inherit"
+ + " `RustCoreLibMockTestBase`.";
+
+ socketPath = RustCoreMock.start(rustCoreLibMock);
+ }
+
+ @BeforeEach
+ public void preTestCheck() {
+ assert socketPath != null
+ : "You missed to call `startRustCoreLibMock` in a `@BeforeAll` method of your test class"
+ + " inherited from `RustCoreLibMockTestBase`.";
+ }
+
+ @AfterEach
+ public void afterTestCheck() {
+ assert !RustCoreMock.failed() : "Error occurred in `RustCoreMock`";
+ }
+
+ @AfterAll
+ @SneakyThrows
+ public static void stopRustCoreLibMock() {
+ assert socketPath != null
+ : "You missed to call `startRustCoreLibMock` in a `@BeforeAll` method of your test class"
+ + " inherited from `RustCoreLibMockTestBase`.";
+ RustCoreMock.stop();
+ socketPath = null;
+ }
+}
diff --git a/java/client/src/test/java/babushka/utils/RustCoreMock.java b/java/client/src/test/java/babushka/utils/RustCoreMock.java
new file mode 100644
index 0000000000..136cf726ca
--- /dev/null
+++ b/java/client/src/test/java/babushka/utils/RustCoreMock.java
@@ -0,0 +1,154 @@
+package babushka.utils;
+
+import babushka.connectors.resources.Platform;
+import connection_request.ConnectionRequestOuterClass.ConnectionRequest;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.unix.DomainSocketAddress;
+import io.netty.channel.unix.DomainSocketChannel;
+import io.netty.handler.codec.protobuf.ProtobufEncoder;
+import io.netty.handler.codec.protobuf.ProtobufVarint32FrameDecoder;
+import io.netty.handler.codec.protobuf.ProtobufVarint32LengthFieldPrepender;
+import java.nio.file.Files;
+import java.util.OptionalInt;
+import java.util.concurrent.atomic.AtomicBoolean;
+import redis_request.RedisRequestOuterClass.RedisRequest;
+import response.ResponseOuterClass.ConstantResponse;
+import response.ResponseOuterClass.Response;
+
+public class RustCoreMock {
+
+ public abstract static class BabushkaMock {
+ /** Return `null` to do not reply. */
+ public abstract Response connection(ConnectionRequest request);
+
+ /** Return `null` to do not reply. */
+ public abstract Response.Builder redisRequest(RedisRequest request);
+
+ public Response redisRequestWithCallbackId(RedisRequest request) {
+ var responseDraft = redisRequest(request);
+ return responseDraft == null
+ ? null
+ : responseDraft.setCallbackIdx(request.getCallbackIdx()).build();
+ }
+
+ public static Response.Builder OK() {
+ return Response.newBuilder().setConstantResponse(ConstantResponse.OK);
+ }
+ }
+
+ public abstract static class BabushkaMockConnectAll extends BabushkaMock {
+ @Override
+ public Response connection(ConnectionRequest request) {
+ return OK().build();
+ }
+ }
+
+ /** Thread pool supplied to Netty to perform all async IO. */
+ private EventLoopGroup group;
+
+ private Channel channel;
+
+ private String socketPath;
+
+ private static RustCoreMock instance;
+
+ private BabushkaMock messageProcessor;
+
+ /** Update {@link BabushkaMock} into a running {@link RustCoreMock}. */
+ public static void updateBabushkaMock(BabushkaMock newMock) {
+ instance.messageProcessor = newMock;
+ }
+
+ private final AtomicBoolean failed = new AtomicBoolean(false);
+
+ /** Get and clear failure status. */
+ public static boolean failed() {
+ return instance.failed.compareAndSet(true, false);
+ }
+
+ private RustCoreMock() {
+ try {
+ socketPath = Files.createTempFile("RustCoreMock", null).toString();
+ channel =
+ new ServerBootstrap()
+ .group(group = Platform.createNettyThreadPool("RustCoreMock", OptionalInt.empty()))
+ .channel(Platform.getServerUdsNettyChannelType())
+ .childHandler(
+ new ChannelInitializer() {
+
+ @Override
+ protected void initChannel(DomainSocketChannel ch) throws Exception {
+ ch.pipeline()
+ // https://netty.io/4.1/api/io/netty/handler/codec/protobuf/ProtobufEncoder.html
+ .addLast("frameDecoder", new ProtobufVarint32FrameDecoder())
+ .addLast("frameEncoder", new ProtobufVarint32LengthFieldPrepender())
+ .addLast("protobufEncoder", new ProtobufEncoder())
+ .addLast(
+ new ChannelInboundHandlerAdapter() {
+
+ // This works with only one connected client.
+ // TODO Rework with `channelActive` override.
+ private AtomicBoolean anybodyConnected = new AtomicBoolean(false);
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg)
+ throws Exception {
+ var buf = (ByteBuf) msg;
+ var bytes = new byte[buf.readableBytes()];
+ buf.readBytes(bytes);
+ buf.release();
+ Response response = null;
+ if (!anybodyConnected.get()) {
+ var connection = ConnectionRequest.parseFrom(bytes);
+ response = messageProcessor.connection(connection);
+ anybodyConnected.setPlain(true);
+ } else {
+ var request = RedisRequest.parseFrom(bytes);
+ response = messageProcessor.redisRequestWithCallbackId(request);
+ }
+ if (response != null) {
+ ctx.writeAndFlush(response);
+ }
+ }
+
+ @Override
+ public void exceptionCaught(
+ ChannelHandlerContext ctx, Throwable cause) throws Exception {
+ cause.printStackTrace();
+ ctx.close();
+ failed.setPlain(true);
+ }
+ });
+ }
+ })
+ .bind(new DomainSocketAddress(socketPath))
+ // .sync()
+ .channel();
+ } catch (Exception e) {
+ System.err.printf(
+ "Failed to create a channel %s: %s%n", e.getClass().getSimpleName(), e.getMessage());
+ e.printStackTrace(System.err);
+ }
+ }
+
+ public static String start(BabushkaMock messageProcessor) {
+ if (instance != null) {
+ stop();
+ }
+ instance = new RustCoreMock();
+ instance.messageProcessor = messageProcessor;
+ return instance.socketPath;
+ }
+
+ public static void stop() {
+ instance.channel.close();
+ instance.group.shutdownGracefully();
+ instance = null;
+ }
+}