diff --git a/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGClient.java b/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGClient.java index 9235ffbf6..3e3412c22 100644 --- a/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGClient.java +++ b/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGClient.java @@ -33,12 +33,12 @@ import org.apache.sshd.common.config.keys.OpenSshCertificate; import org.apache.sshd.common.digest.Digest; import org.apache.sshd.common.kex.AbstractDH; +import org.apache.sshd.common.kex.CurveSizeIndicator; import org.apache.sshd.common.kex.DHFactory; import org.apache.sshd.common.kex.KexProposalOption; import org.apache.sshd.common.kex.KeyEncapsulationMethod; import org.apache.sshd.common.kex.KeyExchange; import org.apache.sshd.common.kex.KeyExchangeFactory; -import org.apache.sshd.common.kex.XDH; import org.apache.sshd.common.keyprovider.KeyPairProvider; import org.apache.sshd.common.session.Session; import org.apache.sshd.common.signature.Signature; @@ -154,14 +154,15 @@ public boolean next(int cmd, Buffer buffer) throws Exception { } else { try { int l = kemClient.getEncapsulationLength(); - if (dh instanceof XDH) { - if (f.length != l + ((XDH) dh).getKeySize()) { + if (dh instanceof CurveSizeIndicator) { + int expectedLength = l + ((CurveSizeIndicator) dh).getByteLength(); + if (f.length != expectedLength) { throw new SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED, - "Wrong F length (should be 1071 bytes): " + f.length); + "Wrong F length (should be " + expectedLength + " bytes): " + f.length); } - } else { + } else if (f.length <= l) { throw new SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED, - "Key encapsulation only supported for XDH"); + "Strange F length: " + f.length + " <= " + l); } dh.setF(Arrays.copyOfRange(f, l, f.length)); Digest keyHash = dh.getHash(); @@ -170,6 +171,7 @@ public boolean next(int cmd, Buffer buffer) throws Exception { keyHash.update(dh.getK()); k = keyHash.digest(); } catch (IllegalArgumentException ex) { + log.error("Key encapsulation error", ex); throw new SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED, "Key encapsulation error: " + ex.getMessage()); } diff --git a/sshd-core/src/main/java/org/apache/sshd/common/BaseBuilder.java b/sshd-core/src/main/java/org/apache/sshd/common/BaseBuilder.java index 6453420fb..c46fb2c65 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/BaseBuilder.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/BaseBuilder.java @@ -88,6 +88,10 @@ public class BaseBuilder DEFAULT_KEX_PREFERENCE = Collections.unmodifiableList( Arrays.asList( BuiltinDHFactories.sntrup761x25519, + BuiltinDHFactories.sntrup761x25519_openssh, + BuiltinDHFactories.mlkem768x25519, + BuiltinDHFactories.mlkem1024nistp384, + BuiltinDHFactories.mlkem768nistp256, BuiltinDHFactories.curve25519, BuiltinDHFactories.curve25519_libssh, BuiltinDHFactories.curve448, diff --git a/sshd-core/src/main/java/org/apache/sshd/common/kex/BuiltinDHFactories.java b/sshd-core/src/main/java/org/apache/sshd/common/kex/BuiltinDHFactories.java index d14dc8fd5..bbe82da0c 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/kex/BuiltinDHFactories.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/kex/BuiltinDHFactories.java @@ -302,6 +302,86 @@ public boolean isSupported() { return MontgomeryCurve.x448.isSupported() && BuiltinDigests.sha512.isSupported(); } }, + /** + * @see PQ/T Hybrid Key + * Exchange in SSH + */ + mlkem768x25519(Constants.MLKEM768_25519_SHA256) { + @Override + public XDH create(Object... params) throws Exception { + if (!GenericUtils.isEmpty(params)) { + throw new IllegalArgumentException("No accepted parameters for " + getName()); + } + return new XDH(MontgomeryCurve.x25519, true) { + + @Override + public KeyEncapsulationMethod getKeyEncapsulation() { + return BuiltinKEM.mlkem768; + } + + @Override + public Digest getHash() throws Exception { + return BuiltinDigests.sha256.create(); + } + }; + } + + @Override + public boolean isSupported() { + return MontgomeryCurve.x25519.isSupported() && BuiltinDigests.sha256.isSupported() + && BuiltinKEM.mlkem768.isSupported(); + } + }, + /** + * @see PQ/T Hybrid Key + * Exchange in SSH + */ + mlkem768nistp256(Constants.MLKEM768_NISTP256_SHA256) { + @Override + public ECDH create(Object... params) throws Exception { + if (!GenericUtils.isEmpty(params)) { + throw new IllegalArgumentException("No accepted parameters for " + getName()); + } + return new ECDH(ECCurves.nistp256, true) { + + @Override + public KeyEncapsulationMethod getKeyEncapsulation() { + return BuiltinKEM.mlkem768; + } + + }; + } + + @Override + public boolean isSupported() { + return ECCurves.nistp256.isSupported() && BuiltinKEM.mlkem768.isSupported(); + } + }, + /** + * @see PQ/T Hybrid Key + * Exchange in SSH + */ + mlkem1024nistp384(Constants.MLKEM1024_NISTP384_SHA384) { + @Override + public ECDH create(Object... params) throws Exception { + if (!GenericUtils.isEmpty(params)) { + throw new IllegalArgumentException("No accepted parameters for " + getName()); + } + return new ECDH(ECCurves.nistp384, true) { + + @Override + public KeyEncapsulationMethod getKeyEncapsulation() { + return BuiltinKEM.mlkem1024; + } + + }; + } + + @Override + public boolean isSupported() { + return ECCurves.nistp384.isSupported() && BuiltinKEM.mlkem1024.isSupported(); + } + }, /** * @see draft-josefsson-ntruprime-ssh-02.html @@ -524,6 +604,9 @@ public static final class Constants { public static final String CURVE25519_SHA256 = "curve25519-sha256"; public static final String CURVE25519_SHA256_LIBSSH = CURVE25519_SHA256 + "@libssh.org"; public static final String CURVE448_SHA512 = "curve448-sha512"; + public static final String MLKEM768_25519_SHA256 = "mlkem768x25519-sha256"; + public static final String MLKEM768_NISTP256_SHA256 = "mlkem768nistp256-sha256"; + public static final String MLKEM1024_NISTP384_SHA384 = "mlkem1024nistp384-sha384"; public static final String SNTRUP761_25519_SHA512 = "sntrup761x25519-sha512"; public static final String SNTRUP761_25519_SHA512_OPENSSH = SNTRUP761_25519_SHA512 + "@openssh.com"; diff --git a/sshd-core/src/main/java/org/apache/sshd/common/kex/BuiltinKEM.java b/sshd-core/src/main/java/org/apache/sshd/common/kex/BuiltinKEM.java index 33997f52f..b8f2af753 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/kex/BuiltinKEM.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/kex/BuiltinKEM.java @@ -26,6 +26,44 @@ */ public enum BuiltinKEM implements KeyEncapsulationMethod, NamedResource, OptionalFeature { + mlkem768("mlkem768") { + + @Override + public Client getClient() { + return MLKEM.getClient(MLKEM.Parameters.mlkem768); + } + + @Override + public Server getServer() { + return MLKEM.getServer(MLKEM.Parameters.mlkem768); + } + + @Override + public boolean isSupported() { + return MLKEM.Parameters.mlkem768.isSupported(); + } + + }, + + mlkem1024("mlkem1024") { + + @Override + public Client getClient() { + return MLKEM.getClient(MLKEM.Parameters.mlkem1024); + } + + @Override + public Server getServer() { + return MLKEM.getServer(MLKEM.Parameters.mlkem1024); + } + + @Override + public boolean isSupported() { + return MLKEM.Parameters.mlkem1024.isSupported(); + } + + }, + sntrup761("sntrup761") { @Override diff --git a/sshd-core/src/main/java/org/apache/sshd/common/kex/CurveSizeIndicator.java b/sshd-core/src/main/java/org/apache/sshd/common/kex/CurveSizeIndicator.java new file mode 100644 index 000000000..682402776 --- /dev/null +++ b/sshd-core/src/main/java/org/apache/sshd/common/kex/CurveSizeIndicator.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd.common.kex; + +/** + * @author Apache MINA SSHD Project + */ +public interface CurveSizeIndicator { + + /** + * Retrieves the length of a point coordinate in bytes. + * + * @return the length + */ + int getByteLength(); +} diff --git a/sshd-core/src/main/java/org/apache/sshd/common/kex/ECDH.java b/sshd-core/src/main/java/org/apache/sshd/common/kex/ECDH.java index bd889b2ef..b4e8f959c 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/kex/ECDH.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/kex/ECDH.java @@ -43,31 +43,41 @@ public class ECDH extends AbstractDH { public static final String KEX_TYPE = "ECDH"; + private final boolean raw; + private ECCurves curve; private ECParameterSpec params; private ECPoint f; - public ECDH() throws Exception { - this((ECParameterSpec) null); - } - public ECDH(String curveName) throws Exception { - this(ValidateUtils.checkNotNull(ECCurves.fromCurveName(curveName), "Unknown curve name: %s", curveName)); + this(curveName, false); } public ECDH(ECCurves curve) throws Exception { - this(Objects.requireNonNull(curve, "No known curve instance provided").getParameters()); - this.curve = curve; + this(curve, false); } public ECDH(ECParameterSpec paramSpec) throws Exception { + this(paramSpec, false); + } + + public ECDH(String curveName, boolean raw) throws Exception { + this(ValidateUtils.checkNotNull(ECCurves.fromCurveName(curveName), "Unknown curve name: %s", curveName), raw); + } + + public ECDH(ECCurves curve, boolean raw) throws Exception { + this(Objects.requireNonNull(curve, "No known curve instance provided").getParameters(), raw); + this.curve = curve; + } + + public ECDH(ECParameterSpec paramSpec, boolean raw) throws Exception { myKeyAgree = SecurityUtils.getKeyAgreement(KEX_TYPE); - params = paramSpec; // do not check for null-ity since in some cases it can be + params = Objects.requireNonNull(paramSpec, "No EC curve parameters provided"); + this.raw = raw; } @Override protected byte[] calculateE() throws Exception { - Objects.requireNonNull(params, "No ECParameterSpec(s)"); KeyPairGenerator myKpairGen = SecurityUtils.getKeyPairGenerator(KeyUtils.EC_ALGORITHM); myKpairGen.initialize(params); @@ -81,22 +91,17 @@ protected byte[] calculateE() throws Exception { @Override protected byte[] calculateK() throws Exception { - Objects.requireNonNull(params, "No ECParameterSpec(s)"); Objects.requireNonNull(f, "Missing 'f' value"); ECPublicKeySpec keySpec = new ECPublicKeySpec(f, params); KeyFactory myKeyFac = SecurityUtils.getKeyFactory(KeyUtils.EC_ALGORITHM); PublicKey yourPubKey = myKeyFac.generatePublic(keySpec); myKeyAgree.doPhase(yourPubKey, true); - return stripLeadingZeroes(myKeyAgree.generateSecret()); - } - - public void setCurveParameters(ECParameterSpec paramSpec) { - params = paramSpec; + byte[] secret = myKeyAgree.generateSecret(); + return raw ? secret : stripLeadingZeroes(secret); } @Override public void setF(byte[] f) { - Objects.requireNonNull(params, "No ECParameterSpec(s)"); Objects.requireNonNull(f, "No 'f' value specified"); this.f = ECCurves.octetStringToEcPoint(f); } @@ -117,12 +122,14 @@ public void putF(Buffer buffer, byte[] f) { @Override public Digest getHash() throws Exception { + return findCurve().getDigestForParams(); + } + + private ECCurves findCurve() { if (curve == null) { - Objects.requireNonNull(params, "No ECParameterSpec(s)"); curve = Objects.requireNonNull(ECCurves.fromCurveParameters(params), "Unknown curve parameters"); } - - return curve.getDigestForParams(); + return curve; } @Override diff --git a/sshd-core/src/main/java/org/apache/sshd/common/kex/KeyEncapsulationMethod.java b/sshd-core/src/main/java/org/apache/sshd/common/kex/KeyEncapsulationMethod.java index a1cb39b70..d03cfb4c9 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/kex/KeyEncapsulationMethod.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/kex/KeyEncapsulationMethod.java @@ -63,6 +63,13 @@ interface Client { */ interface Server { + /** + * Retrieves the required length of the KEM public key, in bytes. + * + * @return the length of the key + */ + int getPublicKeyLength(); + /** * Initializes the KEM with a public key received from a client and prepares an encapsulated secret. * diff --git a/sshd-core/src/main/java/org/apache/sshd/common/kex/MLKEM.java b/sshd-core/src/main/java/org/apache/sshd/common/kex/MLKEM.java new file mode 100644 index 000000000..4ef2847f9 --- /dev/null +++ b/sshd-core/src/main/java/org/apache/sshd/common/kex/MLKEM.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd.common.kex; + +import java.util.Arrays; +import java.util.Objects; + +import org.apache.sshd.common.OptionalFeature; +import org.apache.sshd.common.random.JceRandom; +import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import org.bouncycastle.crypto.SecretWithEncapsulation; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMExtractor; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMGenerator; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMKeyGenerationParameters; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMKeyPairGenerator; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMParameters; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters; + +/** + * An implementation of the mlkem768 key encapsulation method (KEM), formerly known as Kyber, using Bouncy Castle. But + * see appendix C of FIPS 203 ("Differences From the CRYSTALS-Kyber Submission"). + *

+ * NIST specifies that they removed a hash in the encapsulation/decapsulation methods. + *

+ * + * @see NIST FIPS 203 + */ +final class MLKEM { + + enum Parameters implements OptionalFeature { + // For key sizes see NIST FIPS 203, section 8, table 3. Bouncy Castle does not expose the + // public key sizes through its API. (Though they compute them internally.) + mlkem768(1184) { + + @Override + Object getMLKEMParameters() { + return MLKEMParameters.ml_kem_768; + } + }, + mlkem1024(1568) { + + @Override + Object getMLKEMParameters() { + return MLKEMParameters.ml_kem_1024; + } + }; + + private final int publicKeySize; + + Parameters(int publicKeySize) { + this.publicKeySize = publicKeySize; + } + + // Return type is Object on purpose. We want delayed class loading here so that we can use this + // even if Bouncy Castle is not present. (If it isn't, we'll return false from isSupported at + // run-time, and then never use this algorithm.) + abstract Object getMLKEMParameters(); + + int getPublicKeySize() { + return publicKeySize; + } + + @Override + public boolean isSupported() { + try { + // If we get a ClassNotFoundException or some such, we return false. + return getMLKEMParameters() != null; + } catch (Throwable e) { + return false; + } + } + } + + private MLKEM() { + // No instantiation + } + + static KeyEncapsulationMethod.Client getClient(Parameters parameters) { + return new Client(parameters); + } + + static KeyEncapsulationMethod.Server getServer(Parameters parameters) { + return new Server(parameters); + } + + private static class Client implements KeyEncapsulationMethod.Client { + + private final Parameters parameters; + + private MLKEMExtractor extractor; + private MLKEMPublicKeyParameters publicKey; + + Client(Parameters parameters) { + this.parameters = Objects.requireNonNull(parameters, "No MLKEM.Parameters given"); + } + + @Override + public void init() { + MLKEMKeyPairGenerator gen = new MLKEMKeyPairGenerator(); + gen.init(new MLKEMKeyGenerationParameters(JceRandom.getGlobalInstance(), + (MLKEMParameters) parameters.getMLKEMParameters())); + AsymmetricCipherKeyPair pair = gen.generateKeyPair(); + extractor = new MLKEMExtractor((MLKEMPrivateKeyParameters) pair.getPrivate()); + publicKey = (MLKEMPublicKeyParameters) pair.getPublic(); + } + + @Override + public byte[] getPublicKey() { + return publicKey.getEncoded(); + } + + @Override + public byte[] extractSecret(byte[] encapsulated) { + if (encapsulated.length != getEncapsulationLength()) { + throw new IllegalArgumentException("KEM encpsulation has wrong length: " + encapsulated.length); + } + return extractor.extractSecret(encapsulated); + } + + @Override + public int getEncapsulationLength() { + return extractor.getEncapsulationLength(); + } + } + + private static class Server implements KeyEncapsulationMethod.Server { + + private final Parameters parameters; + + private SecretWithEncapsulation value; + + Server(Parameters parameters) { + this.parameters = Objects.requireNonNull(parameters, "No MLKEM.Parameters given"); + } + + @Override + public int getPublicKeyLength() { + return parameters.getPublicKeySize(); + } + + @Override + public byte[] init(byte[] publicKey) { + int pkBytes = getPublicKeyLength(); + if (publicKey.length < pkBytes) { + throw new IllegalArgumentException("KEM public key too short: " + publicKey.length); + } + byte[] pk = Arrays.copyOf(publicKey, pkBytes); + MLKEMGenerator kemGenerator = new MLKEMGenerator(JceRandom.getGlobalInstance()); + MLKEMPublicKeyParameters params = new MLKEMPublicKeyParameters((MLKEMParameters) parameters.getMLKEMParameters(), + pk); + value = kemGenerator.generateEncapsulated(params); + return Arrays.copyOfRange(publicKey, pkBytes, publicKey.length); + } + + @Override + public byte[] getSecret() { + return value.getSecret(); + } + + @Override + public byte[] getEncapsulation() { + return value.getEncapsulation(); + } + } +} diff --git a/sshd-core/src/main/java/org/apache/sshd/common/kex/MontgomeryCurve.java b/sshd-core/src/main/java/org/apache/sshd/common/kex/MontgomeryCurve.java index 2d3bb8a32..8505832ac 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/kex/MontgomeryCurve.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/kex/MontgomeryCurve.java @@ -42,7 +42,7 @@ * @see RFC 7748 * @see RFC 8731 */ -public enum MontgomeryCurve implements KeySizeIndicator, OptionalFeature { +public enum MontgomeryCurve implements KeySizeIndicator, CurveSizeIndicator, OptionalFeature { /** * The "magic" bytes below are the beginning of a DER encoding of the ASN.1 of the SubjectPublicKeyInfo as specified @@ -130,10 +130,15 @@ public String getAlgorithm() { } @Override - public int getKeySize() { + public int getByteLength() { return keySize; } + @Override + public int getKeySize() { + return getByteLength() * Byte.SIZE; + } + @Override public boolean isSupported() { return supported && !SecurityUtils.isFipsMode(); @@ -152,13 +157,13 @@ public KeyPair generateKeyPair() { public byte[] encode(PublicKey key) throws InvalidKeyException { // Per the ASN.1 of SubjectPublicKeyInfo, the key must be the last keySize bytes of the X.509 encoding byte[] subjectPublicKeyInfo = key.getEncoded(); - byte[] result = Arrays.copyOfRange(subjectPublicKeyInfo, subjectPublicKeyInfo.length - getKeySize(), + byte[] result = Arrays.copyOfRange(subjectPublicKeyInfo, subjectPublicKeyInfo.length - getByteLength(), subjectPublicKeyInfo.length); return result; } public PublicKey decode(byte[] key) throws InvalidKeySpecException { - int size = getKeySize(); + int size = getByteLength(); int offset = key.length - size; // We're lenient here and accept a key prefixed by a zero byte. if (offset < 0 || offset > 1) { diff --git a/sshd-core/src/main/java/org/apache/sshd/common/kex/SNTRUP761.java b/sshd-core/src/main/java/org/apache/sshd/common/kex/SNTRUP761.java index 66ef88423..68a20b921 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/kex/SNTRUP761.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/kex/SNTRUP761.java @@ -97,9 +97,14 @@ static class Server implements KeyEncapsulationMethod.Server { super(); } + @Override + public int getPublicKeyLength() { + return SNTRUPrimeParameters.sntrup761.getPublicKeyBytes(); + } + @Override public byte[] init(byte[] publicKey) { - int pkBytes = SNTRUPrimeParameters.sntrup761.getPublicKeyBytes(); + int pkBytes = getPublicKeyLength(); if (publicKey.length < pkBytes) { throw new IllegalArgumentException("KEM public key too short: " + publicKey.length); } diff --git a/sshd-core/src/main/java/org/apache/sshd/common/kex/XDH.java b/sshd-core/src/main/java/org/apache/sshd/common/kex/XDH.java index ff8eedda7..f3c994b8f 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/kex/XDH.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/kex/XDH.java @@ -29,7 +29,7 @@ * * @see RFC 8731 */ -public abstract class XDH extends AbstractDH { +public abstract class XDH extends AbstractDH implements CurveSizeIndicator { protected final MontgomeryCurve curve; protected final boolean raw; @@ -41,8 +41,9 @@ public XDH(MontgomeryCurve curve, boolean raw) throws Exception { myKeyAgree = curve.createKeyAgreement(); } - public int getKeySize() { - return curve.getKeySize(); + @Override + public int getByteLength() { + return curve.getByteLength(); } @Override diff --git a/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGServer.java b/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGServer.java index 2d90fcccf..42dcf906d 100644 --- a/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGServer.java +++ b/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGServer.java @@ -27,12 +27,12 @@ import org.apache.sshd.common.SshException; import org.apache.sshd.common.digest.Digest; import org.apache.sshd.common.kex.AbstractDH; +import org.apache.sshd.common.kex.CurveSizeIndicator; import org.apache.sshd.common.kex.DHFactory; import org.apache.sshd.common.kex.KexProposalOption; import org.apache.sshd.common.kex.KeyEncapsulationMethod; import org.apache.sshd.common.kex.KeyExchange; import org.apache.sshd.common.kex.KeyExchangeFactory; -import org.apache.sshd.common.kex.XDH; import org.apache.sshd.common.session.Session; import org.apache.sshd.common.signature.Signature; import org.apache.sshd.common.util.ValidateUtils; @@ -111,17 +111,20 @@ public boolean next(int cmd, Buffer buffer) throws Exception { try { KeyEncapsulationMethod.Server kemServer = kem.getServer(); - byte[] f = kemServer.init(e); - if (dh instanceof XDH) { - if (f.length != ((XDH) dh).getKeySize()) { + if (dh instanceof CurveSizeIndicator) { + int expectedLength = kemServer.getPublicKeyLength() + ((CurveSizeIndicator) dh).getByteLength(); + if (e.length != expectedLength) { throw new SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED, - "Wrong E length (should be 1190 bytes): " + e.length); + "Wrong E length (should be " + expectedLength + " bytes): " + e.length); } } else { - throw new SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED, - "Key encapsulation only supported for XDH"); + int minLength = kemServer.getPublicKeyLength(); + if (e.length <= minLength) { + throw new SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED, + "Strange E length: " + e.length + " <= " + minLength); + } } - dh.setF(f); + dh.setF(kemServer.init(e)); byte[] dhK = dh.getK(); Digest keyHash = dh.getHash(); keyHash.init(); @@ -134,6 +137,7 @@ public boolean next(int cmd, Buffer buffer) throws Exception { System.arraycopy(dh.getE(), 0, newF, l, dh.getE().length); setF(newF); } catch (IllegalArgumentException ex) { + log.error("Key encapsulation error", ex); throw new SshException(SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED, "Key encapsulation error: " + ex.getMessage()); } diff --git a/sshd-core/src/test/java/org/apache/sshd/common/kex/OpenSshMlKemTest.java b/sshd-core/src/test/java/org/apache/sshd/common/kex/OpenSshMlKemTest.java new file mode 100644 index 000000000..054697044 --- /dev/null +++ b/sshd-core/src/test/java/org/apache/sshd/common/kex/OpenSshMlKemTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd.common.kex; + +import java.security.Security; +import java.util.Collections; + +import org.apache.sshd.client.ClientBuilder; +import org.apache.sshd.client.SshClient; +import org.apache.sshd.client.future.AuthFuture; +import org.apache.sshd.client.session.ClientSession; +import org.apache.sshd.common.keyprovider.FileKeyPairProvider; +import org.apache.sshd.util.test.BaseTestSupport; +import org.apache.sshd.util.test.CommonTestSupportUtils; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.output.Slf4jLogConsumer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.images.builder.ImageFromDockerfile; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.MountableFile; + +/** + * Test ciphers against OpenSSH. Force resetting ciphers every time to verify that they are res-initialized correctly. + * + * @author Apache MINA SSHD Project + */ +@Tag("ContainerTestCase") +@Testcontainers +class OpenSshMlKemTest extends BaseTestSupport { + + private static final Logger LOG = LoggerFactory.getLogger(OpenSshMlKemTest.class); + + // Re-use an already defined key + private static final String TEST_RESOURCES = "org/apache/sshd/common/kex/extensions/client"; + + @Container + GenericContainer sshdContainer = new GenericContainer<>(new ImageFromDockerfile() + .withDockerfileFromBuilder(builder -> builder.from("alpine:20240807") // + .run("apk --update add openssh-server") // Installs OpenSSH 9.9 + // Enable deprecated ciphers + .run("ssh-keygen -A") // Generate multiple host keys + .run("adduser -D bob") // Add a user + .run("echo 'bob:passwordBob' | chpasswd") // Give it a password to unlock the user + .run("mkdir -p /home/bob/.ssh") // Create the SSH config directory + .entryPoint("/entrypoint.sh") // Sets bob as owner of anything under /home/bob and launches sshd + .build())) // + .withCopyFileToContainer(MountableFile.forClasspathResource(TEST_RESOURCES + "/bob_key.pub"), + "/home/bob/.ssh/authorized_keys") + // entrypoint must be executable. Spotbugs doesn't like 0777, so use hex + .withCopyFileToContainer( + MountableFile.forClasspathResource(TEST_RESOURCES + "/entrypoint.sh", 0x1ff), + "/entrypoint.sh") + .waitingFor(Wait.forLogMessage(".*Server listening on :: port 22.*\\n", 1)).withExposedPorts(22) // + .withLogConsumer(new Slf4jLogConsumer(LOG)); + + @BeforeAll + static void registerBouncyCastleProviderIfNecessary() { + if (Security.getProvider(BouncyCastleProvider.PROVIDER_NAME) == null) { + Security.addProvider(new BouncyCastleProvider()); + } + } + + @Test + void mlkem768x25519() throws Exception { + Assumptions.assumeTrue(BuiltinDHFactories.mlkem768x25519.isSupported()); + FileKeyPairProvider keyPairProvider = CommonTestSupportUtils.createTestKeyPairProvider(TEST_RESOURCES + "/bob_key"); + SshClient client = setupTestClient(); + client.setKeyIdentityProvider(keyPairProvider); + client.setKeyExchangeFactories( + Collections.singletonList(ClientBuilder.DH2KEX.apply(BuiltinDHFactories.mlkem768x25519))); + client.start(); + + Integer actualPort = sshdContainer.getMappedPort(22); + String actualHost = sshdContainer.getHost(); + try (ClientSession session = client.connect("bob", actualHost, actualPort).verify(CONNECT_TIMEOUT).getSession()) { + AuthFuture authed = session.auth().verify(AUTH_TIMEOUT); + assertTrue(authed.isDone() && authed.isSuccess()); + } finally { + client.stop(); + } + } +} diff --git a/sshd-mina/pom.xml b/sshd-mina/pom.xml index eabe810b8..a1150d64e 100644 --- a/sshd-mina/pom.xml +++ b/sshd-mina/pom.xml @@ -125,6 +125,7 @@ **/SessionReKeyHostKeyExchangeTest.java **/HostBoundPubKeyAuthTest.java **/OpenSshCipherTest.java + **/OpenSshMlKemTest.java **/PortForwardingWithOpenSshTest.java **/StrictKexInteroperabilityTest.java diff --git a/sshd-netty/pom.xml b/sshd-netty/pom.xml index a0facc369..bab8a7d26 100644 --- a/sshd-netty/pom.xml +++ b/sshd-netty/pom.xml @@ -149,6 +149,7 @@ **/SessionReKeyHostKeyExchangeTest.java **/HostBoundPubKeyAuthTest.java **/OpenSshCipherTest.java + **/OpenSshMlKemTest.java **/PortForwardingWithOpenSshTest.java **/StrictKexInteroperabilityTest.java