diff --git a/java/client/build.gradle b/java/client/build.gradle index be9120cf3f..7721e4b7b7 100644 --- a/java/client/build.gradle +++ b/java/client/build.gradle @@ -28,6 +28,8 @@ dependencies { // junit testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '3.12.4' + + testImplementation group: 'io.netty', name: 'netty-codec-redis', version: '4.1.100.Final' } tasks.register('protobuf', Exec) { diff --git a/java/client/src/main/java/babushka/connectors/resources/Platform.java b/java/client/src/main/java/babushka/connectors/resources/Platform.java index b411f04f50..353b884daa 100644 --- a/java/client/src/main/java/babushka/connectors/resources/Platform.java +++ b/java/client/src/main/java/babushka/connectors/resources/Platform.java @@ -4,10 +4,16 @@ import io.netty.channel.epoll.Epoll; import io.netty.channel.epoll.EpollDomainSocketChannel; import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerDomainSocketChannel; +import io.netty.channel.epoll.EpollServerSocketChannel; import io.netty.channel.kqueue.KQueue; import io.netty.channel.kqueue.KQueueDomainSocketChannel; import io.netty.channel.kqueue.KQueueEventLoopGroup; +import io.netty.channel.kqueue.KQueueServerDomainSocketChannel; +import io.netty.channel.kqueue.KQueueServerSocketChannel; +import io.netty.channel.socket.ServerSocketChannel; import io.netty.channel.unix.DomainSocketChannel; +import io.netty.channel.unix.ServerDomainSocketChannel; import io.netty.util.concurrent.DefaultThreadFactory; import java.util.Map; import java.util.OptionalInt; @@ -121,6 +127,36 @@ public static Class getClientUdsNettyChannelType( throw new RuntimeException("Current platform supports no known socket types"); } + /** + * Get a channel class required by Netty to open a server UDS channel. + * + * @return Return a class, supported by the current native platform. + */ + public static Class getServerUdsNettyChannelType() { + if (capabilities.isKQueueAvailable()) { + return KQueueServerDomainSocketChannel.class; + } + if (capabilities.isEPollAvailable()) { + return EpollServerDomainSocketChannel.class; + } + throw new RuntimeException("Current platform supports no known socket types"); + } + + /** + * Get a channel class required by Netty to open a server TCP channel. + * + * @return Return a class, supported by the current native platform. + */ + public static Class getServerTcpNettyChannelType() { + if (capabilities.isKQueueAvailable()) { + return KQueueServerSocketChannel.class; + } + if (capabilities.isEPollAvailable()) { + return EpollServerSocketChannel.class; + } + throw new RuntimeException("Current platform supports no known socket types"); + } + /** * A JVM shutdown hook to be registered. It is responsible for closing connection and freeing * resources. It is recommended to use a class instead of lambda to ensure that it is called.
diff --git a/java/client/src/test/java/babushka/utils/Awaiter.java b/java/client/src/test/java/babushka/utils/Awaiter.java new file mode 100644 index 0000000000..2f54e37a39 --- /dev/null +++ b/java/client/src/test/java/babushka/utils/Awaiter.java @@ -0,0 +1,35 @@ +package babushka.utils; + +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class Awaiter { + private static final long DEFAULT_TIMEOUT_MILLISECONDS = 30000; + + /** Get the future result with default timeout. */ + public static T await(Future future) { + return await(future, DEFAULT_TIMEOUT_MILLISECONDS); + } + + /** Get the future result with given timeout in ms. */ + public static T await(Future future, long timeout) { + try { + return future.get(timeout, TimeUnit.MILLISECONDS); + } catch (TimeoutException e) { + throw new RuntimeException("Request timed out", e); + } catch (ExecutionException e) { + throw new RuntimeException(e.getMessage(), e.getCause()); + } catch (InterruptedException e) { + if (Thread.currentThread().isInterrupted()) { + // restore interrupt + Thread.interrupted(); + } + throw new RuntimeException("The thread was interrupted", e); + } catch (CancellationException e) { + throw new RuntimeException("Request was cancelled", e); + } + } +} diff --git a/java/client/src/test/java/babushka/utils/RedisMockTestBase.java b/java/client/src/test/java/babushka/utils/RedisMockTestBase.java new file mode 100644 index 0000000000..d194f02b63 --- /dev/null +++ b/java/client/src/test/java/babushka/utils/RedisMockTestBase.java @@ -0,0 +1,42 @@ +package babushka.utils; + +import lombok.SneakyThrows; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +public class RedisMockTestBase { + + public static boolean started = false; + + @SneakyThrows + public static void startRedisMock(RedisServerMock.ServerMock serverMock) { + assert !started + : "Previous `RedisMock` wasn't stopped, probably your test class does not inherit" + + " `RedisMockTestBase`."; + RedisServerMock.start(serverMock); + started = true; + } + + @BeforeEach + public void preTestCheck() { + assert started + : "You missed to call `startRustCoreLibMock` in a `@BeforeAll` method of your test class" + + " inherited from `RedisMockTestBase`."; + } + + @AfterEach + public void afterTestCheck() { + assert !RedisServerMock.failed() : "Error occurred in `RedisMock`"; + } + + @AfterAll + @SneakyThrows + public static void stopRedisMock() { + assert started + : "You missed to call `startRustCoreLibMock` in a `@BeforeAll` method of your test class" + + " inherited from `RedisMockTestBase`."; + RedisServerMock.stop(); + started = false; + } +} diff --git a/java/client/src/test/java/babushka/utils/RedisServerMock.java b/java/client/src/test/java/babushka/utils/RedisServerMock.java new file mode 100644 index 0000000000..a233659f9d --- /dev/null +++ b/java/client/src/test/java/babushka/utils/RedisServerMock.java @@ -0,0 +1,219 @@ +package babushka.utils; + +import babushka.connectors.resources.Platform; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.redis.ArrayRedisMessage; +import io.netty.handler.codec.redis.ErrorRedisMessage; +import io.netty.handler.codec.redis.FullBulkStringRedisMessage; +import io.netty.handler.codec.redis.IntegerRedisMessage; +import io.netty.handler.codec.redis.RedisArrayAggregator; +import io.netty.handler.codec.redis.RedisBulkStringAggregator; +import io.netty.handler.codec.redis.RedisDecoder; +import io.netty.handler.codec.redis.RedisEncoder; +import io.netty.handler.codec.redis.RedisMessage; +import io.netty.handler.codec.redis.SimpleStringRedisMessage; +import io.netty.util.CharsetUtil; +import java.net.InetSocketAddress; +import java.util.OptionalInt; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import lombok.Setter; + +public class RedisServerMock { + + public abstract static class ServerMock { + /** Return `null` to do not reply. */ + public abstract RedisMessage reply(String cmd); + + protected RedisMessage reply0(String cmd) { + return reply(cmd); + } + + public static RedisMessage error(String text) { + return new ErrorRedisMessage(text); + } + + public static RedisMessage error(String prefix, String text) { + // https://redis.io/docs/reference/protocol-spec/#simple-errors + if (prefix.contains(" ") || prefix.contains("\r") || prefix.contains("\n")) { + throw new IllegalArgumentException(); + } + return new ErrorRedisMessage(prefix.toUpperCase() + " " + text); + } + + public static RedisMessage simpleString(String text) { + return new SimpleStringRedisMessage(text); + } + + public static RedisMessage OK() { + return simpleString("OK"); + } + + public static RedisMessage number(long value) { + return new IntegerRedisMessage(value); + } + + /** A multi-line message. */ + public static RedisMessage multiString(String text) { + return new FullBulkStringRedisMessage(Unpooled.copiedBuffer(text.getBytes())); + } + } + + public abstract static class ServerMockConnectAll extends ServerMock { + @Override + protected RedisMessage reply0(String cmd) { + if (cmd.startsWith("CLIENT SETINFO")) { + return OK(); + } else if (cmd.startsWith("INFO REPLICATION")) { + var response = + "# Replication\r\n" + + "role:master\r\n" + + "connected_slaves:0\r\n" + + "master_failover_state:no-failover\r\n" + + "master_replid:d7646c8d14901de9347f1f675c70bcf269a503eb\r\n" + + "master_replid2:0000000000000000000000000000000000000000\r\n" + + "master_repl_offset:0\r\n" + + "second_repl_offset:-1\r\n" + + "repl_backlog_active:0\r\n" + + "repl_backlog_size:1048576\r\n" + + "repl_backlog_first_byte_offset:0\r\n" + + "repl_backlog_histlen:0\r\n"; + return multiString(response); + } + return reply(cmd); + } + } + + // TODO support configurable port to test cluster mode + public static final int PORT = 6380; + + /** Thread pool supplied to Netty to perform all async IO. */ + private EventLoopGroup group; + + private Channel channel; + + private static RedisServerMock instance; + + private ServerMock messageProcessor; + + /** Update {@link ServerMock} into a running {@link RedisServerMock}. */ + public static void updateServerMock(ServerMock newMock) { + instance.messageProcessor = newMock; + } + + private final AtomicBoolean failed = new AtomicBoolean(false); + + /** Get and clear failure status. */ + public static boolean failed() { + return instance.failed.compareAndSet(true, false); + } + + @Setter private static boolean debugLogging = false; + + private RedisServerMock() { + try { + channel = + new ServerBootstrap() + .group(group = Platform.createNettyThreadPool("RedisMock", OptionalInt.empty())) + .channel(Platform.getServerTcpNettyChannelType()) + .childHandler( + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline() + // https://github.com/netty/netty/blob/4.1/example/src/main/java/io/netty/example/redis/RedisClient.java + .addLast(new RedisDecoder()) + .addLast(new RedisBulkStringAggregator()) + .addLast(new RedisArrayAggregator()) + .addLast(new RedisEncoder()) + .addLast( + new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) + throws Exception { + RedisMessage redisMessage = (RedisMessage) msg; + var str = RedisMessageToString(redisMessage); + if (debugLogging) { + System.out.printf("-- Received%n %s%n", str); + } + var response = messageProcessor.reply0(str); + if (response != null) { + if (debugLogging) { + System.out.printf( + "-- Replying with%n %s%n", + RedisMessageToString(response)); + } + ctx.writeAndFlush(response); + } else if (debugLogging) { + System.out.printf("-- Ignoring%n"); + } + } + + @Override + public void exceptionCaught( + ChannelHandlerContext ctx, Throwable cause) throws Exception { + cause.printStackTrace(); + ctx.close(); + failed.setPlain(true); + } + }); + } + }) + .bind(new InetSocketAddress(PORT)) + // .sync() + .channel(); + } catch (Exception e) { + System.err.printf( + "Failed to create a channel %s: %s%n", e.getClass().getSimpleName(), e.getMessage()); + e.printStackTrace(System.err); + } + } + + public static void start(ServerMock messageProcessor) { + if (instance != null) { + stop(); + } + instance = new RedisServerMock(); + instance.messageProcessor = messageProcessor; + } + + public static void stop() { + instance.channel.close(); + instance.group.shutdownGracefully(); + instance = null; + } + + private static String RedisMessageToString(RedisMessage msg) { + if (msg instanceof SimpleStringRedisMessage) { + return ((SimpleStringRedisMessage) msg).content(); + } else if (msg instanceof ErrorRedisMessage) { + return ((ErrorRedisMessage) msg).content(); + } else if (msg instanceof IntegerRedisMessage) { + return String.valueOf(((IntegerRedisMessage) msg).value()); + } else if (msg instanceof FullBulkStringRedisMessage) { + return getString((FullBulkStringRedisMessage) msg); + } else if (msg instanceof ArrayRedisMessage) { + return ((ArrayRedisMessage) msg) + .children().stream() + .map(RedisServerMock::RedisMessageToString) + .collect(Collectors.joining(" ")); + } else { + throw new CodecException("unknown message type: " + msg); + } + } + + private static String getString(FullBulkStringRedisMessage msg) { + if (msg.isNull()) { + return "(null)"; + } + return msg.content().toString(CharsetUtil.UTF_8); + } +} diff --git a/java/client/src/test/java/babushka/utils/RustCoreLibMockTestBase.java b/java/client/src/test/java/babushka/utils/RustCoreLibMockTestBase.java new file mode 100644 index 0000000000..24b9daca47 --- /dev/null +++ b/java/client/src/test/java/babushka/utils/RustCoreLibMockTestBase.java @@ -0,0 +1,48 @@ +package babushka.utils; + +import babushka.connectors.handlers.ChannelHandler; +import babushka.ffi.resolvers.BabushkaCoreNativeDefinitions; +import lombok.SneakyThrows; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +public class RustCoreLibMockTestBase { + + /** + * Pass this socket path to {@link ChannelHandler} or mock {@link + * BabushkaCoreNativeDefinitions#getSocket()} to return it. + */ + protected static String socketPath = null; + + @SneakyThrows + public static void startRustCoreLibMock(RustCoreMock.BabushkaMock rustCoreLibMock) { + assert socketPath == null + : "Previous `RustCoreMock` wasn't stopped, probably your test class does not inherit" + + " `RustCoreLibMockTestBase`."; + + socketPath = RustCoreMock.start(rustCoreLibMock); + } + + @BeforeEach + public void preTestCheck() { + assert socketPath != null + : "You missed to call `startRustCoreLibMock` in a `@BeforeAll` method of your test class" + + " inherited from `RustCoreLibMockTestBase`."; + } + + @AfterEach + public void afterTestCheck() { + assert !RustCoreMock.failed() : "Error occurred in `RustCoreMock`"; + } + + @AfterAll + @SneakyThrows + public static void stopRustCoreLibMock() { + assert socketPath != null + : "You missed to call `startRustCoreLibMock` in a `@BeforeAll` method of your test class" + + " inherited from `RustCoreLibMockTestBase`."; + RustCoreMock.stop(); + socketPath = null; + } +} diff --git a/java/client/src/test/java/babushka/utils/RustCoreMock.java b/java/client/src/test/java/babushka/utils/RustCoreMock.java new file mode 100644 index 0000000000..136cf726ca --- /dev/null +++ b/java/client/src/test/java/babushka/utils/RustCoreMock.java @@ -0,0 +1,154 @@ +package babushka.utils; + +import babushka.connectors.resources.Platform; +import connection_request.ConnectionRequestOuterClass.ConnectionRequest; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.unix.DomainSocketAddress; +import io.netty.channel.unix.DomainSocketChannel; +import io.netty.handler.codec.protobuf.ProtobufEncoder; +import io.netty.handler.codec.protobuf.ProtobufVarint32FrameDecoder; +import io.netty.handler.codec.protobuf.ProtobufVarint32LengthFieldPrepender; +import java.nio.file.Files; +import java.util.OptionalInt; +import java.util.concurrent.atomic.AtomicBoolean; +import redis_request.RedisRequestOuterClass.RedisRequest; +import response.ResponseOuterClass.ConstantResponse; +import response.ResponseOuterClass.Response; + +public class RustCoreMock { + + public abstract static class BabushkaMock { + /** Return `null` to do not reply. */ + public abstract Response connection(ConnectionRequest request); + + /** Return `null` to do not reply. */ + public abstract Response.Builder redisRequest(RedisRequest request); + + public Response redisRequestWithCallbackId(RedisRequest request) { + var responseDraft = redisRequest(request); + return responseDraft == null + ? null + : responseDraft.setCallbackIdx(request.getCallbackIdx()).build(); + } + + public static Response.Builder OK() { + return Response.newBuilder().setConstantResponse(ConstantResponse.OK); + } + } + + public abstract static class BabushkaMockConnectAll extends BabushkaMock { + @Override + public Response connection(ConnectionRequest request) { + return OK().build(); + } + } + + /** Thread pool supplied to Netty to perform all async IO. */ + private EventLoopGroup group; + + private Channel channel; + + private String socketPath; + + private static RustCoreMock instance; + + private BabushkaMock messageProcessor; + + /** Update {@link BabushkaMock} into a running {@link RustCoreMock}. */ + public static void updateBabushkaMock(BabushkaMock newMock) { + instance.messageProcessor = newMock; + } + + private final AtomicBoolean failed = new AtomicBoolean(false); + + /** Get and clear failure status. */ + public static boolean failed() { + return instance.failed.compareAndSet(true, false); + } + + private RustCoreMock() { + try { + socketPath = Files.createTempFile("RustCoreMock", null).toString(); + channel = + new ServerBootstrap() + .group(group = Platform.createNettyThreadPool("RustCoreMock", OptionalInt.empty())) + .channel(Platform.getServerUdsNettyChannelType()) + .childHandler( + new ChannelInitializer() { + + @Override + protected void initChannel(DomainSocketChannel ch) throws Exception { + ch.pipeline() + // https://netty.io/4.1/api/io/netty/handler/codec/protobuf/ProtobufEncoder.html + .addLast("frameDecoder", new ProtobufVarint32FrameDecoder()) + .addLast("frameEncoder", new ProtobufVarint32LengthFieldPrepender()) + .addLast("protobufEncoder", new ProtobufEncoder()) + .addLast( + new ChannelInboundHandlerAdapter() { + + // This works with only one connected client. + // TODO Rework with `channelActive` override. + private AtomicBoolean anybodyConnected = new AtomicBoolean(false); + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) + throws Exception { + var buf = (ByteBuf) msg; + var bytes = new byte[buf.readableBytes()]; + buf.readBytes(bytes); + buf.release(); + Response response = null; + if (!anybodyConnected.get()) { + var connection = ConnectionRequest.parseFrom(bytes); + response = messageProcessor.connection(connection); + anybodyConnected.setPlain(true); + } else { + var request = RedisRequest.parseFrom(bytes); + response = messageProcessor.redisRequestWithCallbackId(request); + } + if (response != null) { + ctx.writeAndFlush(response); + } + } + + @Override + public void exceptionCaught( + ChannelHandlerContext ctx, Throwable cause) throws Exception { + cause.printStackTrace(); + ctx.close(); + failed.setPlain(true); + } + }); + } + }) + .bind(new DomainSocketAddress(socketPath)) + // .sync() + .channel(); + } catch (Exception e) { + System.err.printf( + "Failed to create a channel %s: %s%n", e.getClass().getSimpleName(), e.getMessage()); + e.printStackTrace(System.err); + } + } + + public static String start(BabushkaMock messageProcessor) { + if (instance != null) { + stop(); + } + instance = new RustCoreMock(); + instance.messageProcessor = messageProcessor; + return instance.socketPath; + } + + public static void stop() { + instance.channel.close(); + instance.group.shutdownGracefully(); + instance = null; + } +}