Skip to content

Commit

Permalink
Don't send keep alive signals before kex is done (#934)
Browse files Browse the repository at this point in the history
Otherwise, they could interfere with strict key exchange.

Co-authored-by: Jeroen van Erp <[email protected]>
  • Loading branch information
hpoettker and hierynomus authored Apr 15, 2024
1 parent 70af58d commit 81d77d2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,26 @@
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.read.ListAppender;
import com.hierynomus.sshj.SshdContainer;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.transport.TransportException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.LoggerFactory;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
Expand Down Expand Up @@ -62,14 +73,27 @@ private void setUpLogger(String className) {
watchedLoggers.add(logger);
}

@Test
void strictKeyExchange() throws Throwable {
try (SSHClient client = sshd.getConnectedClient()) {
private static Stream<Arguments> strictKeyExchange() {
Config defaultConfig = new DefaultConfig();
Config heartbeaterConfig = new DefaultConfig();
heartbeaterConfig.setKeepAliveProvider(new KeepAliveProvider() {
@Override
public KeepAlive provide(ConnectionImpl connection) {
return new HotLoopHeartbeater(connection);
}
});
return Stream.of(defaultConfig, heartbeaterConfig).map(Arguments::of);
}

@MethodSource
@ParameterizedTest
void strictKeyExchange(Config config) throws Throwable {
try (SSHClient client = sshd.getConnectedClient(config)) {
client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1");
assertTrue(client.isAuthenticated());
}
List<String> keyExchangerLogs = getLogs("KeyExchanger");
assertThat(keyExchangerLogs).containsSequence(
assertThat(keyExchangerLogs).contains(
"Initiating key exchange",
"Sending SSH_MSG_KEXINIT",
"Received SSH_MSG_KEXINIT",
Expand All @@ -78,7 +102,7 @@ void strictKeyExchange() throws Throwable {
List<String> decoderLogs = getLogs("Decoder").stream()
.map(log -> log.split(":")[0])
.collect(Collectors.toList());
assertThat(decoderLogs).containsExactly(
assertThat(decoderLogs).startsWith(
"Received packet #0",
"Received packet #1",
"Received packet #2",
Expand All @@ -90,7 +114,7 @@ void strictKeyExchange() throws Throwable {
List<String> encoderLogs = getLogs("Encoder").stream()
.map(log -> log.split(":")[0])
.collect(Collectors.toList());
assertThat(encoderLogs).containsExactly(
assertThat(encoderLogs).startsWith(
"Encoding packet #0",
"Encoding packet #1",
"Encoding packet #2",
Expand All @@ -108,4 +132,22 @@ private List<String> getLogs(String className) {
.collect(Collectors.toList());
}

private static class HotLoopHeartbeater extends KeepAlive {

HotLoopHeartbeater(ConnectionImpl conn) {
super(conn, "sshj-Heartbeater");
}

@Override
public boolean isEnabled() {
return true;
}

@Override
protected void doKeepAlive() throws TransportException {
conn.getTransport().write(new SSHPacket(Message.IGNORE));
}

}

}
2 changes: 1 addition & 1 deletion src/main/java/net/schmizz/sshj/SSHClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -804,12 +804,12 @@ protected void onConnect()
throws IOException {
super.onConnect();
trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream());
doKex();
final KeepAlive keepAliveThread = conn.getKeepAlive();
if (keepAliveThread.isEnabled()) {
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
keepAliveThread.start();
}
doKex();
}

/**
Expand Down

0 comments on commit 81d77d2

Please sign in to comment.