Skip to content

Commit

Permalink
Upgrade bouncycastle to v2 (#1022)
Browse files Browse the repository at this point in the history
* update bouncycastle 1.9.0 -> 2.2.1

* clean

* fixed DTLS alert handling
  • Loading branch information
camnewnham authored Dec 10, 2023
1 parent aab4fd9 commit 73d24e1
Show file tree
Hide file tree
Showing 11 changed files with 524 additions and 590 deletions.
2 changes: 1 addition & 1 deletion src/SIPSorcery.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Portable.BouncyCastle" Version="1.9.0" />
<PackageReference Include="BouncyCastle.Cryptography" Version="2.2.1" />
<PackageReference Include="DnsClient" Version="1.7.0" />
<PackageReference Include="SIPSorcery.WebSocketSharp" Version="0.0.1" />
<PackageReference Include="SIPSorceryMedia.Abstractions" Version="1.2.0" />
Expand Down
105 changes: 52 additions & 53 deletions src/net/DtlsSrtp/DtlsSrtpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
using System;
using System.Collections;
using Microsoft.Extensions.Logging;
using Org.BouncyCastle.Crypto;
using Org.BouncyCastle.Crypto.Tls;
using Org.BouncyCastle.Tls;
using Org.BouncyCastle.Security;
using Org.BouncyCastle.Utilities;
using SIPSorcery.Sys;
using System.Collections.Generic;
using Org.BouncyCastle.Crypto;
using Org.BouncyCastle.Tls.Crypto;

namespace SIPSorcery.Net
{
Expand All @@ -36,15 +38,15 @@ internal DtlsSrtpTlsAuthentication(DtlsSrtpClient client)
this.mContext = client.TlsContext;
}

public virtual void NotifyServerCertificate(Certificate serverCertificate)
public virtual void NotifyServerCertificate(TlsServerCertificate serverCertificate)
{
//Console.WriteLine("DTLS client received server certificate chain of length " + chain.Length);
mClient.ServerCertificate = serverCertificate;
}

public virtual TlsCredentials GetClientCredentials(CertificateRequest certificateRequest)
{
byte[] certificateTypes = certificateRequest.CertificateTypes;
short[] certificateTypes = certificateRequest.CertificateTypes;
if (certificateTypes == null || !Arrays.Contains(certificateTypes, ClientCertificateType.rsa_sign))
{
return null;
Expand All @@ -56,11 +58,6 @@ public virtual TlsCredentials GetClientCredentials(CertificateRequest certificat
mClient.mCertificateChain,
mClient.mPrivateKey);
}

public TlsCredentials GetClientCredentials(TlsContext context, CertificateRequest certificateRequest)
{
return GetClientCredentials(certificateRequest);
}
};

public class DtlsSrtpClient : DefaultTlsClient, IDtlsSrtpPeer
Expand All @@ -72,15 +69,15 @@ public class DtlsSrtpClient : DefaultTlsClient, IDtlsSrtpPeer

internal TlsClientContext TlsContext
{
get { return mContext; }
get { return m_context; }
}

protected internal TlsSession mSession;

public bool ForceUseExtendedMasterSecret { get; set; } = true;

//Received from server
public Certificate ServerCertificate { get; internal set; }
public TlsServerCertificate ServerCertificate { get; internal set; }

public RTCDtlsFingerprint Fingerprint { get; private set; }

Expand All @@ -105,36 +102,37 @@ internal TlsClientContext TlsContext
/// </summary>
public event Action<AlertLevelsEnum, AlertTypesEnum, string> OnAlert;

public DtlsSrtpClient() :
this(null, null, null)
public DtlsSrtpClient(TlsCrypto crypto) :
this(crypto, null, null, null)
{
}

public DtlsSrtpClient(System.Security.Cryptography.X509Certificates.X509Certificate2 certificate) :
this(DtlsUtils.LoadCertificateChain(certificate), DtlsUtils.LoadPrivateKeyResource(certificate))
public DtlsSrtpClient(TlsCrypto crypto, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate) :
this(crypto, DtlsUtils.LoadCertificateChain(crypto, certificate), DtlsUtils.LoadPrivateKeyResource(certificate))
{
}

public DtlsSrtpClient(string certificatePath, string keyPath) :
this(new string[] { certificatePath }, keyPath)
public DtlsSrtpClient(TlsCrypto crypto, string certificatePath, string keyPath) :
this(crypto, new string[] { certificatePath }, keyPath)
{
}

public DtlsSrtpClient(string[] certificatesPath, string keyPath) :
this(DtlsUtils.LoadCertificateChain(certificatesPath), DtlsUtils.LoadPrivateKeyResource(keyPath))
public DtlsSrtpClient(TlsCrypto crypto, string[] certificatesPath, string keyPath) :
this(crypto, DtlsUtils.LoadCertificateChain(crypto, certificatesPath), DtlsUtils.LoadPrivateKeyResource(keyPath))
{
}

public DtlsSrtpClient(Certificate certificateChain, AsymmetricKeyParameter privateKey) :
this(certificateChain, privateKey, null)
public DtlsSrtpClient(TlsCrypto crypto, Certificate certificateChain, Org.BouncyCastle.Crypto.AsymmetricKeyParameter privateKey) :
this(crypto, certificateChain, privateKey, null)
{
}

public DtlsSrtpClient(Certificate certificateChain, AsymmetricKeyParameter privateKey, UseSrtpData clientSrtpData)
public DtlsSrtpClient(TlsCrypto crypto, Certificate certificateChain, Org.BouncyCastle.Crypto.AsymmetricKeyParameter privateKey, UseSrtpData clientSrtpData) : base(crypto)
{

if (certificateChain == null && privateKey == null)
{
(certificateChain, privateKey) = DtlsUtils.CreateSelfSignedTlsCert();
(certificateChain, privateKey) = DtlsUtils.CreateSelfSignedTlsCert(crypto);
}

if (clientSrtpData == null)
Expand All @@ -158,31 +156,33 @@ public DtlsSrtpClient(Certificate certificateChain, AsymmetricKeyParameter priva
Fingerprint = certificate != null ? DtlsUtils.Fingerprint(certificate) : null;
}

public DtlsSrtpClient(UseSrtpData clientSrtpData) : this(null, null, clientSrtpData)
public DtlsSrtpClient(TlsCrypto crypto, UseSrtpData clientSrtpData) : this(crypto, null, null, clientSrtpData)
{ }

public override IDictionary GetClientExtensions()

public override IDictionary<int, byte[]> GetClientExtensions()
{
var clientExtensions = base.GetClientExtensions();
if (TlsSRTPUtils.GetUseSrtpExtension(clientExtensions) == null)
if (TlsSrpUtilities.GetSrpExtension(clientExtensions) == null)
{
if (clientExtensions == null)
{
clientExtensions = new Hashtable();
clientExtensions = new Hashtable() as IDictionary<int, byte[]>;
}

TlsSRTPUtils.AddUseSrtpExtension(clientExtensions, clientSrtpData);
TlsSrtpUtilities.AddUseSrtpExtension(clientExtensions, clientSrtpData);
}
return clientExtensions;
}

public override void ProcessServerExtensions(IDictionary clientExtensions)

public override void ProcessServerExtensions(IDictionary<int, byte[]> serverExtensions)
{
base.ProcessServerExtensions(clientExtensions);
base.ProcessServerExtensions(serverExtensions);

// set to some reasonable default value
int chosenProfile = SrtpProtectionProfile.SRTP_AES128_CM_HMAC_SHA1_80;
UseSrtpData clientSrtpData = TlsSRTPUtils.GetUseSrtpExtension(clientExtensions);
clientSrtpData = TlsSrtpUtilities.GetUseSrtpExtension(serverExtensions);

foreach (int profile in clientSrtpData.ProtectionProfiles)
{
Expand Down Expand Up @@ -244,12 +244,12 @@ public override void NotifyHandshakeComplete()
{
base.NotifyHandshakeComplete();

//Copy master Secret (will be inaccessible after this call)
masterSecret = new byte[mContext.SecurityParameters.MasterSecret != null ? mContext.SecurityParameters.MasterSecret.Length : 0];
Buffer.BlockCopy(mContext.SecurityParameters.MasterSecret, 0, masterSecret, 0, masterSecret.Length);

//Prepare Srtp Keys (we must to it here because master key will be cleared after that)
PrepareSrtpSharedSecret();

//Copy master Secret (will be inaccessible after this call)
masterSecret = new byte[m_context.SecurityParameters.MasterSecret != null ? m_context.SecurityParameters.MasterSecret.Length : 0];
Buffer.BlockCopy(m_context.SecurityParameters.MasterSecret.Extract(), 0, masterSecret, 0, masterSecret.Length);
}

public bool IsClient()
Expand All @@ -269,7 +269,7 @@ protected virtual byte[] GetKeyingMaterial(string asciiLabel, byte[] context_val
throw new ArgumentException("must have length less than 2^16 (or be null)", "context_value");
}

SecurityParameters sp = mContext.SecurityParameters;
SecurityParameters sp = m_context.SecurityParameters;
if (!sp.IsExtendedMasterSecret && RequiresExtendedMasterSecret())
{
/*
Expand Down Expand Up @@ -309,7 +309,7 @@ protected virtual byte[] GetKeyingMaterial(string asciiLabel, byte[] context_val
throw new InvalidOperationException("error in calculation of seed for export");
}

return TlsUtilities.PRF(mContext, sp.MasterSecret, asciiLabel, seed, length);
return TlsUtilities.Prf(sp, sp.MasterSecret, asciiLabel, seed, length).Extract();
}

public override bool RequiresExtendedMasterSecret()
Expand Down Expand Up @@ -371,22 +371,12 @@ protected virtual void PrepareSrtpSharedSecret()
Buffer.BlockCopy(sharedSecret, (2 * keyLen + saltLen), srtpMasterServerSalt, 0, saltLen);
}

public override ProtocolVersion ClientVersion
{
get { return ProtocolVersion.DTLSv12; }
}

public override ProtocolVersion MinimumVersion
{
get { return ProtocolVersion.DTLSv10; }
}

public override TlsSession GetSessionToResume()
{
return this.mSession;
}

public override void NotifyAlertRaised(byte alertLevel, byte alertDescription, string message, Exception cause)
public override void NotifyAlertRaised(short alertLevel, short alertDescription, string message, Exception cause)
{
string description = null;
if (message != null)
Expand All @@ -401,7 +391,7 @@ public override void NotifyAlertRaised(byte alertLevel, byte alertDescription, s
string alertMessage = $"{AlertLevel.GetText(alertLevel)}, {AlertDescription.GetText(alertDescription)}";
alertMessage += !string.IsNullOrEmpty(description) ? $", {description}." : ".";

if (alertDescription == AlertTypesEnum.close_notify.GetHashCode())
if (alertDescription == (byte)AlertTypesEnum.close_notify)
{
logger.LogDebug($"DTLS client raised close notification: {alertMessage}");
}
Expand All @@ -418,22 +408,31 @@ public override void NotifyServerVersion(ProtocolVersion serverVersion)

public Certificate GetRemoteCertificate()
{
return ServerCertificate;
return ServerCertificate.Certificate;
}

protected override ProtocolVersion[] GetSupportedVersions()
{
return new ProtocolVersion[]
{
ProtocolVersion.DTLSv10,
ProtocolVersion.DTLSv12
};
}

public override void NotifyAlertReceived(byte alertLevel, byte alertDescription)
public override void NotifyAlertReceived(short alertLevel, short alertDescription)
{
string description = AlertDescription.GetText(alertDescription);

AlertLevelsEnum level = AlertLevelsEnum.Warning;
AlertTypesEnum alertType = AlertTypesEnum.unknown;

if (Enum.IsDefined(typeof(AlertLevelsEnum), alertLevel))
if (Enum.IsDefined(typeof(AlertLevelsEnum), checked((byte)alertLevel)))
{
level = (AlertLevelsEnum)alertLevel;
}

if (Enum.IsDefined(typeof(AlertTypesEnum), alertDescription))
if (Enum.IsDefined(typeof(AlertTypesEnum), checked((byte)alertDescription)))
{
alertType = (AlertTypesEnum)alertDescription;
}
Expand Down
Loading

0 comments on commit 73d24e1

Please sign in to comment.