Skip to content

Commit

Permalink
V3 bug fix: wss connection closed on a single client request close ea…
Browse files Browse the repository at this point in the history
…rly (#3)

V3 bug fix: on seeing a single client early close, sending rst_stream msg instead of closing the wss tunnel.
  • Loading branch information
jayjlu authored Nov 29, 2023
1 parent 29bcefb commit 91f977c
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 29 deletions.
10 changes: 5 additions & 5 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
<dependency>
<groupId>io.muserver</groupId>
<artifactId>mu-server</artifactId>
<version>0.74.1</version>
<version>0.74.3</version>
<scope>provided</scope>
</dependency>
<dependency>
Expand All @@ -90,26 +90,26 @@
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>5.10.0</version>
<version>5.10.1</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>4.11.0</version>
<version>4.12.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp-sse</artifactId>
<version>4.11.0</version>
<version>4.12.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.hsbc.cranker</groupId>
<artifactId>cranker-connector</artifactId>
<version>1.2.1</version>
<version>1.2.3</version>
<scope>test</scope>
</dependency>

Expand Down
61 changes: 38 additions & 23 deletions src/main/java/com/hsbc/cranker/mucranker/RouterSocketV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,15 @@ public void sendRequestOverWebSocketV3(MuRequest clientRequest, MuResponse clien
contextMap.put(requestId, context);

asyncHandle.addResponseCompleteHandler(info -> {
if (!info.completedSuccessfully() && !state().endState()) {
// end client early close
log.info("Closing socket because client request did not complete successfully for " + clientRequest);
if (!info.completedSuccessfully()) {
log.info("Client request did not complete successfully " + clientRequest);
if (context.error == null) {
context.error = new IllegalStateException("Client request did not complete successfully.");
}
raiseCompletionEvent(context);
resetStream(context, ERROR_INTERNAL, "Client closed", DoneCallback.NoOp);
if (!state().endState()) {
resetStream(context, ERROR_INTERNAL, "Client early closed", DoneCallback.NoOp);
}
}
});

Expand Down Expand Up @@ -184,9 +188,9 @@ public void onComplete() {

@Override
public void onError(Throwable t) {
asyncHandle.complete(t);
try {
notifyStreamError(context, t);
notifyClientRequestError(context, t);
resetStream(context, ERROR_INTERNAL, "Client request body read error", DoneCallback.NoOp);
} catch (Exception ignored) {
}
}
Expand Down Expand Up @@ -230,7 +234,7 @@ private void sendData(ByteBuffer byteBuffer, DoneCallback doneCallback) {
}

void socketSessionClose() {
if (contextMap.size() == 0) {
if (contextMap.isEmpty()) {
try {
MuWebSocketSession session = session();
if (session != null) {
Expand All @@ -244,11 +248,15 @@ void socketSessionClose() {
}

void resetStream(RequestContext context, Integer errorCode, String message, DoneCallback doneCallback) {
if (context != null && !context.state.isCompleted()) {
if (context != null && !context.state.isCompleted() && !context.isRstStreamSent) {
final ByteBuffer buffer = rstMessage(context.requestId, errorCode, message);
sendData(buffer, doneCallback);
context.isRstStreamSent = true;
}

if (context != null) {
contextMap.remove(context.requestId);
}
contextMap.remove(context.requestId);
}

@Override
Expand All @@ -270,11 +278,11 @@ public void onClientClosed(int statusCode, String reason) throws Exception {
log.warn("websocket exceptional closed from client: statusCode={}, reason={}", statusCode, reason);
}
for (RequestContext context : contextMap.values()) {
notifyStreamClose(context, statusCode);
notifyClientRequestClose(context, statusCode);
}
}

private void notifyStreamClose(RequestContext context, int statusCode) {
private void notifyClientRequestClose(RequestContext context, int statusCode) {
try {
if (!proxyListeners.isEmpty()) {
for (ProxyListener proxyListener : proxyListeners) {
Expand Down Expand Up @@ -304,6 +312,9 @@ private void notifyStreamClose(RequestContext context, int statusCode) {
}
}
} finally {
if (statusCode != 1000 && context.error == null) {
context.error = new IllegalStateException("Upstream server close with code " + statusCode);
}
raiseCompletionEvent(context);
contextMap.remove(context.requestId);
}
Expand All @@ -330,11 +341,11 @@ public void onError(Throwable cause) throws Exception {
isRemoved = true;
}
for (RequestContext context : contextMap.values()) {
notifyStreamError(context, cause);
notifyClientRequestError(context, cause);
}
}

private void notifyStreamError(RequestContext context, Throwable cause) throws Exception {
private void notifyClientRequestError(RequestContext context, Throwable cause) throws Exception {
try {
context.error = cause;
if (cause instanceof TimeoutException) {
Expand Down Expand Up @@ -425,7 +436,7 @@ public void onBinary(ByteBuffer byteBuffer, boolean isLast, DoneCallback doneAnd
handleHeaderMessage(context, fullContent);
}
if (isStreamEnd) {
notifyStreamClose(context, 1000);
notifyClientRequestClose(context, 1000);
}
sendData(windowUpdateMessage(requestId, byteLength), DoneCallback.NoOp);
releaseBuffer.run();
Expand All @@ -436,7 +447,7 @@ public void onBinary(ByteBuffer byteBuffer, boolean isLast, DoneCallback doneAnd
try {
final int errorCode = getErrorCode(byteBuffer);
String message = getErrorMessage(byteBuffer);
notifyStreamError(context, new RuntimeException(
notifyClientRequestError(context, new RuntimeException(
String.format("stream closed by connector, errorCode=%s, message=%s", errorCode, message)));
} catch (Throwable throwable) {
log.warn("exception on handling rst_stream", throwable);
Expand Down Expand Up @@ -478,7 +489,7 @@ private void handleData(RequestContext context, boolean isLast, boolean isEnd, B

int len = byteBuffer.remaining();
if (len == 0) {
if (isEnd) notifyStreamClose(context, 1000);
if (isEnd) notifyClientRequestClose(context, 1000);
releaseBuffer.run();
doneAndPullData.onComplete(null);
return;
Expand All @@ -488,7 +499,7 @@ private void handleData(RequestContext context, boolean isLast, boolean isEnd, B

WebsocketSessionState websocketState = state();
if (websocketState.endState()) {
if (isEnd) notifyStreamClose(context, 1000);
if (isEnd) notifyClientRequestClose(context, 1000);
releaseBuffer.run();
doneAndPullData.onComplete(new IllegalStateException("Received binary message from connector but state=" + websocketState));
return;
Expand All @@ -505,14 +516,18 @@ private void handleData(RequestContext context, boolean isLast, boolean isEnd, B
context.asyncHandle.write(byteBuffer, errorIfAny -> {
try {
if (errorIfAny == null) {
if (isEnd) notifyStreamClose(context, 1000);
if (isEnd) notifyClientRequestClose(context, 1000);
context.toClientBytes.addAndGet(len);
sendData(windowUpdateMessage(context.requestId, len), DoneCallback.NoOp);
} else {
log.info("routerName=" + route + ", routerSocketID=" + routerSocketID +
", could not write to client response (maybe the user closed their browser)" +
" so will cancel the request. Error message: " + errorIfAny.getMessage());
onError(errorIfAny);

// reset the request context instead of closing everything
// the rst_stream will be sent to wss socket in asyncHandle.addResponseCompleteHandler() callback
context.error = errorIfAny;
context.asyncHandle.complete(errorIfAny);
}
if (!proxyListeners.isEmpty()) {
for (ProxyListener proxyListener : proxyListeners) {
Expand Down Expand Up @@ -604,13 +619,12 @@ static ByteBuffer headerMessage(Integer requestId, boolean isHeaderEnd, boolean
if (isStreamEnd) flags = flags | 1; // first bit 00000001
if (isHeaderEnd) flags = flags | 4; // third bit 00000100
final byte[] bytes = headerLine.getBytes(StandardCharsets.UTF_8);
final ByteBuffer message = ByteBuffer.allocate(6 + bytes.length)
return ByteBuffer.allocate(6 + bytes.length)
.put(MESSAGE_TYPE_HEADER) // 1 byte
.put((byte) flags) // 1 byte
.putInt(requestId) // 4 byte
.put(bytes)
.rewind();
return message;
}

static ByteBuffer dataMessages(Integer requestId, boolean isEnd, ByteBuffer buffer) {
Expand Down Expand Up @@ -665,7 +679,8 @@ public class RequestContext implements ProxyInfo {
final AtomicLong toClientBytes = new AtomicLong();

long durationMillis = 0;
Throwable error = null;
volatile Throwable error = null;
volatile boolean isRstStreamSent = false;
StreamState state = StreamState.OPEN;
StringBuilder headerLineBuilder;

Expand Down Expand Up @@ -703,7 +718,7 @@ void flowControl(Runnable runnable) {
}

private void writeItMaybe() {
if (isWssWritable.get() && wssWriteCallbacks.size() > 0 && isWssWriting.compareAndSet(false, true)) {
if (isWssWritable.get() && !wssWriteCallbacks.isEmpty() && isWssWriting.compareAndSet(false, true)) {
try {
Runnable current;
while (isWssWritable.get() && (current = wssWriteCallbacks.poll()) != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public void MuServer_TargetServerDownInMiddleTest_ClientTalkToRouter(RepetitionI
publisher.send("Number 0");
publisher.send("Number 1");
publisher.send("Number 2");
client.waitMessageListSizeGreaterThan(3, 10, TimeUnit.SECONDS);
targetServer.stop();
})
.start();
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/scaffolding/SseTestClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Null

@Override
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
messages.add(String.format("onFailure: message=%s", t.getMessage()));
messages.add(String.format("onFailure: message=%s", t != null ? t.getMessage() : ""));
errorLatch.countDown();
}

Expand Down

0 comments on commit 91f977c

Please sign in to comment.