Skip to content

Commit

Permalink
Add java-jni client
Browse files Browse the repository at this point in the history
Signed-off-by: acarbonetto <[email protected]>
  • Loading branch information
acarbonetto committed Oct 18, 2023
1 parent 8dcfec0 commit a615f8a
Show file tree
Hide file tree
Showing 12 changed files with 14,871 additions and 3 deletions.
2 changes: 1 addition & 1 deletion java/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

[package]
name = "javababushka"
version = "0.0.0"
Expand All @@ -16,6 +15,7 @@ babushka = { path = "../babushka-core" }
tokio = { version = "^1", features = ["rt", "macros", "rt-multi-thread", "time"] }
logger_core = {path = "../logger_core"}
tracing-subscriber = "0.3.16"
jni = "0.21.1"

[profile.release]
lto = true
Expand Down
4 changes: 4 additions & 0 deletions java/benchmarks/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ dependencies {
implementation 'io.lettuce:lettuce-core:6.2.6.RELEASE'
implementation 'commons-cli:commons-cli:1.5.0'
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.13.0'

// https://mvnrepository.com/artifact/com.google.protobuf/protobuf-java
implementation group: 'com.google.protobuf', name: 'protobuf-java', version: '3.24.3'
}

// Apply a specific Java toolchain to ease working on different environments.
Expand All @@ -30,6 +33,7 @@ java {
application {
// Define the main class for the application.
mainClass = 'javababushka.benchmarks.BenchmarkingApp'
applicationDefaultJvmArgs += "-Djava.library.path=${projectDir}/../target/debug"
}

tasks.withType(Test) {
Expand Down
898 changes: 898 additions & 0 deletions java/benchmarks/hs_err_pid70704.log

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.stream.Stream;
import javababushka.benchmarks.clients.JedisClient;
import javababushka.benchmarks.clients.JedisPseudoAsyncClient;
import javababushka.benchmarks.clients.JniSyncClient;
import javababushka.benchmarks.clients.LettuceAsyncClient;
import javababushka.benchmarks.clients.LettuceClient;
import org.apache.commons.cli.CommandLine;
Expand Down Expand Up @@ -62,6 +63,9 @@ public static void main(String[] args) {
case LETTUCE_ASYNC:
testClientSetGet(LettuceAsyncClient::new, runConfiguration, true);
break;
case BABUSHKA_JNI:
testClientSetGet(JniSyncClient::new, runConfiguration, true);
break;
case BABUSHKA_ASYNC:
System.out.println("Babushka async not yet configured");
break;
Expand Down Expand Up @@ -212,6 +216,7 @@ public enum ClientName {
JEDIS_ASYNC("Jedis async"),
LETTUCE("Lettuce"),
LETTUCE_ASYNC("Lettuce async"),
BABUSHKA_JNI("JNI sync"),
BABUSHKA_ASYNC("Babushka async"),
ALL("All"),
ALL_SYNC("All sync"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
package javababushka.benchmarks.clients;

import java.io.IOException;
import java.net.StandardProtocolFamily;
import java.net.UnixDomainSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Arrays;
import javababushka.benchmarks.utils.ConnectionSettings;
import javababushka.client.RedisClient;
import org.apache.commons.lang3.tuple.MutablePair;
import org.apache.commons.lang3.tuple.Pair;
import redis_request.RedisRequestOuterClass;
import response.ResponseOuterClass;

/** A JNI-built client using Unix Domain Sockets with async capabilities */
public class JniSyncClient implements SyncClient {

private static int MAX_TIMEOUT = 1000;

private RedisClient client;

private SocketChannel channel;

private boolean isChannelWriting = false;

@Override
public void connectToRedis() {
connectToRedis(new ConnectionSettings("localhost", 6379, false));
}

@Override
public void connectToRedis(ConnectionSettings connectionSettings) {

// Create redis client
client = new RedisClient();

// Get socket listener address/path
RedisClient.startSocketListenerExternal(client);

int timeout = 0;
int maxTimeout = 1000;
while (client.socketPath == null && timeout < maxTimeout) {
timeout++;
try {
Thread.sleep(250);
} catch (InterruptedException exception) {
// ignored
}
}

System.out.println("Socket Path: " + client.socketPath);
UnixDomainSocketAddress socketAddress = UnixDomainSocketAddress.of(client.socketPath);

// Start the socket listener
try {
channel = SocketChannel.open(StandardProtocolFamily.UNIX);
channel.connect(socketAddress);
} catch (IOException ioException) {
ioException.printStackTrace();
return;
}

String host = connectionSettings.host;
int port = connectionSettings.port;
connection_request.ConnectionRequestOuterClass.TlsMode tls =
connectionSettings.useSsl
?
// TODO: secure or insecure TLS?
connection_request.ConnectionRequestOuterClass.TlsMode.SecureTls
: connection_request.ConnectionRequestOuterClass.TlsMode.NoTls;

connection_request.ConnectionRequestOuterClass.ConnectionRequest request =
connection_request.ConnectionRequestOuterClass.ConnectionRequest.newBuilder()
.addAddresses(
connection_request.ConnectionRequestOuterClass.AddressInfo.newBuilder()
.setHost(host)
.setPort(port))
.setTlsMode(tls)
.setClusterModeEnabled(false)
// In millis
.setResponseTimeout(250)
// In millis
.setClientCreationTimeout(2500)
.setReadFromReplicaStrategy(
connection_request.ConnectionRequestOuterClass.ReadFromReplicaStrategy
.AlwaysFromPrimary)
.setConnectionRetryStrategy(
connection_request.ConnectionRequestOuterClass.ConnectionRetryStrategy.newBuilder()
.setNumberOfRetries(1)
.setFactor(1)
.setExponentBase(1))
.setAuthenticationInfo(
connection_request.ConnectionRequestOuterClass.AuthenticationInfo.newBuilder()
.setPassword("")
.setUsername("default"))
.setDatabaseId(0)
.build();

makeConnection(request);
}

@Override
public void set(String key, String value) {

int futureIdx = 1;
RedisRequestOuterClass.Command.ArgsArray args =
RedisRequestOuterClass.Command.ArgsArray.newBuilder().addArgs(key).addArgs(value).build();
RedisRequestOuterClass.RedisRequest request =
RedisRequestOuterClass.RedisRequest.newBuilder()
.setCallbackIdx(futureIdx)
.setSingleCommand(
RedisRequestOuterClass.Command.newBuilder()
.setRequestType(RedisRequestOuterClass.RequestType.SetString)
.setArgsArray(args))
.setRoute(
RedisRequestOuterClass.Routes.newBuilder()
.setSimpleRoutes(RedisRequestOuterClass.SimpleRoutes.AllNodes))
.build();

ResponseOuterClass.Response response = makeRedisRequest(request);
// nothing to do with the response
}

@Override
public String get(String key) {
int futureIdx = 1;
RedisRequestOuterClass.RedisRequest getStringRequest =
RedisRequestOuterClass.RedisRequest.newBuilder()
.setCallbackIdx(futureIdx)
.setSingleCommand(
RedisRequestOuterClass.Command.newBuilder()
.setRequestType(RedisRequestOuterClass.RequestType.GetString)
.setArgsArray(
RedisRequestOuterClass.Command.ArgsArray.newBuilder().addArgs(key)))
.setRoute(
RedisRequestOuterClass.Routes.newBuilder()
.setSimpleRoutes(RedisRequestOuterClass.SimpleRoutes.AllNodes))
.build();

ResponseOuterClass.Response response = makeRedisRequest(getStringRequest);
return response.toString();
}

@Override
public void closeConnection() {}

@Override
public String getName() {
return "JNI (with UDS) Sync";
}

// Left is length of message, right is position
private static Pair<Long, Integer> decodeVarint(byte[] buffer, int pos) throws Exception {
long mask = ((long) 1 << 32) - 1;
int shift = 0;
long result = 0;
while (true) {
byte b = buffer[pos];
result |= (b & 0x7F) << shift;
pos += 1;
if ((b & 0x80) == 0) {
result &= mask;
// result = (int) result;
return new MutablePair<>(result, pos);
}
shift += 7;
if (shift >= 64) {
throw new Exception("Too many bytes when decoding varint.");
}
}
}

private static ResponseOuterClass.Response decodeMessage(byte[] buffer) throws Exception {
Pair<Long, Integer> pair = decodeVarint(buffer, 0);
int startIdx = (int) pair.getRight();
byte[] responseBytes =
Arrays.copyOfRange(buffer, startIdx, startIdx + (int) (long) pair.getLeft());
ResponseOuterClass.Response response = ResponseOuterClass.Response.parseFrom(responseBytes);
return response;
}

private static Byte[] varintBytes(int value) {
ArrayList<Byte> output = new ArrayList();
int bits = value & 0x7F;
value >>= 7;
while (value > 0) {
output.add(new Byte((byte) (0x80 | bits)));
bits = value & 0x7F;
value >>= 7;
}
output.add(new Byte((byte) bits));
Byte[] arr = new Byte[] {};
return output.toArray(arr);
}

private static byte[] readSocketMessage(SocketChannel channel) throws IOException {
ByteBuffer buffer = ByteBuffer.allocate(1024);
int bytesRead = channel.read(buffer);
if (bytesRead <= 0) {
return null;
}

byte[] bytes = new byte[bytesRead];
buffer.flip();
buffer.get(bytes);
return bytes;
}

private ResponseOuterClass.Response makeConnection(
connection_request.ConnectionRequestOuterClass.ConnectionRequest request) {
Byte[] varint = varintBytes(request.toByteArray().length);

// System.out.println("Request: \n" + request.toString());
ByteBuffer buffer = ByteBuffer.allocate(1024);
buffer.clear();
for (Byte b : varint) {
buffer.put(b);
}
buffer.put(request.toByteArray());
buffer.flip();
while (isChannelWriting) {
try {
Thread.sleep(250);
} catch (InterruptedException interruptedException) {
// ignore...
}
}
isChannelWriting = true;
while (buffer.hasRemaining()) {
try {
channel.write(buffer);
} catch (IOException ioException) {
// ignore...
}
}
isChannelWriting = false;

ResponseOuterClass.Response response = null;
int timeout = 0;
try {
byte[] responseBuffer = readSocketMessage(channel);
while (responseBuffer == null && timeout < MAX_TIMEOUT) {
Thread.sleep(250);
timeout++;
responseBuffer = readSocketMessage(channel);
}

response = decodeMessage(responseBuffer);
} catch (Exception e) {
e.printStackTrace();
}
return response;
}

private ResponseOuterClass.Response makeRedisRequest(
RedisRequestOuterClass.RedisRequest request) {
Byte[] varint = varintBytes(request.toByteArray().length);

// System.out.println("Request: \n" + request.toString());
ByteBuffer buffer = ByteBuffer.allocate(1024);
buffer.clear();
for (Byte b : varint) {
buffer.put(b);
}
buffer.put(request.toByteArray());
buffer.flip();
while (isChannelWriting) {
try {
Thread.sleep(250);
} catch (InterruptedException interruptedException) {
// ignore...
}
}
isChannelWriting = true;
while (buffer.hasRemaining()) {
try {
channel.write(buffer);
} catch (IOException ioException) {
// ignore...
}
}
isChannelWriting = false;

int timeout = 0;
byte[] responseBuffer = null;
while (responseBuffer == null && timeout < MAX_TIMEOUT) {
timeout++;
try {
responseBuffer = readSocketMessage(channel);
Thread.sleep(250);
} catch (IOException | InterruptedException exception) {
// ignore...
}
}

// nothing to do with the responseBuffer message
ResponseOuterClass.Response response = null;
try {
response = decodeMessage(responseBuffer);
} catch (Exception e) {
e.printStackTrace();
}
return response;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ public static Map<ChosenAction, LatencyResults> calculateResults(
percentile(latencies, 50),
percentile(latencies, 90),
percentile(latencies, 99),
stdDeviation(latencies, avgLatency)));
stdDeviation(latencies, avgLatency),
latencies.size()
));
}

return results;
Expand Down Expand Up @@ -161,6 +163,8 @@ public static void printResults(Map<ChosenAction, LatencyResults> resultsMap) {
action + " p99 latency in ms: " + results.p99Latency / LATENCY_NORMALIZATION);
System.out.println(
action + " std dev in ms: " + results.stdDeviation / LATENCY_NORMALIZATION);
System.out.println(
action + " total hits: " + results.totalHits);
}
}

Expand Down
Loading

0 comments on commit a615f8a

Please sign in to comment.