From 30c402e832dae7c810514e42060e6b9a1f5fc098 Mon Sep 17 00:00:00 2001 From: David Kocher Date: Mon, 23 Oct 2023 21:51:34 +0200 Subject: [PATCH] Add overloaded init methods that take the public key from a stream and properly initialize. Resolves #907. --- .../keyprovider/OpenSSHKeyV1KeyFile.java | 58 ++++++++++++++----- .../keyprovider/BaseFileKeyProvider.java | 29 ++++++---- .../userauth/keyprovider/FileKeyProvider.java | 4 ++ .../userauth/keyprovider/OpenSSHKeyFile.java | 25 ++++++-- .../sshj/keyprovider/OpenSSHKeyFileTest.java | 23 ++++++++ 5 files changed, 108 insertions(+), 31 deletions(-) diff --git a/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java b/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java index 9229fa4af..04de5b409 100644 --- a/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java +++ b/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java @@ -20,38 +20,33 @@ import com.hierynomus.sshj.transport.cipher.BlockCiphers; import com.hierynomus.sshj.transport.cipher.ChachaPolyCiphers; import com.hierynomus.sshj.transport.cipher.GcmCiphers; +import com.hierynomus.sshj.userauth.keyprovider.bcrypt.BCrypt; import net.i2p.crypto.eddsa.EdDSAPrivateKey; import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable; import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec; -import net.schmizz.sshj.common.Buffer; +import net.schmizz.sshj.common.*; import net.schmizz.sshj.common.Buffer.PlainBuffer; -import net.schmizz.sshj.common.ByteArrayUtils; -import net.schmizz.sshj.common.IOUtils; -import net.schmizz.sshj.common.KeyType; -import net.schmizz.sshj.common.SSHRuntimeException; -import net.schmizz.sshj.common.SecurityUtils; import net.schmizz.sshj.transport.cipher.Cipher; import net.schmizz.sshj.userauth.keyprovider.BaseFileKeyProvider; import net.schmizz.sshj.userauth.keyprovider.FileKeyProvider; import net.schmizz.sshj.userauth.keyprovider.KeyFormat; +import net.schmizz.sshj.userauth.password.PasswordFinder; import org.bouncycastle.asn1.nist.NISTNamedCurves; import org.bouncycastle.asn1.x9.X9ECParameters; import org.bouncycastle.jce.spec.ECNamedCurveSpec; -import com.hierynomus.sshj.userauth.keyprovider.bcrypt.BCrypt; import org.bouncycastle.openssl.EncryptionException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.BufferedReader; -import java.io.File; -import java.io.FileReader; -import java.io.IOException; -import java.io.Reader; +import java.io.*; import java.math.BigInteger; import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.nio.charset.StandardCharsets; -import java.security.*; +import java.security.GeneralSecurityException; +import java.security.KeyPair; +import java.security.PrivateKey; +import java.security.PublicKey; import java.security.spec.ECPrivateKeySpec; import java.security.spec.RSAPrivateCrtKeySpec; import java.util.Arrays; @@ -89,6 +84,12 @@ public class OpenSSHKeyV1KeyFile extends BaseFileKeyProvider { private PublicKey pubKey; + @Override + public PublicKey getPublic() + throws IOException { + return pubKey != null ? pubKey : super.getPublic(); + } + public static class Factory implements net.schmizz.sshj.common.Factory.Named { @@ -106,16 +107,41 @@ public String getName() { protected final Logger log = LoggerFactory.getLogger(getClass()); @Override - public void init(File location) { + public void init(File location, PasswordFinder pwdf) { File pubKey = OpenSSHKeyFileUtil.getPublicKeyFile(location); - if (pubKey != null) + if (pubKey != null) { try { initPubKey(new FileReader(pubKey)); } catch (IOException e) { // let super provide both public & private key log.warn("Error reading public key file: {}", e.toString()); } - super.init(location); + } + super.init(location, pwdf); + } + + @Override + public void init(String privateKey, String publicKey, PasswordFinder pwdf) { + if (pubKey != null) { + try { + initPubKey(new StringReader(publicKey)); + } catch (IOException e) { + log.warn("Error reading public key file: {}", e.toString()); + } + } + super.init(privateKey, null, pwdf); + } + + @Override + public void init(Reader privateKey, Reader publicKey, PasswordFinder pwdf) { + if (pubKey != null) { + try { + initPubKey(publicKey); + } catch (IOException e) { + log.warn("Error reading public key file: {}", e.toString()); + } + } + super.init(privateKey, null, pwdf); } @Override diff --git a/src/main/java/net/schmizz/sshj/userauth/keyprovider/BaseFileKeyProvider.java b/src/main/java/net/schmizz/sshj/userauth/keyprovider/BaseFileKeyProvider.java index f4e7580ea..bb67206b3 100644 --- a/src/main/java/net/schmizz/sshj/userauth/keyprovider/BaseFileKeyProvider.java +++ b/src/main/java/net/schmizz/sshj/userauth/keyprovider/BaseFileKeyProvider.java @@ -34,38 +34,47 @@ public abstract class BaseFileKeyProvider implements FileKeyProvider { @Override public void init(Reader location) { - assert location != null; - resource = new PrivateKeyReaderResource(location); + this.init(location, (PasswordFinder) null); } @Override public void init(Reader location, PasswordFinder pwdf) { - init(location); + this.init(location, null, pwdf); + } + + @Override + public void init(Reader privateKey, Reader publicKey) { + this.init(privateKey, publicKey, null); + } + + @Override + public void init(Reader privateKey, Reader publicKey, PasswordFinder pwdf) { + assert publicKey == null; + this.resource = new PrivateKeyReaderResource(privateKey); this.pwdf = pwdf; } @Override public void init(File location) { - assert location != null; - resource = new PrivateKeyFileResource(location.getAbsoluteFile()); + this.init(location, null); } @Override public void init(File location, PasswordFinder pwdf) { - init(location); + this.resource = new PrivateKeyFileResource(location.getAbsoluteFile()); this.pwdf = pwdf; } @Override public void init(String privateKey, String publicKey) { - assert privateKey != null; - assert publicKey == null; - resource = new PrivateKeyStringResource(privateKey); + this.init(privateKey, publicKey, null); } @Override public void init(String privateKey, String publicKey, PasswordFinder pwdf) { - init(privateKey, publicKey); + assert privateKey != null; + assert publicKey == null; + this.resource = new PrivateKeyStringResource(privateKey); this.pwdf = pwdf; } diff --git a/src/main/java/net/schmizz/sshj/userauth/keyprovider/FileKeyProvider.java b/src/main/java/net/schmizz/sshj/userauth/keyprovider/FileKeyProvider.java index 1fcaa2423..4bab4e9f9 100644 --- a/src/main/java/net/schmizz/sshj/userauth/keyprovider/FileKeyProvider.java +++ b/src/main/java/net/schmizz/sshj/userauth/keyprovider/FileKeyProvider.java @@ -30,6 +30,10 @@ public interface FileKeyProvider void init(Reader location); + void init(Reader privateKey, Reader publicKey); + + void init(Reader privateKey, Reader publicKey, PasswordFinder pwdf); + void init(Reader location, PasswordFinder pwdf); void init(String privateKey, String publicKey); diff --git a/src/main/java/net/schmizz/sshj/userauth/keyprovider/OpenSSHKeyFile.java b/src/main/java/net/schmizz/sshj/userauth/keyprovider/OpenSSHKeyFile.java index a7aec1fa7..48bd09708 100644 --- a/src/main/java/net/schmizz/sshj/userauth/keyprovider/OpenSSHKeyFile.java +++ b/src/main/java/net/schmizz/sshj/userauth/keyprovider/OpenSSHKeyFile.java @@ -16,6 +16,7 @@ package net.schmizz.sshj.userauth.keyprovider; import com.hierynomus.sshj.userauth.keyprovider.OpenSSHKeyFileUtil; +import net.schmizz.sshj.userauth.password.PasswordFinder; import java.io.*; import java.security.PublicKey; @@ -54,21 +55,22 @@ public PublicKey getPublic() } @Override - public void init(File location) { + public void init(File location, PasswordFinder pwdf) { // try cert key location first File pubKey = OpenSSHKeyFileUtil.getPublicKeyFile(location); - if (pubKey != null) + if (pubKey != null) { try { initPubKey(new FileReader(pubKey)); } catch (IOException e) { // let super provide both public & private key log.warn("Error reading public key file: {}", e.toString()); } - super.init(location); + } + super.init(location, pwdf); } @Override - public void init(String privateKey, String publicKey) { + public void init(String privateKey, String publicKey, PasswordFinder pwdf) { if (publicKey != null) { try { initPubKey(new StringReader(publicKey)); @@ -77,7 +79,20 @@ public void init(String privateKey, String publicKey) { log.warn("Error reading public key: {}", e.toString()); } } - super.init(privateKey, null); + super.init(privateKey, null, pwdf); + } + + @Override + public void init(Reader privateKey, Reader publicKey, PasswordFinder pwdf) { + if (publicKey != null) { + try { + initPubKey(publicKey); + } catch (IOException e) { + // let super provide both public & private key + log.warn("Error reading public key: {}", e.toString()); + } + } + super.init(privateKey, null, pwdf); } /** diff --git a/src/test/java/net/schmizz/sshj/keyprovider/OpenSSHKeyFileTest.java b/src/test/java/net/schmizz/sshj/keyprovider/OpenSSHKeyFileTest.java index 3e8cbae8d..3268e73cf 100644 --- a/src/test/java/net/schmizz/sshj/keyprovider/OpenSSHKeyFileTest.java +++ b/src/test/java/net/schmizz/sshj/keyprovider/OpenSSHKeyFileTest.java @@ -381,6 +381,18 @@ public void shouldSuccessfullyLoadSignedRSAPublicKey() throws IOException { } + @Test + public void shouldSuccessfullyLoadSignedRSAPublicKeyFromStream() throws IOException { + FileKeyProvider keyFile = new OpenSSHKeyFile(); + keyFile.init(new FileReader("src/test/resources/keytypes/certificate/test_rsa"), + new FileReader("src/test/resources/keytypes/certificate/test_rsa.pub"), + PasswordUtils.createOneOff(correctPassphrase)); + assertNotNull(keyFile.getPrivate()); + PublicKey pubKey = keyFile.getPublic(); + assertNotNull(pubKey); + assertEquals("RSA", pubKey.getAlgorithm()); + } + @Test public void shouldSuccessfullyLoadSignedRSAPublicKeyWithMaxDate() throws IOException { FileKeyProvider keyFile = new OpenSSHKeyFile(); @@ -422,6 +434,17 @@ public void shouldSuccessfullyLoadSignedDSAPublicKey() throws IOException { assertEquals("", certificate.getExtensions().get("permit-pty")); } + @Test + public void shouldSuccessfullyLoadSignedDSAPublicKeyFromStream() throws IOException { + FileKeyProvider keyFile = new OpenSSHKeyFile(); + keyFile.init(new FileReader("src/test/resources/keytypes/certificate/test_dsa"), + new FileReader("src/test/resources/keytypes/certificate/test_dsa-cert.pub"), + PasswordUtils.createOneOff(correctPassphrase)); + assertNotNull(keyFile.getPrivate()); + PublicKey pubKey = keyFile.getPublic(); + assertEquals("DSA", pubKey.getAlgorithm()); + } + /** * Sometimes users copy-pastes private and public keys in text editors. It leads to redundant * spaces and newlines. OpenSSH can easily read such keys, so users expect from SSHJ the same.