Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
hierynomus authored Apr 15, 2024
2 parents 56e9dc2 + 81d77d2 commit c3f92a8
Show file tree
Hide file tree
Showing 13 changed files with 385 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,26 @@
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.read.ListAppender;
import com.hierynomus.sshj.SshdContainer;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.transport.TransportException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.LoggerFactory;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
Expand Down Expand Up @@ -62,14 +73,27 @@ private void setUpLogger(String className) {
watchedLoggers.add(logger);
}

@Test
void strictKeyExchange() throws Throwable {
try (SSHClient client = sshd.getConnectedClient()) {
private static Stream<Arguments> strictKeyExchange() {
Config defaultConfig = new DefaultConfig();
Config heartbeaterConfig = new DefaultConfig();
heartbeaterConfig.setKeepAliveProvider(new KeepAliveProvider() {
@Override
public KeepAlive provide(ConnectionImpl connection) {
return new HotLoopHeartbeater(connection);
}
});
return Stream.of(defaultConfig, heartbeaterConfig).map(Arguments::of);
}

@MethodSource
@ParameterizedTest
void strictKeyExchange(Config config) throws Throwable {
try (SSHClient client = sshd.getConnectedClient(config)) {
client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1");
assertTrue(client.isAuthenticated());
}
List<String> keyExchangerLogs = getLogs("KeyExchanger");
assertThat(keyExchangerLogs).containsSequence(
assertThat(keyExchangerLogs).contains(
"Initiating key exchange",
"Sending SSH_MSG_KEXINIT",
"Received SSH_MSG_KEXINIT",
Expand All @@ -78,7 +102,7 @@ void strictKeyExchange() throws Throwable {
List<String> decoderLogs = getLogs("Decoder").stream()
.map(log -> log.split(":")[0])
.collect(Collectors.toList());
assertThat(decoderLogs).containsExactly(
assertThat(decoderLogs).startsWith(
"Received packet #0",
"Received packet #1",
"Received packet #2",
Expand All @@ -90,7 +114,7 @@ void strictKeyExchange() throws Throwable {
List<String> encoderLogs = getLogs("Encoder").stream()
.map(log -> log.split(":")[0])
.collect(Collectors.toList());
assertThat(encoderLogs).containsExactly(
assertThat(encoderLogs).startsWith(
"Encoding packet #0",
"Encoding packet #1",
"Encoding packet #2",
Expand All @@ -108,4 +132,22 @@ private List<String> getLogs(String className) {
.collect(Collectors.toList());
}

private static class HotLoopHeartbeater extends KeepAlive {

HotLoopHeartbeater(ConnectionImpl conn) {
super(conn, "sshj-Heartbeater");
}

@Override
public boolean isEnabled() {
return true;
}

@Override
protected void doKeepAlive() throws TransportException {
conn.getTransport().write(new SSHPacket(Message.IGNORE));
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package com.hierynomus.sshj.transport.verification;

import net.schmizz.sshj.common.Base64DecodingException;
import net.schmizz.sshj.common.Base64Decoder;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.transport.mac.MAC;
Expand All @@ -26,9 +28,13 @@
import java.util.regex.Pattern;

import com.hierynomus.sshj.transport.mac.Macs;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KnownHostMatchers {

private static final Logger log = LoggerFactory.getLogger(KnownHostMatchers.class);

public static HostMatcher createMatcher(String hostEntry) throws SSHException {
if (hostEntry.contains(",")) {
return new AnyHostMatcher(hostEntry);
Expand Down Expand Up @@ -80,17 +86,22 @@ private static class HashedHostMatcher implements HostMatcher {

@Override
public boolean match(String hostname) throws IOException {
return hash.equals(hashHost(hostname));
try {
return hash.equals(hashHost(hostname));
} catch (Base64DecodingException err) {
log.warn("Hostname [{}] not matched: salt decoding failed", hostname, err);
return false;
}
}

private String hashHost(String host) throws IOException {
private String hashHost(String host) throws IOException, Base64DecodingException {
sha1.init(getSaltyBytes());
return "|1|" + salt + "|" + Base64.getEncoder().encodeToString(sha1.doFinal(host.getBytes(IOUtils.UTF8)));
}

private byte[] getSaltyBytes() {
private byte[] getSaltyBytes() throws IOException, Base64DecodingException {
if (saltyBytes == null) {
saltyBytes = Base64.getDecoder().decode(salt);
saltyBytes = Base64Decoder.decode(salt);
}
return saltyBytes;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package com.hierynomus.sshj.userauth.keyprovider;

import net.schmizz.sshj.common.Base64DecodingException;
import net.schmizz.sshj.common.Base64Decoder;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.KeyType;

Expand All @@ -23,7 +25,6 @@
import java.io.IOException;
import java.io.Reader;
import java.security.PublicKey;
import java.util.Base64;

public class OpenSSHKeyFileUtil {
private OpenSSHKeyFileUtil() {
Expand Down Expand Up @@ -54,16 +55,19 @@ public static ParsedPubKey initPubKey(Reader publicKey) throws IOException {
if (!keydata.isEmpty()) {
String[] parts = keydata.trim().split("\\s+");
if (parts.length >= 2) {
byte[] decodedPublicKey = Base64Decoder.decode(parts[1]);
return new ParsedPubKey(
KeyType.fromString(parts[0]),
new Buffer.PlainBuffer(Base64.getDecoder().decode(parts[1])).readPublicKey()
new Buffer.PlainBuffer(decodedPublicKey).readPublicKey()
);
} else {
throw new IOException("Got line with only one column");
}
}
}
throw new IOException("Public key file is blank");
} catch (Base64DecodingException err) {
throw new IOException("Public key decoding failed", err);
} finally {
br.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,8 @@
import net.i2p.crypto.eddsa.EdDSAPrivateKey;
import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable;
import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.*;
import net.schmizz.sshj.common.Buffer.PlainBuffer;
import net.schmizz.sshj.common.ByteArrayUtils;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.KeyType;
import net.schmizz.sshj.common.SSHRuntimeException;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.transport.cipher.Cipher;
import net.schmizz.sshj.userauth.keyprovider.BaseFileKeyProvider;
import net.schmizz.sshj.userauth.keyprovider.FileKeyProvider;
Expand All @@ -55,7 +50,6 @@
import java.security.spec.ECPrivateKeySpec;
import java.security.spec.RSAPrivateCrtKeySpec;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -124,7 +118,7 @@ protected KeyPair readKeyPair() throws IOException {
try {
if (checkHeader(reader)) {
final String encodedPrivateKey = readEncodedKey(reader);
byte[] decodedPrivateKey = Base64.getDecoder().decode(encodedPrivateKey);
byte[] decodedPrivateKey = Base64Decoder.decode(encodedPrivateKey);
final PlainBuffer bufferedPrivateKey = new PlainBuffer(decodedPrivateKey);
return readDecodedKeyPair(bufferedPrivateKey);
} else {
Expand All @@ -133,6 +127,8 @@ protected KeyPair readKeyPair() throws IOException {
}
} catch (final GeneralSecurityException e) {
throw new SSHRuntimeException("Read OpenSSH Version 1 Key failed", e);
} catch (Base64DecodingException e) {
throw new SSHRuntimeException("Private Key decoding failed", e);
} finally {
IOUtils.closeQuietly(reader);
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/net/schmizz/sshj/SSHClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -804,12 +804,12 @@ protected void onConnect()
throws IOException {
super.onConnect();
trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream());
doKex();
final KeepAlive keepAliveThread = conn.getKeepAlive();
if (keepAliveThread.isEnabled()) {
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
keepAliveThread.start();
}
doKex();
}

/**
Expand Down
47 changes: 47 additions & 0 deletions src/main/java/net/schmizz/sshj/common/Base64Decoder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package net.schmizz.sshj.common;

import java.io.IOException;
import java.util.Base64;

/**
* <p>Wraps {@link java.util.Base64.Decoder} in order to wrap unchecked {@code IllegalArgumentException} thrown by
* the default Java Base64 decoder here and there.</p>
*
* <p>Please use this class instead of {@link java.util.Base64.Decoder}.</p>
*/
public class Base64Decoder {
private Base64Decoder() {
}

public static byte[] decode(byte[] source) throws Base64DecodingException {
try {
return Base64.getDecoder().decode(source);
} catch (IllegalArgumentException err) {
throw new Base64DecodingException(err);
}
}

public static byte[] decode(String src) throws Base64DecodingException {
try {
return Base64.getDecoder().decode(src);
} catch (IllegalArgumentException err) {
throw new Base64DecodingException(err);
}
}
}
28 changes: 28 additions & 0 deletions src/main/java/net/schmizz/sshj/common/Base64DecodingException.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package net.schmizz.sshj.common;

/**
* A checked wrapper for all {@link IllegalArgumentException}, thrown by {@link java.util.Base64.Decoder}.
*
* @see Base64Decoder
*/
public class Base64DecodingException extends Exception {
public Base64DecodingException(IllegalArgumentException cause) {
super("Failed to decode base64: " + cause.getMessage(), cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,7 @@
import com.hierynomus.sshj.common.KeyAlgorithm;
import com.hierynomus.sshj.transport.verification.KnownHostMatchers;
import com.hierynomus.sshj.userauth.certificate.Certificate;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.KeyType;
import net.schmizz.sshj.common.LoggerFactory;
import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.common.SSHRuntimeException;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.common.*;
import org.slf4j.Logger;

import java.io.BufferedOutputStream;
Expand Down Expand Up @@ -290,9 +284,9 @@ public KnownHostEntry parseEntry(String line)
if (type != KeyType.UNKNOWN) {
final String sKey = split[i++];
try {
byte[] keyBytes = Base64.getDecoder().decode(sKey);
byte[] keyBytes = Base64Decoder.decode(sKey);
key = new Buffer.PlainBuffer(keyBytes).readPublicKey();
} catch (IOException | IllegalArgumentException exception) {
} catch (IOException | Base64DecodingException exception) {
log.warn("Error decoding Base64 key bytes", exception);
return new BadHostEntry(line);
}
Expand Down
Loading

0 comments on commit c3f92a8

Please sign in to comment.