diff --git a/src/impl/peerconnection.cpp b/src/impl/peerconnection.cpp index c305384aa..f073535c4 100644 --- a/src/impl/peerconnection.cpp +++ b/src/impl/peerconnection.cpp @@ -46,24 +46,22 @@ static LogCounter const string PemBeginCertificateTag = "-----BEGIN CERTIFICATE-----"; -PeerConnection::PeerConnection(Configuration config_) - : config(std::move(config_)) { +PeerConnection::PeerConnection(Configuration config_) : config(std::move(config_)) { PLOG_VERBOSE << "Creating PeerConnection"; - if (config.certificatePemFile && config.keyPemFile) { std::promise cert; cert.set_value(std::make_shared( - config.certificatePemFile->find(PemBeginCertificateTag) != string::npos - ? Certificate::FromString(*config.certificatePemFile, *config.keyPemFile) - : Certificate::FromFile(*config.certificatePemFile, *config.keyPemFile, - config.keyPemPass.value_or("")))); + config.certificatePemFile->find(PemBeginCertificateTag) != string::npos + ? Certificate::FromString(*config.certificatePemFile, *config.keyPemFile) + : Certificate::FromFile(*config.certificatePemFile, *config.keyPemFile, + config.keyPemPass.value_or("")))); mCertificate = cert.get_future(); } else if (!config.certificatePemFile && !config.keyPemFile) { mCertificate = make_certificate(config.certificateType); } else { throw std::invalid_argument( - "Either none or both certificate and key PEM files must be specified"); + "Either none or both certificate and key PEM files must be specified"); } if (config.portRangeEnd && config.portRangeBegin > config.portRangeEnd) @@ -229,13 +227,15 @@ shared_ptr PeerConnection::initDtlsTransport() { PLOG_VERBOSE << "Starting DTLS transport"; - auto fingerprintAlgorithm = CertificateFingerprint::Algorithm::Sha256; - if (auto remote = remoteDescription(); remote && remote->fingerprint()) { - fingerprintAlgorithm = remote->fingerprint()->algorithm; + CertificateFingerprint::Algorithm fingerprintAlgorithm; + { + std::lock_guard lock(mRemoteDescription); + if (mRemoteDescription && mRemoteDescription->fingerprint()) { + mRemoteFingerprintAlgorithm = mRemoteDescription->fingerprint()->algorithm; + } + fingerprintAlgorithm = mRemoteFingerprintAlgorithm; } - mRemoteFingerprintAlgorithm = fingerprintAlgorithm; - auto lower = std::atomic_load(&mIceTransport); if (!lower) throw std::logic_error("No underlying ICE transport for DTLS transport"); @@ -443,23 +443,24 @@ void PeerConnection::rollbackLocalDescription() { bool PeerConnection::checkFingerprint(const std::string &fingerprint) { std::lock_guard lock(mRemoteDescriptionMutex); - if (!mRemoteDescription || !mRemoteDescription->fingerprint()) + mRemoteFingerprint = fingerprint; + + if (!mRemoteDescription || !mRemoteDescription->fingerprint() || mRemoteFingerprintAlgorithm != mRemoteDescription->fingerprint()->algorithm) return false; - if (config.disableFingerprintVerification) { + if (config.disableFingerprintVerification) { PLOG_VERBOSE << "Skipping fingerprint validation"; - mRemoteFingerprint = fingerprint; return true; } auto expectedFingerprint = mRemoteDescription->fingerprint()->value; if (expectedFingerprint == fingerprint) { PLOG_VERBOSE << "Valid fingerprint \"" << fingerprint << "\""; - mRemoteFingerprint = fingerprint; return true; } - PLOG_ERROR << "Invalid fingerprint \"" << fingerprint << "\", expected \"" << expectedFingerprint << "\""; + PLOG_ERROR << "Invalid fingerprint \"" << fingerprint << "\", expected \"" + << expectedFingerprint << "\""; return false; } @@ -555,7 +556,7 @@ void PeerConnection::forwardMedia([[maybe_unused]] message_ptr message) { void PeerConnection::dispatchMedia([[maybe_unused]] message_ptr message) { #if RTC_ENABLE_MEDIA std::shared_lock lock(mTracksMutex); // read-only - if (mTrackLines.size()==1) { + if (mTrackLines.size() == 1) { if (auto track = mTrackLines.front().lock()) track->incoming(message); return; @@ -742,7 +743,7 @@ void PeerConnection::iterateDataChannels( { std::shared_lock lock(mDataChannelsMutex); // read-only locked.reserve(mDataChannels.size()); - for(auto it = mDataChannels.begin(); it != mDataChannels.end(); ++it) { + for (auto it = mDataChannels.begin(); it != mDataChannels.end(); ++it) { auto channel = it->second.lock(); if (channel && !channel->isClosed()) locked.push_back(std::move(channel)); @@ -811,7 +812,7 @@ void PeerConnection::iterateTracks(std::function track)> { std::shared_lock lock(mTracksMutex); // read-only locked.reserve(mTrackLines.size()); - for(auto it = mTrackLines.begin(); it != mTrackLines.end(); ++it) { + for (auto it = mTrackLines.begin(); it != mTrackLines.end(); ++it) { auto track = it->lock(); if (track && !track->isClosed()) locked.push_back(std::move(track)); @@ -1308,6 +1309,7 @@ void PeerConnection::resetCallbacks() { } CertificateFingerprint PeerConnection::remoteFingerprint() { + std::lock_guard lock(mRemoteDescriptionMutex); if (mRemoteFingerprint) return {CertificateFingerprint{mRemoteFingerprintAlgorithm, *mRemoteFingerprint}}; else diff --git a/src/impl/peerconnection.hpp b/src/impl/peerconnection.hpp index 37e07cbd7..a72f58575 100644 --- a/src/impl/peerconnection.hpp +++ b/src/impl/peerconnection.hpp @@ -98,6 +98,7 @@ struct PeerConnection : std::enable_shared_from_this { bool changeSignalingState(SignalingState newState); void resetCallbacks(); + CertificateFingerprint remoteFingerprint(); // Helper method for asynchronous callback invocation @@ -135,12 +136,16 @@ struct PeerConnection : std::enable_shared_from_this { future_certificate_ptr mCertificate; Processor mProcessor; - optional mLocalDescription, mRemoteDescription; + optional mLocalDescription; optional mCurrentLocalDescription; - mutable std::mutex mLocalDescriptionMutex, mRemoteDescriptionMutex; + mutable std::mutex mLocalDescriptionMutex; - shared_ptr mMediaHandler; + optional mRemoteDescription; + CertificateFingerprint::Algorithm mRemoteFingerprintAlgorithm = CertificateFingerprint::Algorithm::Sha256; + optional mRemoteFingerprint; + mutable std::mutex mRemoteDescriptionMutex; + shared_ptr mMediaHandler; mutable std::shared_mutex mMediaHandlerMutex; shared_ptr mIceTransport; @@ -158,9 +163,6 @@ struct PeerConnection : std::enable_shared_from_this { Queue> mPendingDataChannels; Queue> mPendingTracks; - - CertificateFingerprint::Algorithm mRemoteFingerprintAlgorithm = CertificateFingerprint::Algorithm::Sha256; - optional mRemoteFingerprint; }; } // namespace rtc::impl