Skip to content

Commit

Permalink
Change abstraction point for transport protocol
Browse files Browse the repository at this point in the history
The previous implementation had a transport switch point in
InboundPipeline when the bytes were initially pulled off the wire. There
was no implementation for any other protocol as the `canHandleBytes`
method was hardcoded to return true. I believe this is the wrong point
to switch on the protocol. This change makes NativeInboundBytesHandler
protocol agnostic beyond the header. With this change, a complete
message is parsed from the stream of bytes, with the header schema being
unchanged from what exists today. The protocol switch point will now be
at `InboundHandler::inboundMessage`. The header will indicate what
protocol was used to serialize the the non-header bytes of the message
and then invoke the appropriate handler based on that field.

Signed-off-by: Andrew Ross <[email protected]>
  • Loading branch information
andrross committed Aug 27, 2024
1 parent 46a269e commit 0138b6e
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 126 deletions.
10 changes: 9 additions & 1 deletion server/src/main/java/org/opensearch/transport/Header.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public class Header {

private static final String RESPONSE_NAME = "NO_ACTION_NAME_FOR_RESPONSES";

private final TransportProtocol protocol;
private final int networkMessageSize;
private final Version version;
private final long requestId;
Expand All @@ -64,13 +65,18 @@ public class Header {
Tuple<Map<String, String>, Map<String, Set<String>>> headers;
Set<String> features;

Header(int networkMessageSize, long requestId, byte status, Version version) {
Header(TransportProtocol protocol, int networkMessageSize, long requestId, byte status, Version version) {
this.protocol = protocol;
this.networkMessageSize = networkMessageSize;
this.version = version;
this.requestId = requestId;
this.status = status;
}

TransportProtocol getTransportProtocol() {
return protocol;
}

public int getNetworkMessageSize() {
return networkMessageSize;
}
Expand Down Expand Up @@ -142,6 +148,8 @@ void finishParsingHeader(StreamInput input) throws IOException {
@Override
public String toString() {
return "Header{"
+ protocol
+ "}{"
+ networkMessageSize
+ "}{"
+ version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public void aggregate(ReleasableBytesReference content) {
}
}

public NativeInboundMessage finishAggregation() throws IOException {
public ProtocolInboundMessage finishAggregation() throws IOException {
ensureOpen();
final ReleasableBytesReference releasableContent;
if (isFirstContent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,12 @@ private int headerBytesToRead(BytesReference reference) {
// exposed for use in tests
static Header readHeader(Version version, int networkMessageSize, BytesReference bytesReference) throws IOException {
try (StreamInput streamInput = bytesReference.streamInput()) {
streamInput.skip(TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE);
TransportProtocol protocol = TransportProtocol.fromBytes(streamInput.readByte(), streamInput.readByte());
streamInput.skip(TcpHeader.MESSAGE_LENGTH_SIZE);
long requestId = streamInput.readLong();
byte status = streamInput.readByte();
Version remoteVersion = Version.fromId(streamInput.readInt());
Header header = new Header(networkMessageSize, requestId, status, remoteVersion);
Header header = new Header(protocol, networkMessageSize, requestId, status, remoteVersion);
final IllegalStateException invalidVersion = ensureVersionCompatibility(remoteVersion, version, header.isHandshake());
if (invalidVersion != null) {
throw invalidVersion;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public class InboundHandler {

private volatile long slowLogThresholdMs = Long.MAX_VALUE;

private final Map<String, ProtocolMessageHandler> protocolMessageHandlers;
private final Map<TransportProtocol, ProtocolMessageHandler> protocolMessageHandlers;

InboundHandler(
String nodeName,
Expand All @@ -75,7 +75,7 @@ public class InboundHandler {
) {
this.threadPool = threadPool;
this.protocolMessageHandlers = Map.of(
NativeInboundMessage.NATIVE_PROTOCOL,
TransportProtocol.NATIVE,
new NativeMessageHandler(
nodeName,
version,
Expand Down Expand Up @@ -114,9 +114,9 @@ void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) throws E
}

private void messageReceivedFromPipeline(TcpChannel channel, ProtocolInboundMessage message, long startTime) throws IOException {
ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getProtocol());
ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getTransportProtocol());
if (protocolMessageHandler == null) {
throw new IllegalStateException("No protocol message handler found for protocol: " + message.getProtocol());
throw new IllegalStateException("No protocol message handler found for protocol: " + message.getTransportProtocol());
}
protocolMessageHandler.messageReceived(channel, message, startTime, slowLogThresholdMs, messageListener);
}
Expand Down
28 changes: 4 additions & 24 deletions server/src/main/java/org/opensearch/transport/InboundPipeline.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ public class InboundPipeline implements Releasable {
private final ArrayDeque<ReleasableBytesReference> pending = new ArrayDeque<>(2);
private boolean isClosed = false;
private final BiConsumer<TcpChannel, ProtocolInboundMessage> messageHandler;
private final List<InboundBytesHandler> protocolBytesHandlers;
private InboundBytesHandler currentHandler;
private final InboundBytesHandler bytesHandler;

public InboundPipeline(
Version version,
Expand Down Expand Up @@ -95,17 +94,14 @@ public InboundPipeline(
this.statsTracker = statsTracker;
this.decoder = decoder;
this.aggregator = aggregator;
this.protocolBytesHandlers = List.of(new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker));
this.bytesHandler = new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker);
this.messageHandler = messageHandler;
}

@Override
public void close() {
isClosed = true;
if (currentHandler != null) {
currentHandler.close();
currentHandler = null;
}
bytesHandler.close();
Releasables.closeWhileHandlingException(decoder, aggregator);
Releasables.closeWhileHandlingException(pending);
pending.clear();
Expand All @@ -127,22 +123,6 @@ public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference
channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong());
statsTracker.markBytesRead(reference.length());
pending.add(reference.retain());

// If we don't have a current handler, we should try to find one based on the protocol of the incoming bytes.
if (currentHandler == null) {
for (InboundBytesHandler handler : protocolBytesHandlers) {
if (handler.canHandleBytes(reference)) {
currentHandler = handler;
break;
}
}
}

// If we have a current handler determined based on protocol, we should continue to use it for the fragmented bytes.
if (currentHandler != null) {
currentHandler.doHandleBytes(channel, reference, messageHandler);
} else {
throw new IllegalStateException("No bytes handler found for the incoming transport protocol");
}
bytesHandler.doHandleBytes(channel, reference, messageHandler);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
package org.opensearch.transport;

import org.opensearch.common.annotation.PublicApi;
import org.opensearch.common.bytes.ReleasableBytesReference;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;

/**
* Base class for inbound data as a message.
Expand All @@ -17,11 +20,89 @@
* @opensearch.internal
*/
@PublicApi(since = "2.14.0")
public interface ProtocolInboundMessage {
public abstract class ProtocolInboundMessage implements Releasable {

/**
* @return the protocol used to encode this message
*/
public String getProtocol();
protected final Header header;
protected final ReleasableBytesReference content;
protected final Exception exception;
protected final boolean isPing;
private Releasable breakerRelease;

public ProtocolInboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) {
this.header = header;
this.content = content;
this.breakerRelease = breakerRelease;
this.exception = null;
this.isPing = false;
}

public ProtocolInboundMessage(Header header, Exception exception) {
this.header = header;
this.content = null;
this.breakerRelease = null;
this.exception = exception;
this.isPing = false;
}

public ProtocolInboundMessage(Header header, boolean isPing) {
this.header = header;
this.content = null;
this.breakerRelease = null;
this.exception = null;
this.isPing = isPing;
}

TransportProtocol getTransportProtocol() {
return header.getTransportProtocol();
}

public String getProtocol() {
return header.getTransportProtocol().toString();
}

public Header getHeader() {
return header;
}

public int getContentLength() {
if (content == null) {
return 0;
} else {
return content.length();
}
}

public Exception getException() {
return exception;
}

public boolean isPing() {
return isPing;
}

public boolean isShortCircuit() {
return exception != null;
}

public Releasable takeBreakerReleaseControl() {
final Releasable toReturn = breakerRelease;
breakerRelease = null;
if (toReturn != null) {
return toReturn;
} else {
return () -> {};
}
}



@Override
public void close() {
Releasables.closeWhileHandlingException(content, breakerRelease);
}

@Override
public String toString() {
return "InboundMessage{" + header + "}";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.transport;

enum TransportProtocol {
NATIVE;

public static TransportProtocol fromBytes(byte b1, byte b2) {
if (b1 == 'E' && b2 == 'S') {
return NATIVE;
}

throw new IllegalArgumentException("Unknown transport protocol: [" + b1 + ", " + b2 + "]");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ private void forwardFragments(
messageHandler.accept(channel, PING_MESSAGE);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
try (NativeInboundMessage aggregated = aggregator.finishAggregation()) {
try (ProtocolInboundMessage aggregated = aggregator.finishAggregation()) {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,81 +49,25 @@
* @opensearch.api
*/
@PublicApi(since = "2.14.0")
public class NativeInboundMessage implements Releasable, ProtocolInboundMessage {
public class NativeInboundMessage extends ProtocolInboundMessage {

/**
* The protocol used to encode this message
*/
public static String NATIVE_PROTOCOL = "native";

private final Header header;
private final ReleasableBytesReference content;
private final Exception exception;
private final boolean isPing;
private Releasable breakerRelease;
private StreamInput streamInput;

public NativeInboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) {
this.header = header;
this.content = content;
this.breakerRelease = breakerRelease;
this.exception = null;
this.isPing = false;
super(header, content, breakerRelease);
}

public NativeInboundMessage(Header header, Exception exception) {
this.header = header;
this.content = null;
this.breakerRelease = null;
this.exception = exception;
this.isPing = false;
super(header, exception);
}

public NativeInboundMessage(Header header, boolean isPing) {
this.header = header;
this.content = null;
this.breakerRelease = null;
this.exception = null;
this.isPing = isPing;
}

@Override
public String getProtocol() {
return NATIVE_PROTOCOL;
}

public Header getHeader() {
return header;
}

public int getContentLength() {
if (content == null) {
return 0;
} else {
return content.length();
}
}

public Exception getException() {
return exception;
}

public boolean isPing() {
return isPing;
}

public boolean isShortCircuit() {
return exception != null;
}

public Releasable takeBreakerReleaseControl() {
final Releasable toReturn = breakerRelease;
breakerRelease = null;
if (toReturn != null) {
return toReturn;
} else {
return () -> {};
}
super(header, isPing);
}

public StreamInput openOrGetStreamInput() throws IOException {
Expand All @@ -138,12 +82,6 @@ public StreamInput openOrGetStreamInput() throws IOException {
@Override
public void close() {
IOUtils.closeWhileHandlingException(streamInput);
Releasables.closeWhileHandlingException(content, breakerRelease);
super.close();
}

@Override
public String toString() {
return "InboundMessage{" + header + "}";
}

}
Loading

0 comments on commit 0138b6e

Please sign in to comment.