diff --git a/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java b/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java index 2abe71a7..9d207c0e 100644 --- a/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java +++ b/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java @@ -18,15 +18,26 @@ import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; +import java.util.stream.Stream; import ch.qos.logback.classic.Logger; import ch.qos.logback.classic.spi.ILoggingEvent; import ch.qos.logback.core.read.ListAppender; import com.hierynomus.sshj.SshdContainer; +import net.schmizz.keepalive.KeepAlive; +import net.schmizz.keepalive.KeepAliveProvider; +import net.schmizz.sshj.Config; +import net.schmizz.sshj.DefaultConfig; import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.common.Message; +import net.schmizz.sshj.common.SSHPacket; +import net.schmizz.sshj.connection.ConnectionImpl; +import net.schmizz.sshj.transport.TransportException; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.LoggerFactory; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -62,14 +73,27 @@ private void setUpLogger(String className) { watchedLoggers.add(logger); } - @Test - void strictKeyExchange() throws Throwable { - try (SSHClient client = sshd.getConnectedClient()) { + private static Stream strictKeyExchange() { + Config defaultConfig = new DefaultConfig(); + Config heartbeaterConfig = new DefaultConfig(); + heartbeaterConfig.setKeepAliveProvider(new KeepAliveProvider() { + @Override + public KeepAlive provide(ConnectionImpl connection) { + return new HotLoopHeartbeater(connection); + } + }); + return Stream.of(defaultConfig, heartbeaterConfig).map(Arguments::of); + } + + @MethodSource + @ParameterizedTest + void strictKeyExchange(Config config) throws Throwable { + try (SSHClient client = sshd.getConnectedClient(config)) { client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1"); assertTrue(client.isAuthenticated()); } List keyExchangerLogs = getLogs("KeyExchanger"); - assertThat(keyExchangerLogs).containsSequence( + assertThat(keyExchangerLogs).contains( "Initiating key exchange", "Sending SSH_MSG_KEXINIT", "Received SSH_MSG_KEXINIT", @@ -78,7 +102,7 @@ void strictKeyExchange() throws Throwable { List decoderLogs = getLogs("Decoder").stream() .map(log -> log.split(":")[0]) .collect(Collectors.toList()); - assertThat(decoderLogs).containsExactly( + assertThat(decoderLogs).startsWith( "Received packet #0", "Received packet #1", "Received packet #2", @@ -90,7 +114,7 @@ void strictKeyExchange() throws Throwable { List encoderLogs = getLogs("Encoder").stream() .map(log -> log.split(":")[0]) .collect(Collectors.toList()); - assertThat(encoderLogs).containsExactly( + assertThat(encoderLogs).startsWith( "Encoding packet #0", "Encoding packet #1", "Encoding packet #2", @@ -108,4 +132,22 @@ private List getLogs(String className) { .collect(Collectors.toList()); } + private static class HotLoopHeartbeater extends KeepAlive { + + HotLoopHeartbeater(ConnectionImpl conn) { + super(conn, "sshj-Heartbeater"); + } + + @Override + public boolean isEnabled() { + return true; + } + + @Override + protected void doKeepAlive() throws TransportException { + conn.getTransport().write(new SSHPacket(Message.IGNORE)); + } + + } + } diff --git a/src/main/java/com/hierynomus/sshj/transport/verification/KnownHostMatchers.java b/src/main/java/com/hierynomus/sshj/transport/verification/KnownHostMatchers.java index 37fdaef5..cb664563 100644 --- a/src/main/java/com/hierynomus/sshj/transport/verification/KnownHostMatchers.java +++ b/src/main/java/com/hierynomus/sshj/transport/verification/KnownHostMatchers.java @@ -15,6 +15,8 @@ */ package com.hierynomus.sshj.transport.verification; +import net.schmizz.sshj.common.Base64DecodingException; +import net.schmizz.sshj.common.Base64Decoder; import net.schmizz.sshj.common.IOUtils; import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.transport.mac.MAC; @@ -26,9 +28,13 @@ import java.util.regex.Pattern; import com.hierynomus.sshj.transport.mac.Macs; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class KnownHostMatchers { + private static final Logger log = LoggerFactory.getLogger(KnownHostMatchers.class); + public static HostMatcher createMatcher(String hostEntry) throws SSHException { if (hostEntry.contains(",")) { return new AnyHostMatcher(hostEntry); @@ -80,17 +86,22 @@ private static class HashedHostMatcher implements HostMatcher { @Override public boolean match(String hostname) throws IOException { - return hash.equals(hashHost(hostname)); + try { + return hash.equals(hashHost(hostname)); + } catch (Base64DecodingException err) { + log.warn("Hostname [{}] not matched: salt decoding failed", hostname, err); + return false; + } } - private String hashHost(String host) throws IOException { + private String hashHost(String host) throws IOException, Base64DecodingException { sha1.init(getSaltyBytes()); return "|1|" + salt + "|" + Base64.getEncoder().encodeToString(sha1.doFinal(host.getBytes(IOUtils.UTF8))); } - private byte[] getSaltyBytes() { + private byte[] getSaltyBytes() throws IOException, Base64DecodingException { if (saltyBytes == null) { - saltyBytes = Base64.getDecoder().decode(salt); + saltyBytes = Base64Decoder.decode(salt); } return saltyBytes; } diff --git a/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyFileUtil.java b/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyFileUtil.java index edb56ef3..94802c41 100644 --- a/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyFileUtil.java +++ b/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyFileUtil.java @@ -15,6 +15,8 @@ */ package com.hierynomus.sshj.userauth.keyprovider; +import net.schmizz.sshj.common.Base64DecodingException; +import net.schmizz.sshj.common.Base64Decoder; import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.KeyType; @@ -23,7 +25,6 @@ import java.io.IOException; import java.io.Reader; import java.security.PublicKey; -import java.util.Base64; public class OpenSSHKeyFileUtil { private OpenSSHKeyFileUtil() { @@ -54,9 +55,10 @@ public static ParsedPubKey initPubKey(Reader publicKey) throws IOException { if (!keydata.isEmpty()) { String[] parts = keydata.trim().split("\\s+"); if (parts.length >= 2) { + byte[] decodedPublicKey = Base64Decoder.decode(parts[1]); return new ParsedPubKey( KeyType.fromString(parts[0]), - new Buffer.PlainBuffer(Base64.getDecoder().decode(parts[1])).readPublicKey() + new Buffer.PlainBuffer(decodedPublicKey).readPublicKey() ); } else { throw new IOException("Got line with only one column"); @@ -64,6 +66,8 @@ public static ParsedPubKey initPubKey(Reader publicKey) throws IOException { } } throw new IOException("Public key file is blank"); + } catch (Base64DecodingException err) { + throw new IOException("Public key decoding failed", err); } finally { br.close(); } diff --git a/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java b/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java index 9229fa4a..5d89356f 100644 --- a/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java +++ b/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java @@ -23,13 +23,8 @@ import net.i2p.crypto.eddsa.EdDSAPrivateKey; import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable; import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec; -import net.schmizz.sshj.common.Buffer; +import net.schmizz.sshj.common.*; import net.schmizz.sshj.common.Buffer.PlainBuffer; -import net.schmizz.sshj.common.ByteArrayUtils; -import net.schmizz.sshj.common.IOUtils; -import net.schmizz.sshj.common.KeyType; -import net.schmizz.sshj.common.SSHRuntimeException; -import net.schmizz.sshj.common.SecurityUtils; import net.schmizz.sshj.transport.cipher.Cipher; import net.schmizz.sshj.userauth.keyprovider.BaseFileKeyProvider; import net.schmizz.sshj.userauth.keyprovider.FileKeyProvider; @@ -55,7 +50,6 @@ import java.security.spec.ECPrivateKeySpec; import java.security.spec.RSAPrivateCrtKeySpec; import java.util.Arrays; -import java.util.Base64; import java.util.HashMap; import java.util.Map; @@ -124,7 +118,7 @@ protected KeyPair readKeyPair() throws IOException { try { if (checkHeader(reader)) { final String encodedPrivateKey = readEncodedKey(reader); - byte[] decodedPrivateKey = Base64.getDecoder().decode(encodedPrivateKey); + byte[] decodedPrivateKey = Base64Decoder.decode(encodedPrivateKey); final PlainBuffer bufferedPrivateKey = new PlainBuffer(decodedPrivateKey); return readDecodedKeyPair(bufferedPrivateKey); } else { @@ -133,6 +127,8 @@ protected KeyPair readKeyPair() throws IOException { } } catch (final GeneralSecurityException e) { throw new SSHRuntimeException("Read OpenSSH Version 1 Key failed", e); + } catch (Base64DecodingException e) { + throw new SSHRuntimeException("Private Key decoding failed", e); } finally { IOUtils.closeQuietly(reader); } diff --git a/src/main/java/net/schmizz/sshj/SSHClient.java b/src/main/java/net/schmizz/sshj/SSHClient.java index dd0e3817..78b91c5f 100644 --- a/src/main/java/net/schmizz/sshj/SSHClient.java +++ b/src/main/java/net/schmizz/sshj/SSHClient.java @@ -804,12 +804,12 @@ protected void onConnect() throws IOException { super.onConnect(); trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream()); + doKex(); final KeepAlive keepAliveThread = conn.getKeepAlive(); if (keepAliveThread.isEnabled()) { ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans); keepAliveThread.start(); } - doKex(); } /** diff --git a/src/main/java/net/schmizz/sshj/common/Base64Decoder.java b/src/main/java/net/schmizz/sshj/common/Base64Decoder.java new file mode 100644 index 00000000..e29608ad --- /dev/null +++ b/src/main/java/net/schmizz/sshj/common/Base64Decoder.java @@ -0,0 +1,47 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package net.schmizz.sshj.common; + +import java.io.IOException; +import java.util.Base64; + +/** + *

Wraps {@link java.util.Base64.Decoder} in order to wrap unchecked {@code IllegalArgumentException} thrown by + * the default Java Base64 decoder here and there.

+ * + *

Please use this class instead of {@link java.util.Base64.Decoder}.

+ */ +public class Base64Decoder { + private Base64Decoder() { + } + + public static byte[] decode(byte[] source) throws Base64DecodingException { + try { + return Base64.getDecoder().decode(source); + } catch (IllegalArgumentException err) { + throw new Base64DecodingException(err); + } + } + + public static byte[] decode(String src) throws Base64DecodingException { + try { + return Base64.getDecoder().decode(src); + } catch (IllegalArgumentException err) { + throw new Base64DecodingException(err); + } + } +} diff --git a/src/main/java/net/schmizz/sshj/common/Base64DecodingException.java b/src/main/java/net/schmizz/sshj/common/Base64DecodingException.java new file mode 100644 index 00000000..cc18ead7 --- /dev/null +++ b/src/main/java/net/schmizz/sshj/common/Base64DecodingException.java @@ -0,0 +1,28 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package net.schmizz.sshj.common; + +/** + * A checked wrapper for all {@link IllegalArgumentException}, thrown by {@link java.util.Base64.Decoder}. + * + * @see Base64Decoder + */ +public class Base64DecodingException extends Exception { + public Base64DecodingException(IllegalArgumentException cause) { + super("Failed to decode base64: " + cause.getMessage(), cause); + } +} diff --git a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java index c41b83d7..a5821908 100644 --- a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java +++ b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java @@ -18,13 +18,7 @@ import com.hierynomus.sshj.common.KeyAlgorithm; import com.hierynomus.sshj.transport.verification.KnownHostMatchers; import com.hierynomus.sshj.userauth.certificate.Certificate; -import net.schmizz.sshj.common.Buffer; -import net.schmizz.sshj.common.IOUtils; -import net.schmizz.sshj.common.KeyType; -import net.schmizz.sshj.common.LoggerFactory; -import net.schmizz.sshj.common.SSHException; -import net.schmizz.sshj.common.SSHRuntimeException; -import net.schmizz.sshj.common.SecurityUtils; +import net.schmizz.sshj.common.*; import org.slf4j.Logger; import java.io.BufferedOutputStream; @@ -290,9 +284,9 @@ public KnownHostEntry parseEntry(String line) if (type != KeyType.UNKNOWN) { final String sKey = split[i++]; try { - byte[] keyBytes = Base64.getDecoder().decode(sKey); + byte[] keyBytes = Base64Decoder.decode(sKey); key = new Buffer.PlainBuffer(keyBytes).readPublicKey(); - } catch (IOException | IllegalArgumentException exception) { + } catch (IOException | Base64DecodingException exception) { log.warn("Error decoding Base64 key bytes", exception); return new BadHostEntry(line); } diff --git a/src/main/java/net/schmizz/sshj/userauth/keyprovider/PuTTYKeyFile.java b/src/main/java/net/schmizz/sshj/userauth/keyprovider/PuTTYKeyFile.java index 9794da0f..444c222a 100644 --- a/src/main/java/net/schmizz/sshj/userauth/keyprovider/PuTTYKeyFile.java +++ b/src/main/java/net/schmizz/sshj/userauth/keyprovider/PuTTYKeyFile.java @@ -22,9 +22,7 @@ import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable; import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec; import net.i2p.crypto.eddsa.spec.EdDSAPublicKeySpec; -import net.schmizz.sshj.common.Buffer; -import net.schmizz.sshj.common.KeyType; -import net.schmizz.sshj.common.SecurityUtils; +import net.schmizz.sshj.common.*; import net.schmizz.sshj.userauth.password.PasswordUtils; import org.bouncycastle.asn1.nist.NISTNamedCurves; import org.bouncycastle.asn1.x9.X9ECParameters; @@ -42,7 +40,6 @@ import java.security.*; import java.security.spec.*; import java.util.Arrays; -import java.util.Base64; import java.util.HashMap; import java.util.Map; @@ -240,29 +237,34 @@ protected void parseKeyPair() throws IOException { if (this.keyFileVersion == null) { throw new IOException("Invalid key file format: missing \"PuTTY-User-Key-File-?\" entry"); } - // Retrieve keys from payload - publicKey = Base64.getDecoder().decode(payload.get("Public-Lines")); - if (this.isEncrypted()) { - final char[] passphrase; - if (pwdf != null) { - passphrase = pwdf.reqPassword(resource); - } else { - passphrase = "".toCharArray(); - } - try { - privateKey = this.decrypt(Base64.getDecoder().decode(payload.get("Private-Lines")), passphrase); - Mac mac; - if (this.keyFileVersion <= 2) { - mac = this.prepareVerifyMacV2(passphrase); + try { + // Retrieve keys from payload + publicKey = Base64Decoder.decode(payload.get("Public-Lines")); + if (this.isEncrypted()) { + final char[] passphrase; + if (pwdf != null) { + passphrase = pwdf.reqPassword(resource); } else { - mac = this.prepareVerifyMacV3(); + passphrase = "".toCharArray(); + } + try { + privateKey = this.decrypt(Base64Decoder.decode(payload.get("Private-Lines")), passphrase); + Mac mac; + if (this.keyFileVersion <= 2) { + mac = this.prepareVerifyMacV2(passphrase); + } else { + mac = this.prepareVerifyMacV3(); + } + this.verify(mac); + } finally { + PasswordUtils.blankOut(passphrase); } - this.verify(mac); - } finally { - PasswordUtils.blankOut(passphrase); + } else { + privateKey = Base64Decoder.decode(payload.get("Private-Lines")); } - } else { - privateKey = Base64.getDecoder().decode(payload.get("Private-Lines")); + } + catch (Base64DecodingException e) { + throw new IOException("PuTTY key decoding failed", e); } } diff --git a/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java b/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java index 01dbe2f5..8bfebdae 100644 --- a/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java +++ b/src/test/java/com/hierynomus/sshj/transport/verification/OpenSSHKnownHostsTest.java @@ -15,11 +15,16 @@ */ package com.hierynomus.sshj.transport.verification; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.assertj.core.api.Assertions.*; +import net.schmizz.sshj.common.Buffer; +import net.schmizz.sshj.common.SecurityUtils; +import net.schmizz.sshj.transport.verification.OpenSSHKnownHosts; +import net.schmizz.sshj.util.KeyUtil; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import java.io.File; import java.io.IOException; @@ -29,17 +34,8 @@ import java.util.Base64; import java.util.stream.Stream; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.io.TempDir; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; - -import net.schmizz.sshj.common.Buffer; -import net.schmizz.sshj.common.SecurityUtils; -import net.schmizz.sshj.transport.verification.OpenSSHKnownHosts; -import net.schmizz.sshj.util.KeyUtil; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.*; public class OpenSSHKnownHostsTest { @TempDir @@ -118,6 +114,24 @@ public void shouldNotFailOnMalformedBase64String() throws IOException { assertThat(ohk.entries().get(0)).isInstanceOf(OpenSSHKnownHosts.BadHostEntry.class); } + @Test + public void shouldNotFailOnMalformeSaltBase64String() throws IOException { + // A record with broken base64 inside the salt part of the hash. + // No matter how it could be generated, such broken strings must not cause unexpected errors. + String hostName = "example.com"; + File knownHosts = knownHosts( + "|1|2gujgGa6gJnK7wGPCX8zuGttvCMXX|Oqkbjtxd9RFxKQv6y3l3GIxLNiU= ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBGVVnyoAD5/uWiiuTSM3RuW8dEWRrqOXYobAMKHhAA6kuOBoPK+LoAYyUcN26bdMiCxg+VOaLHxPNWv5SlhbMWw=\n" + ); + OpenSSHKnownHosts ohk = new OpenSSHKnownHosts(knownHosts); + assertEquals(1, ohk.entries().size()); + + // Some random valid public key. It doesn't matter for the test if it matches the broken host key record or not. + PublicKey k = new Buffer.PlainBuffer(Base64.getDecoder().decode( + "AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBLTjA7hduYGmvV9smEEsIdGLdghSPD7kL8QarIIOkeXmBh+LTtT/T1K+Ot/rmXCZsP8hoUXxbvN+Tks440Ci0ck=")) + .readPublicKey(); + assertFalse(ohk.verify(hostName, 22, k)); + } + @Test public void shouldMarkBadLineAndNotFail() throws Exception { File knownHosts = knownHosts( diff --git a/src/test/java/net/schmizz/sshj/keyprovider/CorruptedPublicKeyTest.java b/src/test/java/net/schmizz/sshj/keyprovider/CorruptedPublicKeyTest.java new file mode 100644 index 00000000..cf5a6c23 --- /dev/null +++ b/src/test/java/net/schmizz/sshj/keyprovider/CorruptedPublicKeyTest.java @@ -0,0 +1,73 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package net.schmizz.sshj.keyprovider; + +import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.util.CorruptBase64; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Optional; + +public class CorruptedPublicKeyTest { + private final Path keyRoot = Path.of("src/test/resources"); + + @TempDir + public Path tempDir; + + @ParameterizedTest + @CsvSource({ + "keyformats/ecdsa_opensshv1,", + "keyformats/openssh,", + "keytypes/test_ecdsa_nistp521_2,", + "keytypes/ed25519_protected, sshjtest", + }) + public void corruptedPublicKey(String privateKeyFileName, String passphrase) throws IOException { + Files.createDirectories(tempDir.resolve(privateKeyFileName).getParent()); + Files.copy(keyRoot.resolve(privateKeyFileName), tempDir.resolve(privateKeyFileName)); + + { + String publicKeyText; + try (var reader = new BufferedReader(new FileReader( + keyRoot.resolve(privateKeyFileName + ".pub").toFile()))) { + publicKeyText = reader.readLine(); + } + + String[] parts = publicKeyText.split("\\s+"); + parts[1] = CorruptBase64.corruptBase64(parts[1]); + + try (var writer = new FileWriter(tempDir.resolve(privateKeyFileName + ".pub").toFile())) { + writer.write(String.join(" ", parts)); + } + } + + // Must not throw an exception. + try (var sshClient = new SSHClient()) { + sshClient.loadKeys( + tempDir.resolve(privateKeyFileName).toString(), + Optional.ofNullable(passphrase).map(String::toCharArray).orElse(null) + ).getPublic(); + } + } +} diff --git a/src/test/java/net/schmizz/sshj/keyprovider/PuTTYKeyFileTest.java b/src/test/java/net/schmizz/sshj/keyprovider/PuTTYKeyFileTest.java index cfed5537..3b1d5218 100644 --- a/src/test/java/net/schmizz/sshj/keyprovider/PuTTYKeyFileTest.java +++ b/src/test/java/net/schmizz/sshj/keyprovider/PuTTYKeyFileTest.java @@ -18,15 +18,19 @@ import com.hierynomus.sshj.userauth.keyprovider.OpenSSHKeyV1KeyFile; import net.schmizz.sshj.userauth.keyprovider.PKCS8KeyFile; import net.schmizz.sshj.userauth.keyprovider.PuTTYKeyFile; +import net.schmizz.sshj.util.CorruptBase64; import net.schmizz.sshj.util.UnitTestPasswordFinder; import org.junit.jupiter.api.Test; +import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.StringReader; import java.security.interfaces.RSAPrivateKey; import java.security.interfaces.RSAPublicKey; +import java.util.Objects; +import static java.lang.Math.min; import static org.junit.jupiter.api.Assertions.*; public class PuTTYKeyFileTest { @@ -558,4 +562,61 @@ public void testWrongPassphraseDsa() throws Exception { assertNull(key.getPrivate()); }); } + + @Test + public void corruptedPublicLines() throws Exception { + assertThrows(IOException.class, () -> { + PuTTYKeyFile key = new PuTTYKeyFile(); + key.init(new StringReader(corruptBase64InPuttyKey(ppk2048, "Public-Lines: "))); + key.getPublic(); + }); + } + + @Test + public void corruptedPrivateLines() throws Exception { + assertThrows(IOException.class, () -> { + PuTTYKeyFile key = new PuTTYKeyFile(); + key.init(new StringReader(corruptBase64InPuttyKey(ppk2048, "Private-Lines: "))); + key.getPublic(); + }); + } + + private String corruptBase64InPuttyKey( + @SuppressWarnings("SameParameterValue") String source, + String sectionPrefix + ) throws IOException { + try (var reader = new BufferedReader(new StringReader(source))) { + StringBuilder result = new StringBuilder(); + while (true) { + String line = reader.readLine(); + if (line == null) { + break; + } else if (line.startsWith(sectionPrefix)) { + int base64LineCount = Integer.parseInt(line.substring(sectionPrefix.length())); + StringBuilder base64 = new StringBuilder(); + for (int i = 0; i < base64LineCount; ++i) { + base64.append(Objects.requireNonNull(reader.readLine())); + } + String corruptedBase64 = CorruptBase64.corruptBase64(base64.toString()); + + // 64 is the length of base64 lines in PuTTY keys generated by puttygen. + // It's not clear if it's some standard or not. + // It doesn't match the MIME Base64 standard. + int chunkSize = 64; + + result.append(sectionPrefix); + result.append((corruptedBase64.length() + chunkSize - 1) / chunkSize); + result.append('\n'); + for (int offset = 0; offset < corruptedBase64.length(); offset += chunkSize) { + result.append(corruptedBase64, offset, min(corruptedBase64.length(), offset + chunkSize)); + result.append('\n'); + } + } else { + result.append(line); + result.append('\n'); + } + } + return result.toString(); + } + } } diff --git a/src/test/java/net/schmizz/sshj/util/CorruptBase64.java b/src/test/java/net/schmizz/sshj/util/CorruptBase64.java new file mode 100644 index 00000000..edab8a5b --- /dev/null +++ b/src/test/java/net/schmizz/sshj/util/CorruptBase64.java @@ -0,0 +1,42 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package net.schmizz.sshj.util; + +import net.schmizz.sshj.common.Base64DecodingException; +import net.schmizz.sshj.common.Base64Decoder; + +import java.io.IOException; + +public class CorruptBase64 { + private CorruptBase64() { + } + + public static String corruptBase64(String source) throws IOException { + while (true) { + try { + Base64Decoder.decode(source); + } catch (Base64DecodingException e) { + return source; + } + + if (source.endsWith("=")) { + source = source.substring(0, source.length() - 1); + } + source += "X"; + } + } +}