Skip to content

Commit

Permalink
Get PING command working
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanl-bq committed Oct 17, 2023
1 parent 7df9805 commit 4f507a4
Show file tree
Hide file tree
Showing 5 changed files with 9,504 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,151 @@
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Arrays;
import javababushka.client.RedisClient;
import org.apache.commons.lang3.tuple.MutablePair;
import org.apache.commons.lang3.tuple.Pair;
import redis_request.RedisRequestOuterClass.Command;
import redis_request.RedisRequestOuterClass.Command.ArgsArray;
import redis_request.RedisRequestOuterClass.RedisRequest;
import redis_request.RedisRequestOuterClass.RequestType;
import redis_request.RedisRequestOuterClass.Routes;
import redis_request.RedisRequestOuterClass.SimpleRoutes;
import response.ResponseOuterClass.Response;

public class BenchmarkingApp {

/*
def _varint_encoder():
"""Return an encoder for a basic varint value (does not include tag)."""
@classmethod
def _decode_varint_32(cls, buffer, pos):
decoder_func = cls._varint_decoder((1 << 32) - 1, int)
return decoder_func(buffer, pos)
@staticmethod
def _varint_decoder(mask, result_type):
"""Return an encoder for a basic varint value (does not include tag).
local_int2byte = struct.Struct(">B").pack
Decoded values will be bitwise-anded with the given mask before being
returned, e.g. to limit them to 32 bits. The returned decoder does not
take the usual "end" parameter -- the caller is expected to do bounds checking
after the fact (often the caller can defer such checking until later). The
decoder returns a (value, new_pos) pair.
"""
def encode_varint(write, value, unused_deterministic=None):
bits = value & 0x7F
value >>= 7
while value:
write(local_int2byte(0x80 | bits))
bits = value & 0x7F
value >>= 7
return write(local_int2byte(bits))
def decode_varint(buffer, pos):
result = 0
shift = 0
while 1:
b = buffer[pos]
result |= (b & 0x7F) << shift
pos += 1
if not (b & 0x80):
result &= mask
result = result_type(result)
return (result, pos)
shift += 7
if shift >= 64:
raise message.DecodeError("Too many bytes when decoding varint.")
return encode_varint
return decode_varint
@classmethod
def _varint_bytes(cls, value: int) -> bytes:
"""Encode the given integer as a varint and return the bytes.
TODO: Improve performance
"""
pieces: List[bytes] = []
func = cls._varint_encoder()
func(pieces.append, value, True)
return b"".join(pieces)
def decode_delimited(
cls,
read_bytes: bytearray,
read_bytes_view: memoryview,
offset: int,
message_class: Type[message.Message],
) -> Tuple[message.Message, int]:
try:
msg_len, new_pos = cls._decode_varint_32(read_bytes_view, offset)
except IndexError:
# Didn't read enough bytes to decode the varint
raise PartialMessageException(
"Didn't read enough bytes to decode the varint"
)
required_read_size = new_pos + msg_len
if required_read_size > len(read_bytes):
# Recieved only partial response
raise PartialMessageException("Recieved only a partial response")
offset = new_pos
msg_buf = read_bytes_view[offset : offset + msg_len]
offset += msg_len
message = message_class()
message.ParseFromString(msg_buf)
return (message, offset)
async def _reader_loop(self) -> None:
# Socket reader loop
remaining_read_bytes = bytearray()
while True:
read_bytes = await self._reader.read(DEFAULT_READ_BYTES_SIZE)
if len(read_bytes) == 0:
self.close("The server closed the connection")
raise Exception("read 0 bytes")
read_bytes = remaining_read_bytes + bytearray(read_bytes)
read_bytes_view = memoryview(read_bytes)
offset = 0
while offset <= len(read_bytes):
try:
response, offset = ProtobufCodec.decode_delimited(
read_bytes, read_bytes_view, offset, Response
)
except PartialMessageException:
# Recieved only partial response, break the inner loop
remaining_read_bytes = read_bytes[offset:]
break
response = cast(Response, response)
res_future = self._available_futures.get(response.callback_idx)
if not res_future or response.HasField("closing_error"):
err_msg = (
response.closing_error
if response.HasField("closing_error")
else f"Client Error - closing due to unknown error. callback index: {response.callback_idx}"
)
self.close(err_msg)
else:
if response.HasField("request_error"):
res_future.set_exception(Exception(response.request_error))
elif response.HasField("resp_pointer"):
res_future.set_result(value_from_pointer(response.resp_pointer))
elif response.HasField("constant_response"):
res_future.set_result(OK)
else:
res_future.set_result(None)
*/

// Left is length of message, right is position
private static Pair<Long, Integer> decodeVarint(byte[] buffer, int pos) throws Exception {
long mask = ((long) 1 << 32) - 1;
int shift = 0;
long result = 0;
while (true) {
byte b = buffer[pos];
result |= (b & 0x7F) << shift;
pos += 1;
if ((b & 0x80) == 0) {
result &= mask;
// result = (int) result;
return new MutablePair<>(result, pos);
}
shift += 7;
if (shift >= 64) {
throw new Exception("Too many bytes when decoding varint.");
}
}
}

private static Response decodeMessage(byte[] buffer) throws Exception {
Pair<Long, Integer> pair = decodeVarint(buffer, 0);
int startIdx = (int) pair.getRight();
byte[] responseBytes =
Arrays.copyOfRange(buffer, startIdx, startIdx + (int) (long) pair.getLeft());
Response response = Response.parseFrom(responseBytes);
return response;
}

private static Byte[] varintBytes(int value) {
ArrayList<Byte> output = new ArrayList();
int bits = value & 0x7F;
Expand All @@ -59,17 +171,34 @@ private static Byte[] varintBytes(int value) {
return output.toArray(arr);
}

private static String readSocketMessage(SocketChannel channel) throws IOException {
private static byte[] readSocketMessage(SocketChannel channel) throws IOException {
ByteBuffer buffer = ByteBuffer.allocate(1024);
int bytesRead = channel.read(buffer);
if (bytesRead < 0) return null;

byte[] bytes = new byte[bytesRead];
buffer.flip();
buffer.get(bytes);
String message = new String(bytes);
return message;
return bytes;
}

/*
private static sendMessage() {
Byte[] varint = varintBytes(request.toByteArray().length);
System.out.println("Request: \n" + request.toString());
ByteBuffer buffer = ByteBuffer.allocate(1024);
buffer.clear();
for (Byte b : varint) {
buffer.put(b);
}
buffer.put(request.toByteArray());
buffer.flip();
while (buffer.hasRemaining()) {
channel.write(buffer);
}
}
*/

// main application entrypoint
public static void main(String[] args) throws InterruptedException {
Expand Down Expand Up @@ -118,24 +247,68 @@ public static void main(String[] args) throws InterruptedException {
}
buffer.put(request.toByteArray());
buffer.flip();
// System.out.println("Buffer: \n" + StandardCharsets.UTF_8.decode(buffer).toString());
while (buffer.hasRemaining()) {
channel.write(buffer);
}

timeout = 0;
String response = "";
while (response == "" && timeout < maxTimeout) {
byte[] responseBuffer = null;
while (responseBuffer == null && timeout < maxTimeout) {
timeout++;
System.out.println("iteration");
response = readSocketMessage(channel);
responseBuffer = readSocketMessage(channel);
Thread.sleep(250);
}

if (response == null) {
System.out.println("WARNING: response null");
try {
Response response = decodeMessage(responseBuffer);
System.out.println(response);
} catch (Exception e) {
e.printStackTrace();
}

RedisRequest pingRequest =
RedisRequest.newBuilder()
.setCallbackIdx(0)
.setSingleCommand(
Command.newBuilder()
.setRequestType(RequestType.Ping)
.setArgsArray(ArgsArray.newBuilder()))
.setRoute(Routes.newBuilder().setSimpleRoutes(SimpleRoutes.AllNodes))
.build();

Byte[] varint2 = varintBytes(pingRequest.toByteArray().length);

System.out.println("Request: \n" + pingRequest.toString());
ByteBuffer pingBuffer = ByteBuffer.allocate(1024);
pingBuffer.clear();
for (Byte b : varint2) {
pingBuffer.put(b);
}
pingBuffer.put(pingRequest.toByteArray());
pingBuffer.flip();
while (pingBuffer.hasRemaining()) {
channel.write(pingBuffer);
}

System.out.println("Before read from socket");
timeout = 0;
byte[] pingResponseBuffer = null;
while (pingResponseBuffer == null && timeout < maxTimeout) {
timeout++;
pingResponseBuffer = readSocketMessage(channel);
Thread.sleep(250);
}
System.out.println("After read from socket");

try {
Response pingResponse = decodeMessage(pingResponseBuffer);
System.out.println(pingResponse);
Object o = RedisClient.valueFromPointer(pingResponse.getRespPointer());

System.out.println(o);
} catch (Exception e) {
e.printStackTrace();
}
System.out.println("Response: " + response);

} catch (IOException e) {
e.printStackTrace();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
public class RedisClient {
public static native void startSocketListenerExternal(RedisClient callback);

public static native Object valueFromPointer(long pointer);

static {
System.loadLibrary("javababushka");
}
Expand Down
Loading

0 comments on commit 4f507a4

Please sign in to comment.