From 62741ba0728a0cbdbeaedb589a822b85e56a7c6c Mon Sep 17 00:00:00 2001 From: yngrtc Date: Mon, 18 Mar 2024 12:30:57 -0700 Subject: [PATCH] add init webrtc code to rtc --- rtc-shared/src/error.rs | 382 ++- rtc/CHANGELOG.md | 3 + rtc/Cargo.toml | 44 + rtc/src/api/api_test.rs | 25 + .../interceptor_registry_test.rs | 278 +++ rtc/src/api/interceptor_registry/mod.rs | 171 ++ rtc/src/api/media_engine/media_engine_test.rs | 780 ++++++ rtc/src/api/media_engine/mod.rs | 804 +++++++ rtc/src/api/mod.rs | 236 ++ rtc/src/api/setting_engine/mod.rs | 327 +++ .../api/setting_engine/setting_engine_test.rs | 271 +++ rtc/src/data_channel/data_channel_init.rs | 29 + rtc/src/data_channel/data_channel_message.rs | 11 + .../data_channel/data_channel_parameters.rs | 12 + rtc/src/data_channel/data_channel_state.rs | 113 + rtc/src/data_channel/data_channel_test.rs | 1504 ++++++++++++ rtc/src/data_channel/mod.rs | 555 +++++ rtc/src/dtls_transport/dtls_fingerprint.rs | 15 + rtc/src/dtls_transport/dtls_parameters.rs | 11 + rtc/src/dtls_transport/dtls_role.rs | 170 ++ .../dtls_transport/dtls_transport_state.rs | 117 + rtc/src/dtls_transport/dtls_transport_test.rs | 204 ++ rtc/src/dtls_transport/mod.rs | 616 +++++ rtc/src/ice_transport/ice_candidate.rs | 219 ++ rtc/src/ice_transport/ice_candidate_pair.rs | 34 + rtc/src/ice_transport/ice_candidate_type.rs | 119 + rtc/src/ice_transport/ice_connection_state.rs | 142 ++ rtc/src/ice_transport/ice_credential_type.rs | 74 + rtc/src/ice_transport/ice_gatherer.rs | 409 ++++ rtc/src/ice_transport/ice_gatherer_state.rs | 88 + rtc/src/ice_transport/ice_gathering_state.rs | 83 + rtc/src/ice_transport/ice_parameters.rs | 10 + rtc/src/ice_transport/ice_protocol.rs | 79 + rtc/src/ice_transport/ice_role.rs | 73 + rtc/src/ice_transport/ice_server.rs | 173 ++ rtc/src/ice_transport/ice_transport_state.rs | 185 ++ rtc/src/ice_transport/ice_transport_test.rs | 123 + rtc/src/ice_transport/mod.rs | 355 +++ rtc/src/lib.rs | 30 + rtc/src/mux/endpoint.rs | 73 + rtc/src/mux/mod.rs | 156 ++ rtc/src/mux/mux_func.rs | 63 + rtc/src/mux/mux_test.rs | 137 ++ rtc/src/peer_connection/certificate.rs | 295 +++ rtc/src/peer_connection/configuration.rs | 147 ++ rtc/src/peer_connection/mod.rs | 2108 +++++++++++++++++ .../peer_connection/offer_answer_options.rs | 22 + rtc/src/peer_connection/operation/mod.rs | 139 ++ .../operation/operation_test.rs | 47 + .../peer_connection_internal.rs | 1492 ++++++++++++ .../peer_connection/peer_connection_state.rs | 151 ++ .../peer_connection/peer_connection_test.rs | 423 ++++ .../peer_connection/policy/bundle_policy.rs | 94 + .../policy/ice_transport_policy.rs | 77 + rtc/src/peer_connection/policy/mod.rs | 4 + .../peer_connection/policy/rtcp_mux_policy.rs | 79 + .../peer_connection/policy/sdp_semantics.rs | 112 + rtc/src/peer_connection/sdp/mod.rs | 1096 +++++++++ rtc/src/peer_connection/sdp/sdp_test.rs | 1150 +++++++++ rtc/src/peer_connection/sdp/sdp_type.rs | 101 + .../sdp/session_description.rs | 238 ++ rtc/src/peer_connection/signaling_state.rs | 365 +++ .../fmtp/generic/generic_test.rs | 160 ++ rtc/src/rtp_transceiver/fmtp/generic/mod.rs | 65 + .../rtp_transceiver/fmtp/h264/h264_test.rs | 163 ++ rtc/src/rtp_transceiver/fmtp/h264/mod.rs | 102 + rtc/src/rtp_transceiver/fmtp/mod.rs | 58 + rtc/src/rtp_transceiver/mod.rs | 560 +++++ rtc/src/rtp_transceiver/rtp_codec.rs | 163 ++ rtc/src/rtp_transceiver/rtp_receiver/mod.rs | 861 +++++++ .../rtp_receiver/rtp_receiver_test.rs | 231 ++ rtc/src/rtp_transceiver/rtp_sender/mod.rs | 544 +++++ .../rtp_sender/rtp_sender_test.rs | 336 +++ .../rtp_transceiver_direction.rs | 210 ++ .../rtp_transceiver/rtp_transceiver_test.rs | 356 +++ rtc/src/rtp_transceiver/srtp_writer_future.rs | 289 +++ rtc/src/sctp_transport/mod.rs | 441 ++++ .../sctp_transport_capabilities.rs | 7 + .../sctp_transport/sctp_transport_state.rs | 99 + rtc/src/sctp_transport/sctp_transport_test.rs | 43 + rtc/src/stats/mod.rs | 586 +++++ rtc/src/stats/serialize.rs | 28 + rtc/src/stats/stats_collector.rs | 32 + rtc/src/track/mod.rs | 29 + rtc/src/track/track_local/mod.rs | 163 ++ .../track_local/track_local_static_rtp.rs | 238 ++ .../track_local/track_local_static_sample.rs | 299 +++ .../track_local/track_local_static_test.rs | 434 ++++ rtc/src/track/track_remote/mod.rs | 320 +++ 89 files changed, 24245 insertions(+), 15 deletions(-) create mode 100644 rtc/CHANGELOG.md create mode 100644 rtc/src/api/api_test.rs create mode 100644 rtc/src/api/interceptor_registry/interceptor_registry_test.rs create mode 100644 rtc/src/api/interceptor_registry/mod.rs create mode 100644 rtc/src/api/media_engine/media_engine_test.rs create mode 100644 rtc/src/api/media_engine/mod.rs create mode 100644 rtc/src/api/mod.rs create mode 100644 rtc/src/api/setting_engine/mod.rs create mode 100644 rtc/src/api/setting_engine/setting_engine_test.rs create mode 100644 rtc/src/data_channel/data_channel_init.rs create mode 100644 rtc/src/data_channel/data_channel_message.rs create mode 100644 rtc/src/data_channel/data_channel_parameters.rs create mode 100644 rtc/src/data_channel/data_channel_state.rs create mode 100644 rtc/src/data_channel/data_channel_test.rs create mode 100644 rtc/src/data_channel/mod.rs create mode 100644 rtc/src/dtls_transport/dtls_fingerprint.rs create mode 100644 rtc/src/dtls_transport/dtls_parameters.rs create mode 100644 rtc/src/dtls_transport/dtls_role.rs create mode 100644 rtc/src/dtls_transport/dtls_transport_state.rs create mode 100644 rtc/src/dtls_transport/dtls_transport_test.rs create mode 100644 rtc/src/dtls_transport/mod.rs create mode 100644 rtc/src/ice_transport/ice_candidate.rs create mode 100644 rtc/src/ice_transport/ice_candidate_pair.rs create mode 100644 rtc/src/ice_transport/ice_candidate_type.rs create mode 100644 rtc/src/ice_transport/ice_connection_state.rs create mode 100644 rtc/src/ice_transport/ice_credential_type.rs create mode 100644 rtc/src/ice_transport/ice_gatherer.rs create mode 100644 rtc/src/ice_transport/ice_gatherer_state.rs create mode 100644 rtc/src/ice_transport/ice_gathering_state.rs create mode 100644 rtc/src/ice_transport/ice_parameters.rs create mode 100644 rtc/src/ice_transport/ice_protocol.rs create mode 100644 rtc/src/ice_transport/ice_role.rs create mode 100644 rtc/src/ice_transport/ice_server.rs create mode 100644 rtc/src/ice_transport/ice_transport_state.rs create mode 100644 rtc/src/ice_transport/ice_transport_test.rs create mode 100644 rtc/src/ice_transport/mod.rs create mode 100644 rtc/src/mux/endpoint.rs create mode 100644 rtc/src/mux/mod.rs create mode 100644 rtc/src/mux/mux_func.rs create mode 100644 rtc/src/mux/mux_test.rs create mode 100644 rtc/src/peer_connection/certificate.rs create mode 100644 rtc/src/peer_connection/configuration.rs create mode 100644 rtc/src/peer_connection/mod.rs create mode 100644 rtc/src/peer_connection/offer_answer_options.rs create mode 100644 rtc/src/peer_connection/operation/mod.rs create mode 100644 rtc/src/peer_connection/operation/operation_test.rs create mode 100644 rtc/src/peer_connection/peer_connection_internal.rs create mode 100644 rtc/src/peer_connection/peer_connection_state.rs create mode 100644 rtc/src/peer_connection/peer_connection_test.rs create mode 100644 rtc/src/peer_connection/policy/bundle_policy.rs create mode 100644 rtc/src/peer_connection/policy/ice_transport_policy.rs create mode 100644 rtc/src/peer_connection/policy/mod.rs create mode 100644 rtc/src/peer_connection/policy/rtcp_mux_policy.rs create mode 100644 rtc/src/peer_connection/policy/sdp_semantics.rs create mode 100644 rtc/src/peer_connection/sdp/mod.rs create mode 100644 rtc/src/peer_connection/sdp/sdp_test.rs create mode 100644 rtc/src/peer_connection/sdp/sdp_type.rs create mode 100644 rtc/src/peer_connection/sdp/session_description.rs create mode 100644 rtc/src/peer_connection/signaling_state.rs create mode 100644 rtc/src/rtp_transceiver/fmtp/generic/generic_test.rs create mode 100644 rtc/src/rtp_transceiver/fmtp/generic/mod.rs create mode 100644 rtc/src/rtp_transceiver/fmtp/h264/h264_test.rs create mode 100644 rtc/src/rtp_transceiver/fmtp/h264/mod.rs create mode 100644 rtc/src/rtp_transceiver/fmtp/mod.rs create mode 100644 rtc/src/rtp_transceiver/mod.rs create mode 100644 rtc/src/rtp_transceiver/rtp_codec.rs create mode 100644 rtc/src/rtp_transceiver/rtp_receiver/mod.rs create mode 100644 rtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs create mode 100644 rtc/src/rtp_transceiver/rtp_sender/mod.rs create mode 100644 rtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs create mode 100644 rtc/src/rtp_transceiver/rtp_transceiver_direction.rs create mode 100644 rtc/src/rtp_transceiver/rtp_transceiver_test.rs create mode 100644 rtc/src/rtp_transceiver/srtp_writer_future.rs create mode 100644 rtc/src/sctp_transport/mod.rs create mode 100644 rtc/src/sctp_transport/sctp_transport_capabilities.rs create mode 100644 rtc/src/sctp_transport/sctp_transport_state.rs create mode 100644 rtc/src/sctp_transport/sctp_transport_test.rs create mode 100644 rtc/src/stats/mod.rs create mode 100644 rtc/src/stats/serialize.rs create mode 100644 rtc/src/stats/stats_collector.rs create mode 100644 rtc/src/track/mod.rs create mode 100644 rtc/src/track/track_local/mod.rs create mode 100644 rtc/src/track/track_local/track_local_static_rtp.rs create mode 100644 rtc/src/track/track_local/track_local_static_sample.rs create mode 100644 rtc/src/track/track_local/track_local_static_test.rs create mode 100644 rtc/src/track/track_remote/mod.rs diff --git a/rtc-shared/src/error.rs b/rtc-shared/src/error.rs index 619946c..9bb5cce 100644 --- a/rtc-shared/src/error.rs +++ b/rtc-shared/src/error.rs @@ -850,17 +850,6 @@ pub enum Error { #[error("Client transport is not set yet")] ErrClientTransportNotSet, - #[error("{0}")] - Sec1(#[source] sec1::Error), - #[error("{0}")] - P256(#[source] P256Error), - #[error("{0}")] - RcGen(#[from] rcgen::Error), - - /// Error parsing a given PEM string. - #[error("invalid PEM: {0}")] - InvalidPEM(String), - /// The endpoint can no longer create new connections /// /// Indicates that a necessary component of the endpoint has been dropped or otherwise disabled. @@ -1117,13 +1106,366 @@ pub enum Error { #[error("Unknow Protocol")] UnknownProtocol, + //RTC + /// ErrConnectionClosed indicates an operation executed after connection + /// has already been closed. + #[error("connection closed")] + ErrConnectionClosed, + + /// ErrDataChannelNotOpen indicates an operation executed when the data + /// channel is not (yet) open. + #[error("data channel not open")] + ErrDataChannelNotOpen, + + /// ErrCertificateExpired indicates that an x509 certificate has expired. + #[error("x509Cert expired")] + ErrCertificateExpired, + + /// ErrNoTurnCredentials indicates that a TURN server URL was provided + /// without required credentials. + #[error("turn server credentials required")] + ErrNoTurnCredentials, + + /// ErrTurnCredentials indicates that provided TURN credentials are partial + /// or malformed. + #[error("invalid turn server credentials")] + ErrTurnCredentials, + + /// ErrExistingTrack indicates that a track already exists. + #[error("track already exists")] + ErrExistingTrack, + + /// ErrPrivateKeyType indicates that a particular private key encryption + /// chosen to generate a certificate is not supported. + #[error("private key type not supported")] + ErrPrivateKeyType, + + /// ErrModifyingPeerIdentity indicates that an attempt to modify + /// PeerIdentity was made after PeerConnection has been initialized. + #[error("peerIdentity cannot be modified")] + ErrModifyingPeerIdentity, + + /// ErrModifyingCertificates indicates that an attempt to modify + /// Certificates was made after PeerConnection has been initialized. + #[error("certificates cannot be modified")] + ErrModifyingCertificates, + + /// ErrNonCertificate indicates that there is no certificate + #[error("no certificate")] + ErrNonCertificate, + + /// ErrModifyingBundlePolicy indicates that an attempt to modify + /// BundlePolicy was made after PeerConnection has been initialized. + #[error("bundle policy cannot be modified")] + ErrModifyingBundlePolicy, + + /// ErrModifyingRTCPMuxPolicy indicates that an attempt to modify + /// RTCPMuxPolicy was made after PeerConnection has been initialized. + #[error("rtcp mux policy cannot be modified")] + ErrModifyingRTCPMuxPolicy, + + /// ErrModifyingICECandidatePoolSize indicates that an attempt to modify + /// ICECandidatePoolSize was made after PeerConnection has been initialized. + #[error("ice candidate pool size cannot be modified")] + ErrModifyingICECandidatePoolSize, + + /// ErrStringSizeLimit indicates that the character size limit of string is + /// exceeded. The limit is hardcoded to 65535 according to specifications. + #[error("data channel label exceeds size limit")] + ErrStringSizeLimit, + + /// ErrNegotiatedWithoutID indicates that an attempt to create a data channel + /// was made while setting the negotiated option to true without providing + /// the negotiated channel ID. + #[error("negotiated set without channel id")] + ErrNegotiatedWithoutID, + + /// ErrRetransmitsOrPacketLifeTime indicates that an attempt to create a data + /// channel was made with both options max_packet_life_time and max_retransmits + /// set together. Such configuration is not supported by the specification + /// and is mutually exclusive. + #[error("both max_packet_life_time and max_retransmits was set")] + ErrRetransmitsOrPacketLifeTime, + + /// ErrCodecNotFound is returned when a codec search to the Media Engine fails + #[error("codec not found")] + ErrCodecNotFound, + + /// ErrNoRemoteDescription indicates that an operation was rejected because + /// the remote description is not set + #[error("remote description is not set")] + ErrNoRemoteDescription, + + /// ErrIncorrectSDPSemantics indicates that the PeerConnection was configured to + /// generate SDP Answers with different SDP Semantics than the received Offer + #[error("offer SDP semantics does not match configuration")] + ErrIncorrectSDPSemantics, + + /// ErrIncorrectSignalingState indicates that the signaling state of PeerConnection is not correct + #[error("operation can not be run in current signaling state")] + ErrIncorrectSignalingState, + + /// ErrProtocolTooLarge indicates that value given for a DataChannelInit protocol is + /// longer then 65535 bytes + #[error("protocol is larger then 65535 bytes")] + ErrProtocolTooLarge, + + /// ErrSenderNotCreatedByConnection indicates remove_track was called with a RtpSender not created + /// by this PeerConnection + #[error("RtpSender not created by this PeerConnection")] + ErrSenderNotCreatedByConnection, + + /// ErrSenderInitialTrackIdAlreadySet indicates a second call to + /// RtpSender::set_initial_track_id which is not allowed. + #[error("RtpSender's initial_track_id has already been set")] + ErrSenderInitialTrackIdAlreadySet, + + /// ErrSessionDescriptionNoFingerprint indicates set_remote_description was called with a SessionDescription that has no + /// fingerprint + #[error("set_remote_description called with no fingerprint")] + ErrSessionDescriptionNoFingerprint, + + /// ErrSessionDescriptionInvalidFingerprint indicates set_remote_description was called with a SessionDescription that + /// has an invalid fingerprint + #[error("set_remote_description called with an invalid fingerprint")] + ErrSessionDescriptionInvalidFingerprint, + + /// ErrSessionDescriptionConflictingFingerprints indicates set_remote_description was called with a SessionDescription that + /// has an conflicting fingerprints + #[error("set_remote_description called with multiple conflicting fingerprint")] + ErrSessionDescriptionConflictingFingerprints, + + /// ErrSessionDescriptionMissingIceUfrag indicates set_remote_description was called with a SessionDescription that + /// is missing an ice-ufrag value + #[error("set_remote_description called with no ice-ufrag")] + ErrSessionDescriptionMissingIceUfrag, + + /// ErrSessionDescriptionMissingIcePwd indicates set_remote_description was called with a SessionDescription that + /// is missing an ice-pwd value + #[error("set_remote_description called with no ice-pwd")] + ErrSessionDescriptionMissingIcePwd, + + /// ErrSessionDescriptionConflictingIceUfrag indicates set_remote_description was called with a SessionDescription that + /// contains multiple conflicting ice-ufrag values + #[error("set_remote_description called with multiple conflicting ice-ufrag values")] + ErrSessionDescriptionConflictingIceUfrag, + + /// ErrSessionDescriptionConflictingIcePwd indicates set_remote_description was called with a SessionDescription that + /// contains multiple conflicting ice-pwd values + #[error("set_remote_description called with multiple conflicting ice-pwd values")] + ErrSessionDescriptionConflictingIcePwd, + + /// ErrNoSRTPProtectionProfile indicates that the DTLS handshake completed and no SRTP Protection Profile was chosen + #[error("DTLS Handshake completed and no SRTP Protection Profile was chosen")] + ErrNoSRTPProtectionProfile, + + /// ErrFailedToGenerateCertificateFingerprint indicates that we failed to generate the fingerprint used for comparing certificates + #[error("failed to generate certificate fingerprint")] + ErrFailedToGenerateCertificateFingerprint, + + /// ErrNoCodecsAvailable indicates that operation isn't possible because the MediaEngine has no codecs available + #[error("operation failed no codecs are available")] + ErrNoCodecsAvailable, + + /// ErrUnsupportedCodec indicates the remote peer doesn't support the requested codec + #[error("unable to start track, codec is not supported by remote")] + ErrUnsupportedCodec, + + /// ErrSenderWithNoCodecs indicates that a RTPSender was created without any codecs. To send media the MediaEngine needs at + /// least one configured codec. + #[error("unable to populate media section, RTPSender created with no codecs")] + ErrSenderWithNoCodecs, + + /// ErrRTPSenderNewTrackHasIncorrectKind indicates that the new track is of a different kind than the previous/original + #[error("new track must be of the same kind as previous")] + ErrRTPSenderNewTrackHasIncorrectKind, + + /// ErrRTPSenderDataSent indicates that the sequence number transformer tries to be enabled after the data sending began + #[error("Sequence number transformer must be enabled before sending data")] + ErrRTPSenderDataSent, + + /// ErrRTPSenderSeqTransEnabled indicates that the sequence number transformer has been already enabled + #[error("Sequence number transformer has been already enabled")] + ErrRTPSenderSeqTransEnabled, + + /// ErrUnbindFailed indicates that a TrackLocal was not able to be unbind + #[error("failed to unbind TrackLocal from PeerConnection")] + ErrUnbindFailed, + + /// ErrNoPayloaderForCodec indicates that the requested codec does not have a payloader + #[error("the requested codec does not have a payloader")] + ErrNoPayloaderForCodec, + + /// ErrRegisterHeaderExtensionInvalidDirection indicates that a extension was registered with different + /// directions for two different calls. + #[error("a header extension must be registered with the same direction each time")] + ErrRegisterHeaderExtensionInvalidDirection, + + /// ErrRegisterHeaderExtensionNoFreeID indicates that there was no extension ID available which + /// in turn means that all 15 available id(1 through 14) have been used. + #[error("no header extension ID was free to use(this means the maximum of 15 extensions have been registered)")] + ErrRegisterHeaderExtensionNoFreeID, + + /// ErrSimulcastProbeOverflow indicates that too many Simulcast probe streams are in flight and the requested SSRC was ignored + #[error("simulcast probe limit has been reached, new SSRC has been discarded")] + ErrSimulcastProbeOverflow, + + #[error("enable detaching by calling webrtc.DetachDataChannels()")] + ErrDetachNotEnabled, + #[error("datachannel not opened yet, try calling Detach from OnOpen")] + ErrDetachBeforeOpened, + #[error("the DTLS transport has not started yet")] + ErrDtlsTransportNotStarted, + #[error("failed extracting keys from DTLS for SRTP")] + ErrDtlsKeyExtractionFailed, + #[error("failed to start SRTP")] + ErrFailedToStartSRTP, + #[error("failed to start SRTCP")] + ErrFailedToStartSRTCP, + #[error("attempted to start DTLSTransport that is not in new state")] + ErrInvalidDTLSStart, + #[error("peer didn't provide certificate via DTLS")] + ErrNoRemoteCertificate, + #[error("identity provider is not implemented")] + ErrIdentityProviderNotImplemented, + #[error("remote certificate does not match any fingerprint")] + ErrNoMatchingCertificateFingerprint, + #[error("unsupported fingerprint algorithm")] + ErrUnsupportedFingerprintAlgorithm, + #[error("ICE connection not started")] + ErrICEConnectionNotStarted, + #[error("unknown candidate type")] + ErrICECandidateTypeUnknown, + #[error("cannot convert ice.CandidateType into webrtc.ICECandidateType, invalid type")] + ErrICEInvalidConvertCandidateType, + #[error("ICEAgent does not exist")] + ErrICEAgentNotExist, + #[error("unable to convert ICE candidates to ICECandidates")] + ErrICECandidatesConversionFailed, + #[error("unknown ICE Role")] + ErrICERoleUnknown, + #[error("unknown protocol")] + ErrICEProtocolUnknown, + #[error("gatherer not started")] + ErrICEGathererNotStarted, + #[error("unknown network type")] + ErrNetworkTypeUnknown, + #[error("new sdp does not match previous offer")] + ErrSDPDoesNotMatchOffer, + #[error("new sdp does not match previous answer")] + ErrSDPDoesNotMatchAnswer, + #[error("provided value is not a valid enum value of type SDPType")] + ErrPeerConnSDPTypeInvalidValue, + #[error("invalid state change op")] + ErrPeerConnStateChangeInvalid, + #[error("unhandled state change op")] + ErrPeerConnStateChangeUnhandled, + #[error("invalid SDP type supplied to SetLocalDescription()")] + ErrPeerConnSDPTypeInvalidValueSetLocalDescription, + #[error("remoteDescription contained media section without mid value")] + ErrPeerConnRemoteDescriptionWithoutMidValue, + #[error("remoteDescription has not been set yet")] + ErrPeerConnRemoteDescriptionNil, + #[error("single media section has an explicit SSRC")] + ErrPeerConnSingleMediaSectionHasExplicitSSRC, + #[error("could not add transceiver for remote SSRC")] + ErrPeerConnRemoteSSRCAddTransceiver, + #[error("mid RTP Extensions required for Simulcast")] + ErrPeerConnSimulcastMidRTPExtensionRequired, + #[error("stream id RTP Extensions required for Simulcast")] + ErrPeerConnSimulcastStreamIDRTPExtensionRequired, + #[error("incoming SSRC failed Simulcast probing")] + ErrPeerConnSimulcastIncomingSSRCFailed, + #[error("failed collecting stats")] + ErrPeerConnStatsCollectionFailed, + #[error("add_transceiver_from_kind only accepts one RTPTransceiverInit")] + ErrPeerConnAddTransceiverFromKindOnlyAcceptsOne, + #[error("add_transceiver_from_track only accepts one RTPTransceiverInit")] + ErrPeerConnAddTransceiverFromTrackOnlyAcceptsOne, + #[error("add_transceiver_from_kind currently only supports recvonly")] + ErrPeerConnAddTransceiverFromKindSupport, + #[error("add_transceiver_from_track currently only supports sendonly and sendrecv")] + ErrPeerConnAddTransceiverFromTrackSupport, + #[error("TODO set_identity_provider")] + ErrPeerConnSetIdentityProviderNotImplemented, + #[error("write_rtcp failed to open write_stream")] + ErrPeerConnWriteRTCPOpenWriteStream, + #[error("cannot find transceiver with mid")] + ErrPeerConnTransceiverMidNil, + #[error("DTLSTransport must not be nil")] + ErrRTPReceiverDTLSTransportNil, + #[error("Receive has already been called")] + ErrRTPReceiverReceiveAlreadyCalled, + #[error("unable to find stream for Track with SSRC")] + ErrRTPReceiverWithSSRCTrackStreamNotFound, + #[error("no trackStreams found for SSRC")] + ErrRTPReceiverForSSRCTrackStreamNotFound, + #[error("no trackStreams found for RID")] + ErrRTPReceiverForRIDTrackStreamNotFound, + #[error("invalid RTP Receiver transition")] + ErrRTPReceiverStateChangeInvalid, + #[error("Track must not be nil")] + ErrRTPSenderTrackNil, + #[error("RTPSender must not be nil")] + ErrRTPSenderNil, + #[error("RTPReceiver must not be nil")] + ErrRTPReceiverNil, + #[error("DTLSTransport must not be nil")] + ErrRTPSenderDTLSTransportNil, + #[error("Send has already been called")] + ErrRTPSenderSendAlreadyCalled, + #[error("errRTPSenderTrackNil")] + ErrRTPTransceiverCannotChangeMid, + #[error("invalid state change in RTPTransceiver.setSending")] + ErrRTPTransceiverSetSendingInvalidState, + #[error("unsupported codec type by this transceiver")] + ErrRTPTransceiverCodecUnsupported, + #[error("DTLS not established")] + ErrSCTPTransportDTLS, + #[error("add_transceiver_sdp() called with 0 transceivers")] + ErrSDPZeroTransceivers, + #[error("invalid Media Section. Media + DataChannel both enabled")] + ErrSDPMediaSectionMediaDataChanInvalid, + #[error( + "invalid Media Section. Can not have multiple tracks in one MediaSection in UnifiedPlan" + )] + ErrSDPMediaSectionMultipleTrackInvalid, + #[error("set_answering_dtlsrole must DTLSRoleClient or DTLSRoleServer")] + ErrSettingEngineSetAnsweringDTLSRole, + #[error("can't rollback from stable state")] + ErrSignalingStateCannotRollback, + #[error("invalid proposed signaling state transition")] + ErrSignalingStateProposedTransitionInvalid, + #[error("cannot convert to StatsICECandidatePairStateSucceeded invalid ice candidate state")] + ErrStatsICECandidateStateInvalid, + #[error("ICETransport can only be called in ICETransportStateNew")] + ErrICETransportNotInNew, + #[error("bad Certificate PEM format")] + ErrCertificatePEMFormatError, + #[error("SCTP is not established")] + ErrSCTPNotEstablished, + + #[error("DataChannel is not opened")] + ErrClosedPipe, + #[error("Interceptor is not bind")] + ErrInterceptorNotBind, + #[error("excessive retries in CreateOffer")] + ErrExcessiveRetries, + + #[error("not long enough to be a RTP Packet")] + ErrRTPTooShort, + //Third Party Error - //#[error("mpsc send: {0}")] - //MpscSend(String), + #[error("{0}")] + Sec1(#[source] sec1::Error), + #[error("{0}")] + P256(#[source] P256Error), + #[error("{0}")] + RcGen(#[from] rcgen::Error), + #[error("invalid PEM: {0}")] + InvalidPEM(String), #[error("aes gcm: {0}")] AesGcm(#[from] aes_gcm::Error), - //#[error("parse ipnet: {0}")] - //ParseIpnet(#[from] ipnet::AddrParseError), #[error("parse ip: {0}")] ParseIp(#[from] net::AddrParseError), #[error("parse int: {0}")] @@ -1257,3 +1599,13 @@ impl From for Error { Error::Other(e.to_string()) } } + +/// flatten_errs flattens multiple errors into one +pub fn flatten_errs(errs: Vec>) -> Result<()> { + if errs.is_empty() { + Ok(()) + } else { + let errs_strs: Vec = errs.into_iter().map(|e| e.into().to_string()).collect(); + Err(Error::Other(errs_strs.join("\n"))) + } +} diff --git a/rtc/CHANGELOG.md b/rtc/CHANGELOG.md new file mode 100644 index 0000000..dfbb2ba --- /dev/null +++ b/rtc/CHANGELOG.md @@ -0,0 +1,3 @@ +# rtc changelog + +## Unreleased diff --git a/rtc/Cargo.toml b/rtc/Cargo.toml index 8b95d32..3da9cae 100644 --- a/rtc/Cargo.toml +++ b/rtc/Cargo.toml @@ -5,7 +5,51 @@ authors = ["Rain Liu "] edition = "2021" description = "RTC in Rust" license = "MIT/Apache-2.0" +documentation = "https://docs.rs/rtc" +homepage = "https://webrtc.rs" +repository = "https://github.com/webrtc-rs/rtc" +readme = "../README.md" [dependencies] +datachannel = { version = "0.2.0", path = "../rtc-datachannel", package = "rtc-datachannel" } +dtls = { version = "0.2.0", path = "../rtc-dtls", package = "rtc-dtls" } +ice = { version = "0.2.0", path = "../rtc-ice", package = "rtc-ice" } +#TODO: interceptor = { version = "0.0.0", path = "../rtc-interceptor", package = "rtc-interceptor" } +#TODO: mdns = { version = "0.0.0", path = "../mdns", package = "rtc-mdns" } +#TODO: media = { version = "0.0.0", path = "../media", package = "rtc-media" } +rtcp = { version = "0.2.0", path = "../rtc-rtcp", package = "rtc-rtcp" } +rtp = { version = "0.2.0", path = "../rtc-rtp", package = "rtc-rtp" } +sctp = { version = "0.2.0", path = "../rtc-sctp", package = "rtc-sctp" } +sdp = { version = "0.2.0", path = "../rtc-sdp", package = "rtc-sdp" } +shared = { version = "0.2.1", path = "../rtc-shared", package = "rtc-shared", default-features = false, features = ["crypto", "marshal", "replay"] } +srtp = { version = "0.2.0", path = "../rtc-srtp", package = "rtc-srtp" } +stun = { version = "0.2.0", path = "../rtc-stun", package = "rtc-stun" } +turn = { version = "0.1.0", path = "../rtc-turn", package = "rtc-turn" } + +log = "0.4.21" +serde = { version = "1.0.197", features = ["derive"] } +serde_json = "1.0.114" +rand = "0.8.5" +bytes = "1.5.0" +regex = "1.10.3" +url = "2.5.0" +rustls = { version = "0.21", features = ["dangerous_configuration"] } +rcgen = { version = "0.11", features = ["pem", "x509-parser"] } +ring = "0.17.8" +sha2 = "0.10.8" +hex = "0.4.3" +pem = { version = "3.0.3", optional = true } +time = "0.3.34" +cfg-if = "1.0.0" +lazy_static = "1.4.0" +smol_str = { version = "0.2.1", features = ["serde"] } [dev-dependencies] +tokio-test = "0.4.4" +env_logger = "0.10.2" +arc-swap = "1.7.0" +tokio = { version = "1.36.0", features = ["full"] } +waitgroup = "0.1.2" + +[features] +pem = ["dep:pem", "dtls/pem"] diff --git a/rtc/src/api/api_test.rs b/rtc/src/api/api_test.rs new file mode 100644 index 0000000..c6b1aa5 --- /dev/null +++ b/rtc/src/api/api_test.rs @@ -0,0 +1,25 @@ +use super::*; + +#[test] +fn test_new_api() -> Result<()> { + let mut s = SettingEngine::default(); + s.detach_data_channels(); + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + + let api = APIBuilder::new() + .with_setting_engine(s) + .with_media_engine(m) + .build(); + + assert!( + api.setting_engine.detach.data_channels, + "Failed to set settings engine" + ); + assert!( + !api.media_engine.audio_codecs.is_empty(), + "Failed to set media engine" + ); + + Ok(()) +} diff --git a/rtc/src/api/interceptor_registry/interceptor_registry_test.rs b/rtc/src/api/interceptor_registry/interceptor_registry_test.rs new file mode 100644 index 0000000..277bce5 --- /dev/null +++ b/rtc/src/api/interceptor_registry/interceptor_registry_test.rs @@ -0,0 +1,278 @@ +/*TODO: +use super::*; +use crate::api::APIBuilder; +use crate::peer_connection::configuration::RTCConfiguration; + +use bytes::Bytes; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use interceptor::mock::mock_builder::MockBuilder; +use interceptor::mock::mock_interceptor::MockInterceptor; +use interceptor::stream_info::StreamInfo; +use interceptor::{Attributes, Interceptor, RTPWriter, RTPWriterFn}; + +// E2E test of the features of Interceptors +// * Assert an extension can be set on an outbound packet +// * Assert an extension can be read on an outbound packet +// * Assert that attributes set by an interceptor are returned to the Reader +#[tokio::test] +async fn test_peer_connection_interceptor() -> Result<()> { + let create_pc = || async { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + + let mut ir = Registry::new(); + + let BindLocalStreamFn = |info: &StreamInfo, + writer: Arc| + -> Pin< + Box> + Send + Sync>, + > { + let writer2 = Arc::clone(&writer); + Box::pin(async move { + Arc::new(RTPWriterFn(Box::new( + move |in_pkt: &rtp::packet::Packet, + attributes: &Attributes| + -> Pin< + Box< + dyn Future> + + Send + + Sync, + >, + > { + let writer3 = Arc::clone(&writer2); + let a = attributes.clone(); + // set extension on outgoing packet + let mut out_pkt = in_pkt.clone(); + out_pkt.header.extension = true; + out_pkt.header.extension_profile = 0xBEDE; + + Box::pin(async move { + out_pkt + .header + .set_extension(2, Bytes::from_static(b"foo"))?; + //writer3.write(&out_pkt, &a).await + Ok(0) + }) + }, + ))) as Arc + }) + }; + + BindRemoteStreamFn: func(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { + return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + if a == nil { + a = interceptor.Attributes{} + } + + a.Set("attribute", "value") + return reader.Read(b, a) + }) + }, + let mock_builder = Box::new(MockBuilder { + build: + Box::new( + |_: &str| -> std::result::Result< + Arc, + interceptor::Error, + > { + Ok(Arc::new(MockInterceptor { + ..Default::default() + })) + }, + ), + }); + let mock_builder = MockBuilder::new( + |_: &str| -> std::result::Result< + Arc, + interceptor::Error, + > { + Ok(Arc::new(MockInterceptor { + ..Default::default() + })) + }, + ); + ir.add(Box::new(mock_builder)); + + let api = APIBuilder::new() + .with_media_engine(m) + .with_interceptor_registry(ir) + .build(); + api.new_peer_connection(RTCConfiguration::default()).await + }; + + let offerer = create_pc().await?; + let answerer = create_pc().await?; + + track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion") + assert.NoError(t, err) + + _, err = offerer.AddTrack(track) + assert.NoError(t, err) + + seenRTP, seenRTPCancel := context.WithCancel(context.Background()) + answerer.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) { + p, attributes, readErr := track.ReadRTP() + assert.NoError(t, readErr) + + assert.Equal(t, p.Extension, true) + assert.Equal(t, "foo", string(p.GetExtension(2))) + assert.Equal(t, "value", attributes.Get("attribute")) + + seenRTPCancel() + }) + + assert.NoError(t, signalPair(offerer, answerer)) + + func() { + ticker := time.NewTicker(time.Millisecond * 20) + for { + select { + case <-seenRTP.Done(): + return + case <-ticker.C: + assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second})) + } + } + }() + + closePairNow(t, offerer, answerer) + + Ok(()) +} + +func Test_Interceptor_BindUnbind(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + m := &MediaEngine{} + assert.NoError(t, m.RegisterDefaultCodecs()) + + var ( + cntBindRTCPReader uint32 + cntBindRTCPWriter uint32 + cntBindLocalStream uint32 + cntUnbindLocalStream uint32 + cntBindRemoteStream uint32 + cntUnbindRemoteStream uint32 + cntClose uint32 + ) + mockInterceptor := &mock_interceptor.Interceptor{ + BindRTCPReaderFn: func(reader interceptor.RTCPReader) interceptor.RTCPReader { + atomic.AddUint32(&cntBindRTCPReader, 1) + return reader + }, + BindRTCPWriterFn: func(writer interceptor.RTCPWriter) interceptor.RTCPWriter { + atomic.AddUint32(&cntBindRTCPWriter, 1) + return writer + }, + BindLocalStreamFn: func(i *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { + atomic.AddUint32(&cntBindLocalStream, 1) + return writer + }, + UnbindLocalStreamFn: func(i *interceptor.StreamInfo) { + atomic.AddUint32(&cntUnbindLocalStream, 1) + }, + BindRemoteStreamFn: func(i *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { + atomic.AddUint32(&cntBindRemoteStream, 1) + return reader + }, + UnbindRemoteStreamFn: func(i *interceptor.StreamInfo) { + atomic.AddUint32(&cntUnbindRemoteStream, 1) + }, + CloseFn: func() error { + atomic.AddUint32(&cntClose, 1) + return nil + }, + } + ir := &interceptor.Registry{} + ir.Add(&mock_interceptor.Factory{ + NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { return mockInterceptor, nil }, + }) + + sender, receiver, err := NewAPI(WithMediaEngine(m), WithInterceptorRegistry(ir)).newPair(Configuration{}) + assert.NoError(t, err) + + track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion") + assert.NoError(t, err) + + _, err = sender.AddTrack(track) + assert.NoError(t, err) + + receiverReady, receiverReadyFn := context.WithCancel(context.Background()) + receiver.OnTrack(func(track *TrackRemote, _ *RTPReceiver) { + _, _, readErr := track.ReadRTP() + assert.NoError(t, readErr) + receiverReadyFn() + }) + + assert.NoError(t, signalPair(sender, receiver)) + + ticker := time.NewTicker(time.Millisecond * 20) + defer ticker.Stop() + func() { + for { + select { + case <-receiverReady.Done(): + return + case <-ticker.C: + // Send packet to make receiver track actual creates RTPReceiver. + assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0xAA}, Duration: time.Second})) + } + } + }() + + closePairNow(t, sender, receiver) + + // Bind/UnbindLocal/RemoteStream should be called from one side. + if cnt := atomic.LoadUint32(&cntBindLocalStream); cnt != 1 { + t.Errorf("BindLocalStreamFn is expected to be called once, but called %d times", cnt) + } + if cnt := atomic.LoadUint32(&cntUnbindLocalStream); cnt != 1 { + t.Errorf("UnbindLocalStreamFn is expected to be called once, but called %d times", cnt) + } + if cnt := atomic.LoadUint32(&cntBindRemoteStream); cnt != 1 { + t.Errorf("BindRemoteStreamFn is expected to be called once, but called %d times", cnt) + } + if cnt := atomic.LoadUint32(&cntUnbindRemoteStream); cnt != 1 { + t.Errorf("UnbindRemoteStreamFn is expected to be called once, but called %d times", cnt) + } + + // BindRTCPWriter/Reader and Close should be called from both side. + if cnt := atomic.LoadUint32(&cntBindRTCPWriter); cnt != 2 { + t.Errorf("BindRTCPWriterFn is expected to be called twice, but called %d times", cnt) + } + if cnt := atomic.LoadUint32(&cntBindRTCPReader); cnt != 2 { + t.Errorf("BindRTCPReaderFn is expected to be called twice, but called %d times", cnt) + } + if cnt := atomic.LoadUint32(&cntClose); cnt != 2 { + t.Errorf("CloseFn is expected to be called twice, but called %d times", cnt) + } +} + +func Test_InterceptorRegistry_Build(t *testing.T) { + registryBuildCount := 0 + + ir := &interceptor.Registry{} + ir.Add(&mock_interceptor.Factory{ + NewInterceptorFn: func(_ string) (interceptor.Interceptor, error) { + registryBuildCount++ + return &interceptor.NoOp{}, nil + }, + }) + + peerConnectionA, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + peerConnectionB, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + assert.Equal(t, 2, registryBuildCount) + closePairNow(t, peerConnectionA, peerConnectionB) +} +*/ diff --git a/rtc/src/api/interceptor_registry/mod.rs b/rtc/src/api/interceptor_registry/mod.rs new file mode 100644 index 0000000..ca4f549 --- /dev/null +++ b/rtc/src/api/interceptor_registry/mod.rs @@ -0,0 +1,171 @@ +#[cfg(test)] +mod interceptor_registry_test; + +use interceptor::nack::generator::Generator; +use interceptor::nack::responder::Responder; +use interceptor::registry::Registry; +use interceptor::report::receiver::ReceiverReport; +use interceptor::report::sender::SenderReport; +use interceptor::twcc::receiver::Receiver; +use interceptor::twcc::sender::Sender; + +use crate::api::media_engine::MediaEngine; +use crate::error::Result; +use crate::rtp_transceiver::rtp_codec::{RTCRtpHeaderExtensionCapability, RTPCodecType}; +use crate::rtp_transceiver::{RTCPFeedback, TYPE_RTCP_FB_TRANSPORT_CC}; + +/// register_default_interceptors will register some useful interceptors. +/// If you want to customize which interceptors are loaded, you should copy the +/// code from this method and remove unwanted interceptors. +pub fn register_default_interceptors( + mut registry: Registry, + media_engine: &mut MediaEngine, +) -> Result { + registry = configure_nack(registry, media_engine); + + registry = configure_rtcp_reports(registry); + + registry = configure_twcc_receiver_only(registry, media_engine)?; + + Ok(registry) +} + +/// configure_rtcp_reports will setup everything necessary for generating Sender and Receiver Reports +pub fn configure_rtcp_reports(mut registry: Registry) -> Registry { + let receiver = Box::new(ReceiverReport::builder()); + let sender = Box::new(SenderReport::builder()); + registry.add(receiver); + registry.add(sender); + registry +} + +/// configure_nack will setup everything necessary for handling generating/responding to nack messages. +pub fn configure_nack(mut registry: Registry, media_engine: &mut MediaEngine) -> Registry { + media_engine.register_feedback( + RTCPFeedback { + typ: "nack".to_owned(), + parameter: "".to_owned(), + }, + RTPCodecType::Video, + ); + media_engine.register_feedback( + RTCPFeedback { + typ: "nack".to_owned(), + parameter: "pli".to_owned(), + }, + RTPCodecType::Video, + ); + + let generator = Box::new(Generator::builder()); + let responder = Box::new(Responder::builder()); + registry.add(responder); + registry.add(generator); + registry +} + +/// configure_twcc will setup everything necessary for adding +/// a TWCC header extension to outgoing RTP packets and generating TWCC reports. +pub fn configure_twcc(mut registry: Registry, media_engine: &mut MediaEngine) -> Result { + media_engine.register_feedback( + RTCPFeedback { + typ: TYPE_RTCP_FB_TRANSPORT_CC.to_owned(), + ..Default::default() + }, + RTPCodecType::Video, + ); + media_engine.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: sdp::extmap::TRANSPORT_CC_URI.to_owned(), + }, + RTPCodecType::Video, + None, + )?; + + media_engine.register_feedback( + RTCPFeedback { + typ: TYPE_RTCP_FB_TRANSPORT_CC.to_owned(), + ..Default::default() + }, + RTPCodecType::Audio, + ); + media_engine.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: sdp::extmap::TRANSPORT_CC_URI.to_owned(), + }, + RTPCodecType::Audio, + None, + )?; + + let sender = Box::new(Sender::builder()); + let receiver = Box::new(Receiver::builder()); + registry.add(sender); + registry.add(receiver); + Ok(registry) +} + +/// configure_twcc_sender will setup everything necessary for adding +/// a TWCC header extension to outgoing RTP packets. This will allow the remote peer to generate TWCC reports. +pub fn configure_twcc_sender_only( + mut registry: Registry, + media_engine: &mut MediaEngine, +) -> Result { + media_engine.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: sdp::extmap::TRANSPORT_CC_URI.to_owned(), + }, + RTPCodecType::Video, + None, + )?; + + media_engine.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: sdp::extmap::TRANSPORT_CC_URI.to_owned(), + }, + RTPCodecType::Audio, + None, + )?; + + let sender = Box::new(Sender::builder()); + registry.add(sender); + Ok(registry) +} + +/// configure_twcc_receiver will setup everything necessary for generating TWCC reports. +pub fn configure_twcc_receiver_only( + mut registry: Registry, + media_engine: &mut MediaEngine, +) -> Result { + media_engine.register_feedback( + RTCPFeedback { + typ: TYPE_RTCP_FB_TRANSPORT_CC.to_owned(), + ..Default::default() + }, + RTPCodecType::Video, + ); + media_engine.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: sdp::extmap::TRANSPORT_CC_URI.to_owned(), + }, + RTPCodecType::Video, + None, + )?; + + media_engine.register_feedback( + RTCPFeedback { + typ: TYPE_RTCP_FB_TRANSPORT_CC.to_owned(), + ..Default::default() + }, + RTPCodecType::Audio, + ); + media_engine.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: sdp::extmap::TRANSPORT_CC_URI.to_owned(), + }, + RTPCodecType::Audio, + None, + )?; + + let receiver = Box::new(Receiver::builder()); + registry.add(receiver); + Ok(registry) +} diff --git a/rtc/src/api/media_engine/media_engine_test.rs b/rtc/src/api/media_engine/media_engine_test.rs new file mode 100644 index 0000000..cbf3b50 --- /dev/null +++ b/rtc/src/api/media_engine/media_engine_test.rs @@ -0,0 +1,780 @@ +use std::io::Cursor; + +use regex::Regex; + +use super::*; +use crate::api::media_engine::MIME_TYPE_OPUS; +use crate::api::APIBuilder; +use crate::peer_connection::configuration::RTCConfiguration; + +#[tokio::test] +async fn test_opus_case() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let pc = api.new_peer_connection(RTCConfiguration::default()).await?; + pc.add_transceiver_from_kind(RTPCodecType::Audio, None) + .await?; + + let offer = pc.create_offer(None).await?; + + let re = Regex::new(r"(?m)^a=rtpmap:\d+ opus/48000/2").unwrap(); + assert!(re.is_match(offer.sdp.as_str())); + + pc.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_video_case() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let pc = api.new_peer_connection(RTCConfiguration::default()).await?; + pc.add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + let offer = pc.create_offer(None).await?; + + let re = Regex::new(r"(?m)^a=rtpmap:\d+ H264/90000").unwrap(); + assert!(re.is_match(offer.sdp.as_str())); + let re = Regex::new(r"(?m)^a=rtpmap:\d+ VP8/90000").unwrap(); + assert!(re.is_match(offer.sdp.as_str())); + let re = Regex::new(r"(?m)^a=rtpmap:\d+ VP9/90000").unwrap(); + assert!(re.is_match(offer.sdp.as_str())); + + pc.close().await?; + + Ok(()) +} + +#[tokio::test] +async fn test_media_engine_remote_description() -> Result<()> { + let must_parse = |raw: &str| -> Result { + let mut reader = Cursor::new(raw.as_bytes()); + Ok(SessionDescription::unmarshal(&mut reader)?) + }; + + //"No Media" + { + const NO_MEDIA: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +"; + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + m.update_from_remote_description(&must_parse(NO_MEDIA)?) + .await?; + + assert!(!m.negotiated_video.load(Ordering::SeqCst)); + assert!(!m.negotiated_audio.load(Ordering::SeqCst)); + } + + //"Enable Opus" + { + const OPUS_SAME_PAYLOAD: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=audio 9 UDP/TLS/RTP/SAVPF 111 +a=rtpmap:111 opus/48000/2 +a=fmtp:111 minptime=10; useinbandfec=1 +"; + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + m.update_from_remote_description(&must_parse(OPUS_SAME_PAYLOAD)?) + .await?; + + assert!(!m.negotiated_video.load(Ordering::SeqCst)); + assert!(m.negotiated_audio.load(Ordering::SeqCst)); + + let (opus_codec, _) = m.get_codec_by_payload(111).await?; + assert_eq!(opus_codec.capability.mime_type, MIME_TYPE_OPUS); + } + + //"Change Payload Type" + { + const OPUS_SAME_PAYLOAD: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=audio 9 UDP/TLS/RTP/SAVPF 112 +a=rtpmap:112 opus/48000/2 +a=fmtp:112 minptime=10; useinbandfec=1 +"; + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + m.update_from_remote_description(&must_parse(OPUS_SAME_PAYLOAD)?) + .await?; + + assert!(!m.negotiated_video.load(Ordering::SeqCst)); + assert!(m.negotiated_audio.load(Ordering::SeqCst)); + + let result = m.get_codec_by_payload(111).await; + assert!(result.is_err()); + + let (opus_codec, _) = m.get_codec_by_payload(112).await?; + assert_eq!(opus_codec.capability.mime_type, MIME_TYPE_OPUS); + } + + //"Case Insensitive" + { + const OPUS_UPCASE: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=audio 9 UDP/TLS/RTP/SAVPF 111 +a=rtpmap:111 OPUS/48000/2 +a=fmtp:111 minptime=10; useinbandfec=1 +"; + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + m.update_from_remote_description(&must_parse(OPUS_UPCASE)?) + .await?; + + assert!(!m.negotiated_video.load(Ordering::SeqCst)); + assert!(m.negotiated_audio.load(Ordering::SeqCst)); + + let (opus_codec, _) = m.get_codec_by_payload(111).await?; + assert_eq!(opus_codec.capability.mime_type, "audio/OPUS"); + } + + //"Handle different fmtp" + { + const OPUS_NO_FMTP: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=audio 9 UDP/TLS/RTP/SAVPF 111 +a=rtpmap:111 opus/48000/2 +"; + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + m.update_from_remote_description(&must_parse(OPUS_NO_FMTP)?) + .await?; + + assert!(!m.negotiated_video.load(Ordering::SeqCst)); + assert!(m.negotiated_audio.load(Ordering::SeqCst)); + + let (opus_codec, _) = m.get_codec_by_payload(111).await?; + assert_eq!(opus_codec.capability.mime_type, MIME_TYPE_OPUS); + } + + //"Header Extensions" + { + const HEADER_EXTENSIONS: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=audio 9 UDP/TLS/RTP/SAVPF 111 +a=extmap:7 urn:ietf:params:rtp-hdrext:sdes:mid +a=extmap:5 urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id +a=rtpmap:111 opus/48000/2 +"; + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + for extension in [ + "urn:ietf:params:rtp-hdrext:sdes:mid", + "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id", + ] { + m.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: extension.to_owned(), + }, + RTPCodecType::Audio, + None, + )?; + } + + m.update_from_remote_description(&must_parse(HEADER_EXTENSIONS)?) + .await?; + + assert!(!m.negotiated_video.load(Ordering::SeqCst)); + assert!(m.negotiated_audio.load(Ordering::SeqCst)); + + let (abs_id, abs_audio_enabled, abs_video_enabled) = m + .get_header_extension_id(RTCRtpHeaderExtensionCapability { + uri: sdp::extmap::ABS_SEND_TIME_URI.to_owned(), + }) + .await; + assert_eq!(abs_id, 0); + assert!(!abs_audio_enabled); + assert!(!abs_video_enabled); + + let (mid_id, mid_audio_enabled, mid_video_enabled) = m + .get_header_extension_id(RTCRtpHeaderExtensionCapability { + uri: sdp::extmap::SDES_MID_URI.to_owned(), + }) + .await; + assert_eq!(mid_id, 7); + assert!(mid_audio_enabled); + assert!(!mid_video_enabled); + } + + //"Prefers exact codec matches" + { + const PROFILE_LEVELS: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=video 60323 UDP/TLS/RTP/SAVPF 96 98 +a=rtpmap:96 H264/90000 +a=fmtp:96 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=640c1f +a=rtpmap:98 H264/90000 +a=fmtp:98 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f +"; + let mut m = MediaEngine::default(); + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_H264.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f" + .to_string(), + rtcp_feedback: vec![], + }, + payload_type: 127, + ..Default::default() + }, + RTPCodecType::Video, + )?; + m.update_from_remote_description(&must_parse(PROFILE_LEVELS)?) + .await?; + + assert!(m.negotiated_video.load(Ordering::SeqCst)); + assert!(!m.negotiated_audio.load(Ordering::SeqCst)); + + let (supported_h264, _) = m.get_codec_by_payload(98).await?; + assert_eq!(supported_h264.capability.mime_type, MIME_TYPE_H264); + + assert!(m.get_codec_by_payload(96).await.is_err()); + } + + //"Does not match when fmtpline is set and does not match" + { + const PROFILE_LEVELS: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=video 60323 UDP/TLS/RTP/SAVPF 96 98 +a=rtpmap:96 H264/90000 +a=fmtp:96 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=640c1f +"; + let mut m = MediaEngine::default(); + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_H264.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f" + .to_string(), + rtcp_feedback: vec![], + }, + payload_type: 127, + ..Default::default() + }, + RTPCodecType::Video, + )?; + assert!(m + .update_from_remote_description(&must_parse(PROFILE_LEVELS)?) + .await + .is_err()); + + assert!(m.get_codec_by_payload(96).await.is_err()); + } + + //"Matches when fmtpline is not set in offer, but exists in mediaengine" + { + const PROFILE_LEVELS: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=video 60323 UDP/TLS/RTP/SAVPF 96 +a=rtpmap:96 VP9/90000 +"; + let mut m = MediaEngine::default(); + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP9.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "profile-id=0".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 98, + ..Default::default() + }, + RTPCodecType::Video, + )?; + + m.update_from_remote_description(&must_parse(PROFILE_LEVELS)?) + .await?; + + assert!(m.negotiated_video.load(Ordering::SeqCst)); + + m.get_codec_by_payload(96).await?; + } + + //"Matches when fmtpline exists in neither" + { + const PROFILE_LEVELS: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=video 60323 UDP/TLS/RTP/SAVPF 96 +a=rtpmap:96 VP8/90000 +"; + let mut m = MediaEngine::default(); + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 96, + ..Default::default() + }, + RTPCodecType::Video, + )?; + + m.update_from_remote_description(&must_parse(PROFILE_LEVELS)?) + .await?; + + assert!(m.negotiated_video.load(Ordering::SeqCst)); + + m.get_codec_by_payload(96).await?; + } + + //"Matches when rtx apt for exact match codec" + { + const PROFILE_LEVELS: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=video 60323 UDP/TLS/RTP/SAVPF 94 96 97 +a=rtpmap:94 VP8/90000 +a=rtpmap:96 VP9/90000 +a=fmtp:96 profile-id=2 +a=rtpmap:97 rtx/90000 +a=fmtp:97 apt=96 +"; + let mut m = MediaEngine::default(); + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 94, + ..Default::default() + }, + RTPCodecType::Video, + )?; + + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP9.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "profile-id=2".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 96, + ..Default::default() + }, + RTPCodecType::Video, + )?; + + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: "video/rtx".to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "apt=96".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 97, + ..Default::default() + }, + RTPCodecType::Video, + )?; + + m.update_from_remote_description(&must_parse(PROFILE_LEVELS)?) + .await?; + + assert!(m.negotiated_video.load(Ordering::SeqCst)); + + m.get_codec_by_payload(97).await?; + } + + //"Matches when rtx apt for partial match codec" + { + const PROFILE_LEVELS: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=video 60323 UDP/TLS/RTP/SAVPF 94 96 97 +a=rtpmap:94 VP8/90000 +a=rtpmap:96 VP9/90000 +a=fmtp:96 profile-id=2 +a=rtpmap:97 rtx/90000 +a=fmtp:97 apt=96 +"; + let mut m = MediaEngine::default(); + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 94, + ..Default::default() + }, + RTPCodecType::Video, + )?; + + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP9.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "profile-id=1".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 96, + ..Default::default() + }, + RTPCodecType::Video, + )?; + + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: "video/rtx".to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "apt=96".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 97, + ..Default::default() + }, + RTPCodecType::Video, + )?; + + m.update_from_remote_description(&must_parse(PROFILE_LEVELS)?) + .await?; + + assert!(m.negotiated_video.load(Ordering::SeqCst)); + + if let Err(err) = m.get_codec_by_payload(97).await { + assert_eq!(err, Error::ErrCodecNotFound); + } else { + panic!(); + } + } + + Ok(()) +} + +#[tokio::test] +async fn test_media_engine_header_extension_direction() -> Result<()> { + let register_codec = |m: &mut MediaEngine| -> Result<()> { + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_OPUS.to_owned(), + clock_rate: 48000, + channels: 0, + sdp_fmtp_line: "".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 111, + ..Default::default() + }, + RTPCodecType::Audio, + ) + }; + + //"No Direction" + { + let mut m = MediaEngine::default(); + register_codec(&mut m)?; + m.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: "webrtc-header-test".to_owned(), + }, + RTPCodecType::Audio, + None, + )?; + + let params = + m.get_rtp_parameters_by_kind(RTPCodecType::Audio, RTCRtpTransceiverDirection::Recvonly); + + assert_eq!(params.header_extensions.len(), 1); + } + + //"Same Direction" + { + let mut m = MediaEngine::default(); + register_codec(&mut m)?; + m.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: "webrtc-header-test".to_owned(), + }, + RTPCodecType::Audio, + Some(RTCRtpTransceiverDirection::Recvonly), + )?; + + let params = + m.get_rtp_parameters_by_kind(RTPCodecType::Audio, RTCRtpTransceiverDirection::Recvonly); + + assert_eq!(params.header_extensions.len(), 1); + } + + //"Different Direction" + { + let mut m = MediaEngine::default(); + register_codec(&mut m)?; + m.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: "webrtc-header-test".to_owned(), + }, + RTPCodecType::Audio, + Some(RTCRtpTransceiverDirection::Sendonly), + )?; + + let params = + m.get_rtp_parameters_by_kind(RTPCodecType::Audio, RTCRtpTransceiverDirection::Recvonly); + + assert_eq!(params.header_extensions.len(), 0); + } + + //"No direction and inactive" + { + let mut m = MediaEngine::default(); + register_codec(&mut m)?; + m.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: "webrtc-header-test".to_owned(), + }, + RTPCodecType::Audio, + None, + )?; + + let params = + m.get_rtp_parameters_by_kind(RTPCodecType::Audio, RTCRtpTransceiverDirection::Inactive); + + assert_eq!(params.header_extensions.len(), 1); + } + + Ok(()) +} + +/// If a user attempts to register a codec twice we should just discard duplicate calls +#[tokio::test] +async fn test_media_engine_double_register() -> Result<()> { + let mut m = MediaEngine::default(); + + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_OPUS.to_owned(), + clock_rate: 48000, + channels: 0, + sdp_fmtp_line: "".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 111, + ..Default::default() + }, + RTPCodecType::Audio, + )?; + + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_OPUS.to_owned(), + clock_rate: 48000, + channels: 0, + sdp_fmtp_line: "".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 111, + ..Default::default() + }, + RTPCodecType::Audio, + )?; + + assert_eq!(m.audio_codecs.len(), 1); + Ok(()) +} + +async fn validate(m: &MediaEngine) -> Result<()> { + m.update_header_extension(2, "test-extension", RTPCodecType::Audio) + .await?; + + let (id, audio_negotiated, video_negotiated) = m + .get_header_extension_id(RTCRtpHeaderExtensionCapability { + uri: "test-extension".to_owned(), + }) + .await; + assert_eq!(id, 2); + assert!(audio_negotiated); + assert!(!video_negotiated); + + Ok(()) +} + +/// The cloned MediaEngine instance should be able to update negotiated header extensions. +#[tokio::test] +async fn test_update_header_extension_to_cloned_media_engine() -> Result<()> { + let mut m = MediaEngine::default(); + + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_OPUS.to_owned(), + clock_rate: 48000, + channels: 0, + sdp_fmtp_line: "".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 111, + ..Default::default() + }, + RTPCodecType::Audio, + )?; + + m.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: "test-extension".to_owned(), + }, + RTPCodecType::Audio, + None, + )?; + + validate(&m).await?; + validate(&m.clone_to()).await?; + + Ok(()) +} + +#[tokio::test] +async fn test_extension_id_collision() -> Result<()> { + let must_parse = |raw: &str| -> Result { + let mut reader = Cursor::new(raw.as_bytes()); + Ok(SessionDescription::unmarshal(&mut reader)?) + }; + + const HEADER_EXTENSIONS: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=audio 9 UDP/TLS/RTP/SAVPF 111 +a=extmap:7 urn:ietf:params:rtp-hdrext:sdes:mid +a=extmap:1 urn:ietf:params:rtp-hdrext:ssrc-audio-level +a=extmap:5 urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id +a=rtpmap:111 opus/48000/2 +"; + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + { + let extension = "urn:3gpp:video-orientation"; + m.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: extension.to_owned(), + }, + RTPCodecType::Video, + None, + )?; + } + for extension in [ + "urn:ietf:params:rtp-hdrext:ssrc-audio-level", + "urn:ietf:params:rtp-hdrext:sdes:mid", + "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id", + ] { + m.register_header_extension( + RTCRtpHeaderExtensionCapability { + uri: extension.to_owned(), + }, + RTPCodecType::Audio, + None, + )?; + } + + m.update_from_remote_description(&must_parse(HEADER_EXTENSIONS)?) + .await?; + + assert!(!m.negotiated_video.load(Ordering::SeqCst)); + assert!(m.negotiated_audio.load(Ordering::SeqCst)); + + let (abs_id, abs_audio_enabled, abs_video_enabled) = m + .get_header_extension_id(RTCRtpHeaderExtensionCapability { + uri: sdp::extmap::ABS_SEND_TIME_URI.to_owned(), + }) + .await; + assert_eq!(abs_id, 0); + assert!(!abs_audio_enabled); + assert!(!abs_video_enabled); + + let (mid_id, mid_audio_enabled, mid_video_enabled) = m + .get_header_extension_id(RTCRtpHeaderExtensionCapability { + uri: sdp::extmap::SDES_MID_URI.to_owned(), + }) + .await; + assert_eq!(mid_id, 7); + assert!(mid_audio_enabled); + assert!(!mid_video_enabled); + + let (mid_id, mid_audio_enabled, mid_video_enabled) = m + .get_header_extension_id(RTCRtpHeaderExtensionCapability { + uri: sdp::extmap::AUDIO_LEVEL_URI.to_owned(), + }) + .await; + assert_eq!(mid_id, 1); + assert!(mid_audio_enabled); + assert!(!mid_video_enabled); + + let params = + m.get_rtp_parameters_by_kind(RTPCodecType::Video, RTCRtpTransceiverDirection::Sendonly); + //dbg!(¶ms); + + let orientation = params + .header_extensions + .iter() + .find(|ext| ext.uri == "urn:3gpp:video-orientation") + .unwrap(); + assert_ne!(orientation.id, 1); + assert_ne!(orientation.id, 7); + assert_ne!(orientation.id, 5); + + Ok(()) +} diff --git a/rtc/src/api/media_engine/mod.rs b/rtc/src/api/media_engine/mod.rs new file mode 100644 index 0000000..8cae27c --- /dev/null +++ b/rtc/src/api/media_engine/mod.rs @@ -0,0 +1,804 @@ +#[cfg(test)] +mod media_engine_test; + +use std::collections::HashMap; +use std::ops::Range; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use sdp::description::session::SessionDescription; +use util::sync::Mutex as SyncMutex; + +use crate::error::{Error, Result}; +use crate::peer_connection::sdp::{ + codecs_from_media_description, rtp_extensions_from_media_description, +}; +use crate::rtp_transceiver::rtp_codec::{ + codec_parameters_fuzzy_search, CodecMatch, RTCRtpCodecCapability, RTCRtpCodecParameters, + RTCRtpHeaderExtensionCapability, RTCRtpHeaderExtensionParameters, RTCRtpParameters, + RTPCodecType, +}; +use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; +use crate::rtp_transceiver::{fmtp, PayloadType, RTCPFeedback}; +use crate::stats::stats_collector::StatsCollector; +use crate::stats::CodecStats; +use crate::stats::StatsReportType::Codec; + +/// MIME_TYPE_H264 H264 MIME type. +/// Note: Matching should be case insensitive. +pub const MIME_TYPE_H264: &str = "video/H264"; +/// MIME_TYPE_OPUS Opus MIME type +/// Note: Matching should be case insensitive. +pub const MIME_TYPE_OPUS: &str = "audio/opus"; +/// MIME_TYPE_VP8 VP8 MIME type +/// Note: Matching should be case insensitive. +pub const MIME_TYPE_VP8: &str = "video/VP8"; +/// MIME_TYPE_VP9 VP9 MIME type +/// Note: Matching should be case insensitive. +pub const MIME_TYPE_VP9: &str = "video/VP9"; +/// MIME_TYPE_AV1 AV1 MIME type +/// Note: Matching should be case insensitive. +pub const MIME_TYPE_AV1: &str = "video/AV1"; +/// MIME_TYPE_G722 G722 MIME type +/// Note: Matching should be case insensitive. +pub const MIME_TYPE_G722: &str = "audio/G722"; +/// MIME_TYPE_PCMU PCMU MIME type +/// Note: Matching should be case insensitive. +pub const MIME_TYPE_PCMU: &str = "audio/PCMU"; +/// MIME_TYPE_PCMA PCMA MIME type +/// Note: Matching should be case insensitive. +pub const MIME_TYPE_PCMA: &str = "audio/PCMA"; +/// MIME_TYPE_TELEPHONE_EVENT telephone-event MIME type +/// Note: Matching should be case insensitive. +pub const MIME_TYPE_TELEPHONE_EVENT: &str = "audio/telephone-event"; + +const VALID_EXT_IDS: Range = 1..15; + +#[derive(Default, Clone)] +pub(crate) struct MediaEngineHeaderExtension { + pub(crate) uri: String, + pub(crate) is_audio: bool, + pub(crate) is_video: bool, + pub(crate) allowed_direction: Option, +} + +impl MediaEngineHeaderExtension { + pub fn is_matching_direction(&self, dir: RTCRtpTransceiverDirection) -> bool { + if let Some(allowed_direction) = self.allowed_direction { + use RTCRtpTransceiverDirection::*; + allowed_direction == Inactive && dir == Inactive + || allowed_direction.has_send() && dir.has_send() + || allowed_direction.has_recv() && dir.has_recv() + } else { + // None means all directions matches. + true + } + } +} + +/// A MediaEngine defines the codecs supported by a PeerConnection, and the +/// configuration of those codecs. A MediaEngine must not be shared between +/// PeerConnections. +#[derive(Default)] +pub struct MediaEngine { + // If we have attempted to negotiate a codec type yet. + pub(crate) negotiated_video: AtomicBool, + pub(crate) negotiated_audio: AtomicBool, + + pub(crate) video_codecs: Vec, + pub(crate) audio_codecs: Vec, + pub(crate) negotiated_video_codecs: SyncMutex>, + pub(crate) negotiated_audio_codecs: SyncMutex>, + + header_extensions: Vec, + proposed_header_extensions: SyncMutex>, + pub(crate) negotiated_header_extensions: SyncMutex>, +} + +impl MediaEngine { + /// register_default_codecs registers the default codecs supported by Pion WebRTC. + /// register_default_codecs is not safe for concurrent use. + pub fn register_default_codecs(&mut self) -> Result<()> { + // Default Audio Codecs + for codec in vec![ + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_OPUS.to_owned(), + clock_rate: 48000, + channels: 2, + sdp_fmtp_line: "minptime=10;useinbandfec=1".to_owned(), + rtcp_feedback: vec![], + }, + payload_type: 111, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_G722.to_owned(), + clock_rate: 8000, + channels: 0, + sdp_fmtp_line: "".to_owned(), + rtcp_feedback: vec![], + }, + payload_type: 9, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_PCMU.to_owned(), + clock_rate: 8000, + channels: 0, + sdp_fmtp_line: "".to_owned(), + rtcp_feedback: vec![], + }, + payload_type: 0, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_PCMA.to_owned(), + clock_rate: 8000, + channels: 0, + sdp_fmtp_line: "".to_owned(), + rtcp_feedback: vec![], + }, + payload_type: 8, + ..Default::default() + }, + ] { + self.register_codec(codec, RTPCodecType::Audio)?; + } + + let video_rtcp_feedback = vec![ + RTCPFeedback { + typ: "goog-remb".to_owned(), + parameter: "".to_owned(), + }, + RTCPFeedback { + typ: "ccm".to_owned(), + parameter: "fir".to_owned(), + }, + RTCPFeedback { + typ: "nack".to_owned(), + parameter: "".to_owned(), + }, + RTCPFeedback { + typ: "nack".to_owned(), + parameter: "pli".to_owned(), + }, + ]; + for codec in vec![ + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_owned(), + rtcp_feedback: video_rtcp_feedback.clone(), + }, + payload_type: 96, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP9.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "profile-id=0".to_owned(), + rtcp_feedback: video_rtcp_feedback.clone(), + }, + payload_type: 98, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP9.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "profile-id=1".to_owned(), + rtcp_feedback: video_rtcp_feedback.clone(), + }, + payload_type: 100, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_H264.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42001f" + .to_owned(), + rtcp_feedback: video_rtcp_feedback.clone(), + }, + payload_type: 102, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_H264.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: + "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42001f" + .to_owned(), + rtcp_feedback: video_rtcp_feedback.clone(), + }, + payload_type: 127, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_H264.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f" + .to_owned(), + rtcp_feedback: video_rtcp_feedback.clone(), + }, + payload_type: 125, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_H264.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: + "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42e01f" + .to_owned(), + rtcp_feedback: video_rtcp_feedback.clone(), + }, + payload_type: 108, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_H264.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: + "level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42001f" + .to_owned(), + rtcp_feedback: video_rtcp_feedback.clone(), + }, + payload_type: 127, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_H264.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=640032" + .to_owned(), + rtcp_feedback: video_rtcp_feedback.clone(), + }, + payload_type: 123, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_AV1.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "profile-id=0".to_owned(), + rtcp_feedback: video_rtcp_feedback, + }, + payload_type: 41, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: "video/ulpfec".to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_owned(), + rtcp_feedback: vec![], + }, + payload_type: 116, + ..Default::default() + }, + ] { + self.register_codec(codec, RTPCodecType::Video)?; + } + + Ok(()) + } + + /// add_codec will append codec if it not exists + fn add_codec(codecs: &mut Vec, codec: RTCRtpCodecParameters) { + for c in codecs.iter() { + if c.capability.mime_type == codec.capability.mime_type + && c.payload_type == codec.payload_type + { + return; + } + } + codecs.push(codec); + } + + /// register_codec adds codec to the MediaEngine + /// These are the list of codecs supported by this PeerConnection. + /// register_codec is not safe for concurrent use. + pub fn register_codec( + &mut self, + mut codec: RTCRtpCodecParameters, + typ: RTPCodecType, + ) -> Result<()> { + codec.stats_id = format!( + "RTPCodec-{}", + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() + ); + match typ { + RTPCodecType::Audio => { + MediaEngine::add_codec(&mut self.audio_codecs, codec); + Ok(()) + } + RTPCodecType::Video => { + MediaEngine::add_codec(&mut self.video_codecs, codec); + Ok(()) + } + _ => Err(Error::ErrUnknownType), + } + } + + /// Adds a header extension to the MediaEngine + /// To determine the negotiated value use [`get_header_extension_id`] after signaling is complete. + /// + /// The `allowed_direction` controls for which transceiver directions the extension matches. If + /// set to `None` it matches all directions. The `SendRecv` direction would match all transceiver + /// directions apart from `Inactive`. Inactive only matches inactive. + pub fn register_header_extension( + &mut self, + extension: RTCRtpHeaderExtensionCapability, + typ: RTPCodecType, + allowed_direction: Option, + ) -> Result<()> { + let ext = { + match self + .header_extensions + .iter_mut() + .find(|ext| ext.uri == extension.uri) + { + Some(ext) => ext, + None => { + // We have registered too many extensions + if self.header_extensions.len() > VALID_EXT_IDS.end as usize { + return Err(Error::ErrRegisterHeaderExtensionNoFreeID); + } + self.header_extensions.push(MediaEngineHeaderExtension { + allowed_direction, + ..Default::default() + }); + + // Unwrap is fine because we just pushed + self.header_extensions.last_mut().unwrap() + } + } + }; + + if typ == RTPCodecType::Audio { + ext.is_audio = true; + } else if typ == RTPCodecType::Video { + ext.is_video = true; + } + + ext.uri = extension.uri; + + if ext.allowed_direction != allowed_direction { + return Err(Error::ErrRegisterHeaderExtensionInvalidDirection); + } + + Ok(()) + } + + /// register_feedback adds feedback mechanism to already registered codecs. + pub fn register_feedback(&mut self, feedback: RTCPFeedback, typ: RTPCodecType) { + match typ { + RTPCodecType::Video => { + for v in &mut self.video_codecs { + v.capability.rtcp_feedback.push(feedback.clone()); + } + } + RTPCodecType::Audio => { + for a in &mut self.audio_codecs { + a.capability.rtcp_feedback.push(feedback.clone()); + } + } + _ => {} + } + } + + /// get_header_extension_id returns the negotiated ID for a header extension. + /// If the Header Extension isn't enabled ok will be false + pub(crate) async fn get_header_extension_id( + &self, + extension: RTCRtpHeaderExtensionCapability, + ) -> (isize, bool, bool) { + let negotiated_header_extensions = self.negotiated_header_extensions.lock(); + if negotiated_header_extensions.is_empty() { + return (0, false, false); + } + + for (id, h) in &*negotiated_header_extensions { + if extension.uri == h.uri { + return (*id, h.is_audio, h.is_video); + } + } + + (0, false, false) + } + + /// clone_to copies any user modifiable state of the MediaEngine + /// all internal state is reset + pub(crate) fn clone_to(&self) -> Self { + MediaEngine { + video_codecs: self.video_codecs.clone(), + audio_codecs: self.audio_codecs.clone(), + header_extensions: self.header_extensions.clone(), + ..Default::default() + } + } + + pub(crate) async fn get_codec_by_payload( + &self, + payload_type: PayloadType, + ) -> Result<(RTCRtpCodecParameters, RTPCodecType)> { + { + let negotiated_video_codecs = self.negotiated_video_codecs.lock(); + for codec in &*negotiated_video_codecs { + if codec.payload_type == payload_type { + return Ok((codec.clone(), RTPCodecType::Video)); + } + } + } + { + let negotiated_audio_codecs = self.negotiated_audio_codecs.lock(); + for codec in &*negotiated_audio_codecs { + if codec.payload_type == payload_type { + return Ok((codec.clone(), RTPCodecType::Audio)); + } + } + } + + Err(Error::ErrCodecNotFound) + } + + pub(crate) async fn collect_stats(&self, collector: &StatsCollector) { + let mut reports = HashMap::new(); + + for codec in &self.video_codecs { + reports.insert(codec.stats_id.clone(), Codec(CodecStats::from(codec))); + } + + for codec in &self.audio_codecs { + reports.insert(codec.stats_id.clone(), Codec(CodecStats::from(codec))); + } + + collector.merge(reports); + } + + /// Look up a codec and enable if it exists + pub(crate) fn match_remote_codec( + &self, + remote_codec: &RTCRtpCodecParameters, + typ: RTPCodecType, + exact_matches: &[RTCRtpCodecParameters], + partial_matches: &[RTCRtpCodecParameters], + ) -> Result { + let codecs = if typ == RTPCodecType::Audio { + &self.audio_codecs + } else { + &self.video_codecs + }; + + let remote_fmtp = fmtp::parse( + &remote_codec.capability.mime_type, + remote_codec.capability.sdp_fmtp_line.as_str(), + ); + if let Some(apt) = remote_fmtp.parameter("apt") { + let payload_type = apt.parse::()?; + + let mut apt_match = CodecMatch::None; + for codec in exact_matches { + if codec.payload_type == payload_type { + apt_match = CodecMatch::Exact; + break; + } + } + + if apt_match == CodecMatch::None { + for codec in partial_matches { + if codec.payload_type == payload_type { + apt_match = CodecMatch::Partial; + break; + } + } + } + + if apt_match == CodecMatch::None { + return Ok(CodecMatch::None); // not an error, we just ignore this codec we don't support + } + + // if apt's media codec is partial match, then apt codec must be partial match too + let (_, mut match_type) = codec_parameters_fuzzy_search(remote_codec, codecs); + if match_type == CodecMatch::Exact && apt_match == CodecMatch::Partial { + match_type = CodecMatch::Partial; + } + return Ok(match_type); + } + + let (_, match_type) = codec_parameters_fuzzy_search(remote_codec, codecs); + Ok(match_type) + } + + /// Look up a header extension and enable if it exists + pub(crate) async fn update_header_extension( + &self, + id: isize, + extension: &str, + typ: RTPCodecType, + ) -> Result<()> { + let mut negotiated_header_extensions = self.negotiated_header_extensions.lock(); + let mut proposed_header_extensions = self.proposed_header_extensions.lock(); + + for local_extension in &self.header_extensions { + if local_extension.uri != extension { + continue; + } + + let negotiated_ext = negotiated_header_extensions + .iter_mut() + .find(|(_, ext)| ext.uri == extension); + + if let Some(n_ext) = negotiated_ext { + if *n_ext.0 == id { + n_ext.1.is_video |= typ == RTPCodecType::Video; + n_ext.1.is_audio |= typ == RTPCodecType::Audio; + } else { + let nid = n_ext.0; + log::warn!("Invalid ext id mapping in update_header_extension. {} was negotiated as {}, but was {} in call", extension, nid, id); + } + } else { + // We either only have a proposal or we have neither proposal nor a negotiated id + // Accept whatevers the peer suggests + + if let Some(prev_ext) = negotiated_header_extensions.get(&id) { + let prev_uri = &prev_ext.uri; + log::warn!("Assigning {} to {} would override previous assignment to {}, no action taken", id, extension, prev_uri); + } else { + let h = MediaEngineHeaderExtension { + uri: extension.to_owned(), + is_audio: local_extension.is_audio && typ == RTPCodecType::Audio, + is_video: local_extension.is_video && typ == RTPCodecType::Video, + allowed_direction: local_extension.allowed_direction, + }; + negotiated_header_extensions.insert(id, h); + } + } + + // Clear any proposals we had for this id + proposed_header_extensions.remove(&id); + } + Ok(()) + } + + pub(crate) async fn push_codecs(&self, codecs: Vec, typ: RTPCodecType) { + for codec in codecs { + if typ == RTPCodecType::Audio { + let mut negotiated_audio_codecs = self.negotiated_audio_codecs.lock(); + MediaEngine::add_codec(&mut negotiated_audio_codecs, codec); + } else if typ == RTPCodecType::Video { + let mut negotiated_video_codecs = self.negotiated_video_codecs.lock(); + MediaEngine::add_codec(&mut negotiated_video_codecs, codec); + } + } + } + + /// Update the MediaEngine from a remote description + pub(crate) async fn update_from_remote_description( + &self, + desc: &SessionDescription, + ) -> Result<()> { + for media in &desc.media_descriptions { + let typ = if !self.negotiated_audio.load(Ordering::SeqCst) + && media.media_name.media.to_lowercase() == "audio" + { + self.negotiated_audio.store(true, Ordering::SeqCst); + RTPCodecType::Audio + } else if !self.negotiated_video.load(Ordering::SeqCst) + && media.media_name.media.to_lowercase() == "video" + { + self.negotiated_video.store(true, Ordering::SeqCst); + RTPCodecType::Video + } else { + continue; + }; + + let codecs = codecs_from_media_description(media)?; + + let mut exact_matches = vec![]; //make([]RTPCodecParameters, 0, len(codecs)) + let mut partial_matches = vec![]; //make([]RTPCodecParameters, 0, len(codecs)) + + for codec in codecs { + let match_type = + self.match_remote_codec(&codec, typ, &exact_matches, &partial_matches)?; + + if match_type == CodecMatch::Exact { + exact_matches.push(codec); + } else if match_type == CodecMatch::Partial { + partial_matches.push(codec); + } + } + + // use exact matches when they exist, otherwise fall back to partial + if !exact_matches.is_empty() { + self.push_codecs(exact_matches, typ).await; + } else if !partial_matches.is_empty() { + self.push_codecs(partial_matches, typ).await; + } else { + // no match, not negotiated + continue; + } + + let extensions = rtp_extensions_from_media_description(media)?; + + for (extension, id) in extensions { + self.update_header_extension(id, &extension, typ).await?; + } + } + + Ok(()) + } + + pub(crate) fn get_codecs_by_kind(&self, typ: RTPCodecType) -> Vec { + if typ == RTPCodecType::Video { + if self.negotiated_video.load(Ordering::SeqCst) { + let negotiated_video_codecs = self.negotiated_video_codecs.lock(); + negotiated_video_codecs.clone() + } else { + self.video_codecs.clone() + } + } else if typ == RTPCodecType::Audio { + if self.negotiated_audio.load(Ordering::SeqCst) { + let negotiated_audio_codecs = self.negotiated_audio_codecs.lock(); + negotiated_audio_codecs.clone() + } else { + self.audio_codecs.clone() + } + } else { + vec![] + } + } + + pub(crate) fn get_rtp_parameters_by_kind( + &self, + typ: RTPCodecType, + direction: RTCRtpTransceiverDirection, + ) -> RTCRtpParameters { + let mut header_extensions = vec![]; + + if self.negotiated_video.load(Ordering::SeqCst) && typ == RTPCodecType::Video + || self.negotiated_audio.load(Ordering::SeqCst) && typ == RTPCodecType::Audio + { + let negotiated_header_extensions = self.negotiated_header_extensions.lock(); + for (id, e) in &*negotiated_header_extensions { + if e.is_matching_direction(direction) + && (e.is_audio && typ == RTPCodecType::Audio + || e.is_video && typ == RTPCodecType::Video) + { + header_extensions.push(RTCRtpHeaderExtensionParameters { + id: *id, + uri: e.uri.clone(), + }); + } + } + } else { + let mut proposed_header_extensions = self.proposed_header_extensions.lock(); + let mut negotiated_header_extensions = self.negotiated_header_extensions.lock(); + + for local_extension in &self.header_extensions { + let relevant = local_extension.is_matching_direction(direction) + && (local_extension.is_audio && typ == RTPCodecType::Audio + || local_extension.is_video && typ == RTPCodecType::Video); + + if !relevant { + continue; + } + + if let Some((id, negotiated_extension)) = negotiated_header_extensions + .iter_mut() + .find(|(_, e)| e.uri == local_extension.uri) + { + // We have previously negotiated this extension, make sure to record it as + // active for the current type + negotiated_extension.is_audio |= typ == RTPCodecType::Audio; + negotiated_extension.is_video |= typ == RTPCodecType::Video; + + header_extensions.push(RTCRtpHeaderExtensionParameters { + id: *id, + uri: negotiated_extension.uri.clone(), + }); + + continue; + } + + if let Some((id, negotiated_extension)) = proposed_header_extensions + .iter_mut() + .find(|(_, e)| e.uri == local_extension.uri) + { + // We have previously proposed this extension, re-use it + header_extensions.push(RTCRtpHeaderExtensionParameters { + id: *id, + uri: negotiated_extension.uri.clone(), + }); + + continue; + } + + // Figure out which (unused id) to propose. + let id = VALID_EXT_IDS.clone().find(|id| { + !negotiated_header_extensions.keys().any(|nid| nid == id) + && !proposed_header_extensions.keys().any(|pid| pid == id) + }); + + if let Some(id) = id { + proposed_header_extensions.insert( + id, + MediaEngineHeaderExtension { + uri: local_extension.uri.clone(), + is_audio: local_extension.is_audio, + is_video: local_extension.is_video, + allowed_direction: local_extension.allowed_direction, + }, + ); + + header_extensions.push(RTCRtpHeaderExtensionParameters { + id, + uri: local_extension.uri.clone(), + }); + } else { + log::warn!("No available RTP extension ID for {}", local_extension.uri); + } + } + } + + RTCRtpParameters { + header_extensions, + codecs: self.get_codecs_by_kind(typ), + } + } + + pub(crate) async fn get_rtp_parameters_by_payload_type( + &self, + payload_type: PayloadType, + ) -> Result { + let (codec, typ) = self.get_codec_by_payload(payload_type).await?; + + let mut header_extensions = vec![]; + { + let negotiated_header_extensions = self.negotiated_header_extensions.lock(); + for (id, e) in &*negotiated_header_extensions { + if e.is_audio && typ == RTPCodecType::Audio + || e.is_video && typ == RTPCodecType::Video + { + header_extensions.push(RTCRtpHeaderExtensionParameters { + uri: e.uri.clone(), + id: *id, + }); + } + } + } + + Ok(RTCRtpParameters { + header_extensions, + codecs: vec![codec], + }) + } +} diff --git a/rtc/src/api/mod.rs b/rtc/src/api/mod.rs new file mode 100644 index 0000000..2ab99c3 --- /dev/null +++ b/rtc/src/api/mod.rs @@ -0,0 +1,236 @@ +#[cfg(test)] +mod api_test; + +pub mod interceptor_registry; +pub mod media_engine; +pub mod setting_engine; + +use std::sync::Arc; +use std::time::SystemTime; + +use interceptor::registry::Registry; +use interceptor::Interceptor; +use media_engine::*; +use rcgen::KeyPair; +use setting_engine::*; + +use crate::data_channel::data_channel_parameters::DataChannelParameters; +use crate::data_channel::RTCDataChannel; +use crate::dtls_transport::RTCDtlsTransport; +use crate::error::{Error, Result}; +use crate::ice_transport::ice_gatherer::{RTCIceGatherOptions, RTCIceGatherer}; +use crate::ice_transport::RTCIceTransport; +use crate::peer_connection::certificate::RTCCertificate; +use crate::peer_connection::configuration::RTCConfiguration; +use crate::peer_connection::RTCPeerConnection; +use crate::rtp_transceiver::rtp_codec::RTPCodecType; +use crate::rtp_transceiver::rtp_receiver::RTCRtpReceiver; +use crate::rtp_transceiver::rtp_sender::RTCRtpSender; +use crate::sctp_transport::RTCSctpTransport; +use crate::track::track_local::TrackLocal; + +/// API bundles the global functions of the WebRTC and ORTC API. +/// Some of these functions are also exported globally using the +/// defaultAPI object. Note that the global version of the API +/// may be phased out in the future. +pub struct API { + pub(crate) setting_engine: Arc, + pub(crate) media_engine: Arc, + pub(crate) interceptor_registry: Registry, +} + +impl API { + /// new_peer_connection creates a new PeerConnection with the provided configuration against the received API object + pub async fn new_peer_connection( + &self, + configuration: RTCConfiguration, + ) -> Result { + RTCPeerConnection::new(self, configuration).await + } + + /// new_ice_gatherer creates a new ice gatherer. + /// This constructor is part of the ORTC API. It is not + /// meant to be used together with the basic WebRTC API. + pub fn new_ice_gatherer(&self, opts: RTCIceGatherOptions) -> Result { + let mut validated_servers = vec![]; + if !opts.ice_servers.is_empty() { + for server in &opts.ice_servers { + let url = server.urls()?; + validated_servers.extend(url); + } + } + + Ok(RTCIceGatherer::new( + validated_servers, + opts.ice_gather_policy, + Arc::clone(&self.setting_engine), + )) + } + + /// new_ice_transport creates a new ice transport. + /// This constructor is part of the ORTC API. It is not + /// meant to be used together with the basic WebRTC API. + pub fn new_ice_transport(&self, gatherer: Arc) -> RTCIceTransport { + RTCIceTransport::new(gatherer) + } + + /// new_dtls_transport creates a new dtls_transport transport. + /// This constructor is part of the ORTC API. It is not + /// meant to be used together with the basic WebRTC API. + pub fn new_dtls_transport( + &self, + ice_transport: Arc, + mut certificates: Vec, + ) -> Result { + if !certificates.is_empty() { + let now = SystemTime::now(); + for cert in &certificates { + cert.expires + .duration_since(now) + .map_err(|_| Error::ErrCertificateExpired)?; + } + } else { + let kp = KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)?; + let cert = RTCCertificate::from_key_pair(kp)?; + certificates = vec![cert]; + }; + + Ok(RTCDtlsTransport::new( + ice_transport, + certificates, + Arc::clone(&self.setting_engine), + )) + } + + /// new_sctp_transport creates a new SCTPTransport. + /// This constructor is part of the ORTC API. It is not + /// meant to be used together with the basic WebRTC API. + pub fn new_sctp_transport( + &self, + dtls_transport: Arc, + ) -> Result { + Ok(RTCSctpTransport::new( + dtls_transport, + Arc::clone(&self.setting_engine), + )) + } + + /// new_data_channel creates a new DataChannel. + /// This constructor is part of the ORTC API. It is not + /// meant to be used together with the basic WebRTC API. + pub async fn new_data_channel( + &self, + sctp_transport: Arc, + params: DataChannelParameters, + ) -> Result { + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #5) + if params.label.len() > 65535 { + return Err(Error::ErrStringSizeLimit); + } + + let d = RTCDataChannel::new(params, Arc::clone(&self.setting_engine)); + d.open(sctp_transport).await?; + + Ok(d) + } + + /// new_rtp_receiver constructs a new RTPReceiver + pub fn new_rtp_receiver( + &self, + kind: RTPCodecType, + transport: Arc, + interceptor: Arc, + ) -> RTCRtpReceiver { + RTCRtpReceiver::new( + self.setting_engine.get_receive_mtu(), + kind, + transport, + Arc::clone(&self.media_engine), + interceptor, + ) + } + + /// new_rtp_sender constructs a new RTPSender + pub async fn new_rtp_sender( + &self, + track: Option>, + transport: Arc, + interceptor: Arc, + ) -> RTCRtpSender { + RTCRtpSender::new( + self.setting_engine.get_receive_mtu(), + track, + transport, + Arc::clone(&self.media_engine), + interceptor, + false, + ) + .await + } + + /// Returns the internal [`SettingEngine`]. + pub fn setting_engine(&self) -> Arc { + Arc::clone(&self.setting_engine) + } + + /// Returns the internal [`MediaEngine`]. + pub fn media_engine(&self) -> Arc { + Arc::clone(&self.media_engine) + } +} + +#[derive(Default)] +pub struct APIBuilder { + setting_engine: Option>, + media_engine: Option>, + interceptor_registry: Option, +} + +impl APIBuilder { + pub fn new() -> Self { + APIBuilder::default() + } + + pub fn build(mut self) -> API { + API { + setting_engine: if let Some(setting_engine) = self.setting_engine.take() { + setting_engine + } else { + Arc::new(SettingEngine::default()) + }, + media_engine: if let Some(media_engine) = self.media_engine.take() { + media_engine + } else { + Arc::new(MediaEngine::default()) + }, + interceptor_registry: if let Some(interceptor_registry) = + self.interceptor_registry.take() + { + interceptor_registry + } else { + Registry::new() + }, + } + } + + /// WithSettingEngine allows providing a SettingEngine to the API. + /// Settings should not be changed after passing the engine to an API. + pub fn with_setting_engine(mut self, setting_engine: SettingEngine) -> Self { + self.setting_engine = Some(Arc::new(setting_engine)); + self + } + + /// WithMediaEngine allows providing a MediaEngine to the API. + /// Settings can be changed after passing the engine to an API. + pub fn with_media_engine(mut self, media_engine: MediaEngine) -> Self { + self.media_engine = Some(Arc::new(media_engine)); + self + } + + /// with_interceptor_registry allows providing Interceptors to the API. + /// Settings should not be changed after passing the registry to an API. + pub fn with_interceptor_registry(mut self, interceptor_registry: Registry) -> Self { + self.interceptor_registry = Some(interceptor_registry); + self + } +} diff --git a/rtc/src/api/setting_engine/mod.rs b/rtc/src/api/setting_engine/mod.rs new file mode 100644 index 0000000..6d909d3 --- /dev/null +++ b/rtc/src/api/setting_engine/mod.rs @@ -0,0 +1,327 @@ +#[cfg(test)] +mod setting_engine_test; + +use std::sync::Arc; + +use dtls::extension::extension_use_srtp::SrtpProtectionProfile; +use ice::agent::agent_config::{InterfaceFilterFn, IpFilterFn}; +use ice::mdns::MulticastDnsMode; +use ice::network_type::NetworkType; +use ice::udp_network::UDPNetwork; +use tokio::time::Duration; +use util::vnet::net::*; + +use crate::dtls_transport::dtls_role::DTLSRole; +use crate::error::{Error, Result}; +use crate::ice_transport::ice_candidate_type::RTCIceCandidateType; +use crate::RECEIVE_MTU; + +#[derive(Default, Clone)] +pub struct Detach { + pub data_channels: bool, +} + +#[derive(Default, Clone)] +pub struct Timeout { + pub ice_disconnected_timeout: Option, + pub ice_failed_timeout: Option, + pub ice_keepalive_interval: Option, + pub ice_host_acceptance_min_wait: Option, + pub ice_srflx_acceptance_min_wait: Option, + pub ice_prflx_acceptance_min_wait: Option, + pub ice_relay_acceptance_min_wait: Option, +} + +#[derive(Default, Clone)] +pub struct Candidates { + pub ice_lite: bool, + pub ice_network_types: Vec, + pub interface_filter: Arc>, + pub ip_filter: Arc>, + pub nat_1to1_ips: Vec, + pub nat_1to1_ip_candidate_type: RTCIceCandidateType, + pub multicast_dns_mode: MulticastDnsMode, + pub multicast_dns_host_name: String, + pub username_fragment: String, + pub password: String, +} + +#[derive(Default, Clone)] +pub struct ReplayProtection { + pub dtls: usize, + pub srtp: usize, + pub srtcp: usize, +} + +/// SettingEngine allows influencing behavior in ways that are not +/// supported by the WebRTC API. This allows us to support additional +/// use-cases without deviating from the WebRTC API elsewhere. +#[derive(Default, Clone)] +pub struct SettingEngine { + pub(crate) detach: Detach, + pub(crate) timeout: Timeout, + pub(crate) candidates: Candidates, + pub(crate) replay_protection: ReplayProtection, + pub(crate) sdp_media_level_fingerprints: bool, + pub(crate) answering_dtls_role: DTLSRole, + pub(crate) disable_certificate_fingerprint_verification: bool, + pub(crate) allow_insecure_verification_algorithm: bool, + pub(crate) disable_srtp_replay_protection: bool, + pub(crate) disable_srtcp_replay_protection: bool, + pub(crate) vnet: Option>, + //BufferFactory :func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser, + //iceTCPMux :ice.TCPMux,? + //iceProxyDialer :proxy.Dialer,? + pub(crate) udp_network: UDPNetwork, + pub(crate) disable_media_engine_copy: bool, + pub(crate) srtp_protection_profiles: Vec, + pub(crate) receive_mtu: usize, + pub(crate) mid_generator: Option String + Send + Sync>>, +} + +impl SettingEngine { + /// get_receive_mtu returns the configured MTU. If SettingEngine's MTU is configured to 0 it returns the default + pub(crate) fn get_receive_mtu(&self) -> usize { + if self.receive_mtu != 0 { + self.receive_mtu + } else { + RECEIVE_MTU + } + } + /// detach_data_channels enables detaching data channels. When enabled + /// data channels have to be detached in the OnOpen callback using the + /// DataChannel.Detach method. + pub fn detach_data_channels(&mut self) { + self.detach.data_channels = true; + } + + /// set_srtp_protection_profiles allows the user to override the default srtp Protection Profiles + /// The default srtp protection profiles are provided by the function `defaultSrtpProtectionProfiles` + pub fn set_srtp_protection_profiles(&mut self, profiles: Vec) { + self.srtp_protection_profiles = profiles + } + + /// set_ice_timeouts sets the behavior around ICE Timeouts + /// * disconnected_timeout is the duration without network activity before a Agent is considered disconnected. Default is 5 Seconds + /// * failed_timeout is the duration without network activity before a Agent is considered failed after disconnected. Default is 25 Seconds + /// * keep_alive_interval is how often the ICE Agent sends extra traffic if there is no activity, if media is flowing no traffic will be sent. Default is 2 seconds + pub fn set_ice_timeouts( + &mut self, + disconnected_timeout: Option, + failed_timeout: Option, + keep_alive_interval: Option, + ) { + self.timeout.ice_disconnected_timeout = disconnected_timeout; + self.timeout.ice_failed_timeout = failed_timeout; + self.timeout.ice_keepalive_interval = keep_alive_interval; + } + + /// set_host_acceptance_min_wait sets the icehost_acceptance_min_wait + pub fn set_host_acceptance_min_wait(&mut self, t: Option) { + self.timeout.ice_host_acceptance_min_wait = t; + } + + /// set_srflx_acceptance_min_wait sets the icesrflx_acceptance_min_wait + pub fn set_srflx_acceptance_min_wait(&mut self, t: Option) { + self.timeout.ice_srflx_acceptance_min_wait = t; + } + + /// set_prflx_acceptance_min_wait sets the iceprflx_acceptance_min_wait + pub fn set_prflx_acceptance_min_wait(&mut self, t: Option) { + self.timeout.ice_prflx_acceptance_min_wait = t; + } + + /// set_relay_acceptance_min_wait sets the icerelay_acceptance_min_wait + pub fn set_relay_acceptance_min_wait(&mut self, t: Option) { + self.timeout.ice_relay_acceptance_min_wait = t; + } + + /// set_udp_network allows ICE traffic to come through Ephemeral or UDPMux. + /// UDPMux drastically simplifying deployments where ports will need to be opened/forwarded. + /// UDPMux should be started prior to creating PeerConnections. + pub fn set_udp_network(&mut self, udp_network: UDPNetwork) { + self.udp_network = udp_network; + } + + /// set_lite configures whether or not the ice agent should be a lite agent + pub fn set_lite(&mut self, lite: bool) { + self.candidates.ice_lite = lite; + } + + /// set_network_types configures what types of candidate networks are supported + /// during local and server reflexive gathering. + pub fn set_network_types(&mut self, candidate_types: Vec) { + self.candidates.ice_network_types = candidate_types; + } + + /// set_interface_filter sets the filtering functions when gathering ICE candidates + /// This can be used to exclude certain network interfaces from ICE. Which may be + /// useful if you know a certain interface will never succeed, or if you wish to reduce + /// the amount of information you wish to expose to the remote peer + pub fn set_interface_filter(&mut self, filter: InterfaceFilterFn) { + self.candidates.interface_filter = Arc::new(Some(filter)); + } + + /// set_ip_filter sets the filtering functions when gathering ICE candidates + /// This can be used to exclude certain ip from ICE. Which may be + /// useful if you know a certain ip will never succeed, or if you wish to reduce + /// the amount of information you wish to expose to the remote peer + pub fn set_ip_filter(&mut self, filter: IpFilterFn) { + self.candidates.ip_filter = Arc::new(Some(filter)); + } + + /// set_nat_1to1_ips sets a list of external IP addresses of 1:1 (D)NAT + /// and a candidate type for which the external IP address is used. + /// This is useful when you are host a server using Pion on an AWS EC2 instance + /// which has a private address, behind a 1:1 DNAT with a public IP (e.g. + /// Elastic IP). In this case, you can give the public IP address so that + /// Pion will use the public IP address in its candidate instead of the private + /// IP address. The second argument, candidate_type, is used to tell Pion which + /// type of candidate should use the given public IP address. + /// Two types of candidates are supported: + /// + /// ICECandidateTypeHost: + /// The public IP address will be used for the host candidate in the SDP. + /// ICECandidateTypeSrflx: + /// A server reflexive candidate with the given public IP address will be added + /// to the SDP. + /// + /// Please note that if you choose ICECandidateTypeHost, then the private IP address + /// won't be advertised with the peer. Also, this option cannot be used along with mDNS. + /// + /// If you choose ICECandidateTypeSrflx, it simply adds a server reflexive candidate + /// with the public IP. The host candidate is still available along with mDNS + /// capabilities unaffected. Also, you cannot give STUN server URL at the same time. + /// It will result in an error otherwise. + pub fn set_nat_1to1_ips(&mut self, ips: Vec, candidate_type: RTCIceCandidateType) { + self.candidates.nat_1to1_ips = ips; + self.candidates.nat_1to1_ip_candidate_type = candidate_type; + } + + /// set_answering_dtls_role sets the dtls_transport role that is selected when offering + /// The dtls_transport role controls if the WebRTC Client as a client or server. This + /// may be useful when interacting with non-compliant clients or debugging issues. + /// + /// DTLSRoleActive: + /// Act as dtls_transport Client, send the ClientHello and starts the handshake + /// DTLSRolePassive: + /// Act as dtls_transport Server, wait for ClientHello + pub fn set_answering_dtls_role(&mut self, role: DTLSRole) -> Result<()> { + if role != DTLSRole::Client && role != DTLSRole::Server { + return Err(Error::ErrSettingEngineSetAnsweringDTLSRole); + } + + self.answering_dtls_role = role; + Ok(()) + } + + /// set_vnet sets the VNet instance that is passed to ice + /// VNet is a virtual network layer, allowing users to simulate + /// different topologies, latency, loss and jitter. This can be useful for + /// learning WebRTC concepts or testing your application in a lab environment + pub fn set_vnet(&mut self, vnet: Option>) { + self.vnet = vnet; + } + + /// set_ice_multicast_dns_mode controls if ice queries and generates mDNS ICE Candidates + pub fn set_ice_multicast_dns_mode(&mut self, multicast_dns_mode: ice::mdns::MulticastDnsMode) { + self.candidates.multicast_dns_mode = multicast_dns_mode + } + + /// set_multicast_dns_host_name sets a static HostName to be used by ice instead of generating one on startup + /// This should only be used for a single PeerConnection. Having multiple PeerConnections with the same HostName will cause + /// undefined behavior + pub fn set_multicast_dns_host_name(&mut self, host_name: String) { + self.candidates.multicast_dns_host_name = host_name; + } + + /// set_ice_credentials sets a staic uFrag/uPwd to be used by ice + /// This is useful if you want to do signalless WebRTC session, or having a reproducible environment with static credentials + pub fn set_ice_credentials(&mut self, username_fragment: String, password: String) { + self.candidates.username_fragment = username_fragment; + self.candidates.password = password; + } + + /// disable_certificate_fingerprint_verification disables fingerprint verification after dtls_transport Handshake has finished + pub fn disable_certificate_fingerprint_verification(&mut self, is_disabled: bool) { + self.disable_certificate_fingerprint_verification = is_disabled; + } + + /// allow_insecure_verification_algorithm allows the usage of certain signature verification + /// algorithm that are known to be vulnerable or deprecated. + pub fn allow_insecure_verification_algorithm(&mut self, is_allowed: bool) { + self.allow_insecure_verification_algorithm = is_allowed; + } + /// set_dtls_replay_protection_window sets a replay attack protection window size of dtls_transport connection. + pub fn set_dtls_replay_protection_window(&mut self, n: usize) { + self.replay_protection.dtls = n; + } + + /// set_srtp_replay_protection_window sets a replay attack protection window size of srtp session. + pub fn set_srtp_replay_protection_window(&mut self, n: usize) { + self.disable_srtp_replay_protection = false; + self.replay_protection.srtp = n; + } + + /// set_srtcp_replay_protection_window sets a replay attack protection window size of srtcp session. + pub fn set_srtcp_replay_protection_window(&mut self, n: usize) { + self.disable_srtcp_replay_protection = false; + self.replay_protection.srtcp = n; + } + + /// disable_srtp_replay_protection disables srtp replay protection. + pub fn disable_srtp_replay_protection(&mut self, is_disabled: bool) { + self.disable_srtp_replay_protection = is_disabled; + } + + /// disable_srtcp_replay_protection disables srtcp replay protection. + pub fn disable_srtcp_replay_protection(&mut self, is_disabled: bool) { + self.disable_srtcp_replay_protection = is_disabled; + } + + /// set_sdp_media_level_fingerprints configures the logic for dtls_transport Fingerprint insertion + /// If true, fingerprints will be inserted in the sdp at the fingerprint + /// level, instead of the session level. This helps with compatibility with + /// some webrtc implementations. + pub fn set_sdp_media_level_fingerprints(&mut self, sdp_media_level_fingerprints: bool) { + self.sdp_media_level_fingerprints = sdp_media_level_fingerprints; + } + + // SetICETCPMux enables ICE-TCP when set to a non-nil value. Make sure that + // NetworkTypeTCP4 or NetworkTypeTCP6 is enabled as well. + //pub fn SetICETCPMux(&mut self, tcpMux ice.TCPMux) { + // self.iceTCPMux = tcpMux + //} + + // SetICEProxyDialer sets the proxy dialer interface based on golang.org/x/net/proxy. + //pub fn SetICEProxyDialer(&mut self, d proxy.Dialer) { + // self.iceProxyDialer = d + //} + + /// disable_media_engine_copy stops the MediaEngine from being copied. This allows a user to modify + /// the MediaEngine after the PeerConnection has been constructed. This is useful if you wish to + /// modify codecs after signaling. Make sure not to share MediaEngines between PeerConnections. + pub fn disable_media_engine_copy(&mut self, is_disabled: bool) { + self.disable_media_engine_copy = is_disabled; + } + + /// set_receive_mtu sets the size of read buffer that copies incoming packets. This is optional. + /// Leave this 0 for the default receive_mtu + pub fn set_receive_mtu(&mut self, receive_mtu: usize) { + self.receive_mtu = receive_mtu; + } + + /// Sets a callback used to generate mid for transceivers created by this side of the RTCPeerconnection. + /// By having separate "naming schemes" for mids generated by either side of a connection, it's + /// possible to reduce complexity when handling SDP offers/answers clashing. + /// + /// The `isize` argument is currently greatest seen _numeric_ mid. Since mids don't need to be numeric + /// this doesn't necessarily indicating anything. + /// + /// Note that the spec says: All MID values MUST be generated in a fashion that does not leak user + /// information, e.g., randomly or using a per-PeerConnection counter, and SHOULD be 3 bytes or less, + /// to allow them to efficiently fit into the RTP header extension + pub fn set_mid_generator(&mut self, f: impl Fn(isize) -> String + Send + Sync + 'static) { + self.mid_generator = Some(Arc::new(f)); + } +} diff --git a/rtc/src/api/setting_engine/setting_engine_test.rs b/rtc/src/api/setting_engine/setting_engine_test.rs new file mode 100644 index 0000000..cb5433f --- /dev/null +++ b/rtc/src/api/setting_engine/setting_engine_test.rs @@ -0,0 +1,271 @@ +use std::sync::atomic::Ordering; + +use super::*; +use crate::api::media_engine::MediaEngine; +use crate::api::APIBuilder; +use crate::peer_connection::peer_connection_test::*; +use crate::rtp_transceiver::rtp_codec::RTPCodecType; + +#[test] +fn test_set_connection_timeout() -> Result<()> { + let mut s = SettingEngine::default(); + + assert_eq!(s.timeout.ice_disconnected_timeout, None); + assert_eq!(s.timeout.ice_failed_timeout, None); + assert_eq!(s.timeout.ice_keepalive_interval, None); + + s.set_ice_timeouts( + Some(Duration::from_secs(1)), + Some(Duration::from_secs(2)), + Some(Duration::from_secs(3)), + ); + assert_eq!( + s.timeout.ice_disconnected_timeout, + Some(Duration::from_secs(1)) + ); + assert_eq!(s.timeout.ice_failed_timeout, Some(Duration::from_secs(2))); + assert_eq!( + s.timeout.ice_keepalive_interval, + Some(Duration::from_secs(3)) + ); + + Ok(()) +} + +#[test] +fn test_detach_data_channels() -> Result<()> { + let mut s = SettingEngine::default(); + + assert!( + !s.detach.data_channels, + "SettingEngine defaults aren't as expected." + ); + + s.detach_data_channels(); + + assert!( + s.detach.data_channels, + "Failed to enable detached data channels." + ); + + Ok(()) +} + +#[test] +fn test_set_nat_1to1_ips() -> Result<()> { + let mut s = SettingEngine::default(); + + assert!( + s.candidates.nat_1to1_ips.is_empty(), + "Invalid default value" + ); + assert!( + s.candidates.nat_1to1_ip_candidate_type == RTCIceCandidateType::Unspecified, + "Invalid default value" + ); + + let ips = vec!["1.2.3.4".to_owned()]; + let typ = RTCIceCandidateType::Host; + s.set_nat_1to1_ips(ips, typ); + assert!( + !(s.candidates.nat_1to1_ips.len() != 1 || s.candidates.nat_1to1_ips[0] != "1.2.3.4"), + "Failed to set NAT1To1IPs" + ); + assert!( + s.candidates.nat_1to1_ip_candidate_type == typ, + "Failed to set NAT1To1IPCandidateType" + ); + + Ok(()) +} + +#[test] +fn test_set_answering_dtls_role() -> Result<()> { + let mut s = SettingEngine::default(); + assert!( + s.set_answering_dtls_role(DTLSRole::Auto).is_err(), + "SetAnsweringDTLSRole can only be called with DTLSRoleClient or DTLSRoleServer" + ); + assert!( + s.set_answering_dtls_role(DTLSRole::Unspecified).is_err(), + "SetAnsweringDTLSRole can only be called with DTLSRoleClient or DTLSRoleServer" + ); + + Ok(()) +} + +#[test] +fn test_set_replay_protection() -> Result<()> { + let mut s = SettingEngine::default(); + + assert!( + !(s.replay_protection.dtls != 0 + || s.replay_protection.srtp != 0 + || s.replay_protection.srtcp != 0), + "SettingEngine defaults aren't as expected." + ); + + s.set_dtls_replay_protection_window(128); + s.set_srtp_replay_protection_window(64); + s.set_srtcp_replay_protection_window(32); + + assert!( + !(s.replay_protection.dtls == 0 || s.replay_protection.dtls != 128), + "Failed to set DTLS replay protection window" + ); + assert!( + !(s.replay_protection.srtp == 0 || s.replay_protection.srtp != 64), + "Failed to set SRTP replay protection window" + ); + assert!( + !(s.replay_protection.srtcp == 0 || s.replay_protection.srtcp != 32), + "Failed to set SRTCP replay protection window" + ); + + Ok(()) +} + +/*TODO:#[test] fn test_setting_engine_set_ice_tcp_mux() ->Result<()> { + + listener, err := net.ListenTCP("tcp", &net.TCPAddr{}) + if err != nil { + panic(err) + } + + defer func() { + _ = listener.Close() + }() + + tcpMux := NewICETCPMux(nil, listener, 8) + + defer func() { + _ = tcpMux.Close() + }() + + let mut s = SettingEngine::default(); + settingEngine.SetICETCPMux(tcpMux) + + assert.Equal(t, tcpMux, settingEngine.iceTCPMux) + + Ok(()) +} +*/ + +#[tokio::test] +async fn test_setting_engine_set_disable_media_engine_copy() -> Result<()> { + //"Copy" + { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (mut offerer, mut answerer) = new_pair(&api).await?; + + offerer + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + signal_pair(&mut offerer, &mut answerer).await?; + + // Assert that the MediaEngine the user created isn't modified + assert!(!api.media_engine.negotiated_video.load(Ordering::SeqCst)); + { + let negotiated_video_codecs = api.media_engine.negotiated_video_codecs.lock(); + assert!(negotiated_video_codecs.is_empty()); + } + + // Assert that the internal MediaEngine is modified + assert!(offerer + .internal + .media_engine + .negotiated_video + .load(Ordering::SeqCst)); + { + let negotiated_video_codecs = + offerer.internal.media_engine.negotiated_video_codecs.lock(); + assert!(!negotiated_video_codecs.is_empty()); + } + + close_pair_now(&offerer, &answerer).await; + + let (new_offerer, new_answerer) = new_pair(&api).await?; + + // Assert that the first internal MediaEngine hasn't been cleared + assert!(offerer + .internal + .media_engine + .negotiated_video + .load(Ordering::SeqCst)); + { + let negotiated_video_codecs = + offerer.internal.media_engine.negotiated_video_codecs.lock(); + assert!(!negotiated_video_codecs.is_empty()); + } + + // Assert that the new internal MediaEngine isn't modified + assert!(!new_offerer + .internal + .media_engine + .negotiated_video + .load(Ordering::SeqCst)); + { + let negotiated_video_codecs = new_offerer + .internal + .media_engine + .negotiated_video_codecs + .lock(); + assert!(negotiated_video_codecs.is_empty()); + } + + close_pair_now(&new_offerer, &new_answerer).await; + } + + //"No Copy" + { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + + let mut s = SettingEngine::default(); + s.disable_media_engine_copy(true); + + let api = APIBuilder::new() + .with_media_engine(m) + .with_setting_engine(s) + .build(); + + let (mut offerer, mut answerer) = new_pair(&api).await?; + + offerer + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + signal_pair(&mut offerer, &mut answerer).await?; + + // Assert that the user MediaEngine was modified, so no copy happened + assert!(api.media_engine.negotiated_video.load(Ordering::SeqCst)); + { + let negotiated_video_codecs = api.media_engine.negotiated_video_codecs.lock(); + assert!(!negotiated_video_codecs.is_empty()); + } + + close_pair_now(&offerer, &answerer).await; + + let (offerer, answerer) = new_pair(&api).await?; + + // Assert that the new internal MediaEngine was modified, so no copy happened + assert!(offerer + .internal + .media_engine + .negotiated_video + .load(Ordering::SeqCst)); + { + let negotiated_video_codecs = + offerer.internal.media_engine.negotiated_video_codecs.lock(); + assert!(!negotiated_video_codecs.is_empty()); + } + + close_pair_now(&offerer, &answerer).await; + } + + Ok(()) +} diff --git a/rtc/src/data_channel/data_channel_init.rs b/rtc/src/data_channel/data_channel_init.rs new file mode 100644 index 0000000..5adbdb7 --- /dev/null +++ b/rtc/src/data_channel/data_channel_init.rs @@ -0,0 +1,29 @@ +/// DataChannelConfig can be used to configure properties of the underlying +/// channel such as data reliability. +#[derive(Default, Debug, Clone)] +pub struct RTCDataChannelInit { + /// ordered indicates if data is allowed to be delivered out of order. The + /// default value of true, guarantees that data will be delivered in order. + pub ordered: Option, + + /// max_packet_life_time limits the time (in milliseconds) during which the + /// channel will transmit or retransmit data if not acknowledged. This value + /// may be clamped if it exceeds the maximum value supported. + pub max_packet_life_time: Option, + + /// max_retransmits limits the number of times a channel will retransmit data + /// if not successfully delivered. This value may be clamped if it exceeds + /// the maximum value supported. + pub max_retransmits: Option, + + /// protocol describes the subprotocol name used for this channel. + pub protocol: Option, + + /// negotiated describes if the data channel is created by the local peer or + /// the remote peer. The default value of None tells the user agent to + /// announce the channel in-band and instruct the other peer to dispatch a + /// corresponding DataChannel. If set to Some(id), it is up to the application + /// to negotiate the channel and create an DataChannel with the same id + /// at the other peer. + pub negotiated: Option, +} diff --git a/rtc/src/data_channel/data_channel_message.rs b/rtc/src/data_channel/data_channel_message.rs new file mode 100644 index 0000000..a781ec4 --- /dev/null +++ b/rtc/src/data_channel/data_channel_message.rs @@ -0,0 +1,11 @@ +use bytes::Bytes; + +/// DataChannelMessage represents a message received from the +/// data channel. IsString will be set to true if the incoming +/// message is of the string type. Otherwise the message is of +/// a binary type. +#[derive(Default, Debug, Clone)] +pub struct DataChannelMessage { + pub is_string: bool, + pub data: Bytes, +} diff --git a/rtc/src/data_channel/data_channel_parameters.rs b/rtc/src/data_channel/data_channel_parameters.rs new file mode 100644 index 0000000..88f116e --- /dev/null +++ b/rtc/src/data_channel/data_channel_parameters.rs @@ -0,0 +1,12 @@ +use serde::{Deserialize, Serialize}; + +/// DataChannelParameters describes the configuration of the DataChannel. +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct DataChannelParameters { + pub label: String, + pub protocol: String, + pub ordered: bool, + pub max_packet_life_time: u16, + pub max_retransmits: u16, + pub negotiated: Option, +} diff --git a/rtc/src/data_channel/data_channel_state.rs b/rtc/src/data_channel/data_channel_state.rs new file mode 100644 index 0000000..1f52ffe --- /dev/null +++ b/rtc/src/data_channel/data_channel_state.rs @@ -0,0 +1,113 @@ +use std::fmt; + +use serde::Serialize; + +/// DataChannelState indicates the state of a data channel. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize)] +pub enum RTCDataChannelState { + #[serde(rename = "unspecified")] + #[default] + Unspecified = 0, + + /// DataChannelStateConnecting indicates that the data channel is being + /// established. This is the initial state of DataChannel, whether created + /// with create_data_channel, or dispatched as a part of an DataChannelEvent. + #[serde(rename = "connecting")] + Connecting, + + /// DataChannelStateOpen indicates that the underlying data transport is + /// established and communication is possible. + #[serde(rename = "open")] + Open, + + /// DataChannelStateClosing indicates that the procedure to close down the + /// underlying data transport has started. + #[serde(rename = "closing")] + Closing, + + /// DataChannelStateClosed indicates that the underlying data transport + /// has been closed or could not be established. + #[serde(rename = "closed")] + Closed, +} + +const DATA_CHANNEL_STATE_CONNECTING_STR: &str = "connecting"; +const DATA_CHANNEL_STATE_OPEN_STR: &str = "open"; +const DATA_CHANNEL_STATE_CLOSING_STR: &str = "closing"; +const DATA_CHANNEL_STATE_CLOSED_STR: &str = "closed"; + +impl From for RTCDataChannelState { + fn from(v: u8) -> Self { + match v { + 1 => RTCDataChannelState::Connecting, + 2 => RTCDataChannelState::Open, + 3 => RTCDataChannelState::Closing, + 4 => RTCDataChannelState::Closed, + _ => RTCDataChannelState::Unspecified, + } + } +} + +impl From<&str> for RTCDataChannelState { + fn from(raw: &str) -> Self { + match raw { + DATA_CHANNEL_STATE_CONNECTING_STR => RTCDataChannelState::Connecting, + DATA_CHANNEL_STATE_OPEN_STR => RTCDataChannelState::Open, + DATA_CHANNEL_STATE_CLOSING_STR => RTCDataChannelState::Closing, + DATA_CHANNEL_STATE_CLOSED_STR => RTCDataChannelState::Closed, + _ => RTCDataChannelState::Unspecified, + } + } +} + +impl fmt::Display for RTCDataChannelState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + RTCDataChannelState::Connecting => DATA_CHANNEL_STATE_CONNECTING_STR, + RTCDataChannelState::Open => DATA_CHANNEL_STATE_OPEN_STR, + RTCDataChannelState::Closing => DATA_CHANNEL_STATE_CLOSING_STR, + RTCDataChannelState::Closed => DATA_CHANNEL_STATE_CLOSED_STR, + RTCDataChannelState::Unspecified => crate::UNSPECIFIED_STR, + }; + write!(f, "{s}") + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_data_channel_state() { + let tests = vec![ + (crate::UNSPECIFIED_STR, RTCDataChannelState::Unspecified), + ("connecting", RTCDataChannelState::Connecting), + ("open", RTCDataChannelState::Open), + ("closing", RTCDataChannelState::Closing), + ("closed", RTCDataChannelState::Closed), + ]; + + for (state_string, expected_state) in tests { + assert_eq!( + RTCDataChannelState::from(state_string), + expected_state, + "testCase: {expected_state}", + ); + } + } + + #[test] + fn test_data_channel_state_string() { + let tests = vec![ + (RTCDataChannelState::Unspecified, crate::UNSPECIFIED_STR), + (RTCDataChannelState::Connecting, "connecting"), + (RTCDataChannelState::Open, "open"), + (RTCDataChannelState::Closing, "closing"), + (RTCDataChannelState::Closed, "closed"), + ]; + + for (state, expected_string) in tests { + assert_eq!(state.to_string(), expected_string) + } + } +} diff --git a/rtc/src/data_channel/data_channel_test.rs b/rtc/src/data_channel/data_channel_test.rs new file mode 100644 index 0000000..65e8eb9 --- /dev/null +++ b/rtc/src/data_channel/data_channel_test.rs @@ -0,0 +1,1504 @@ +// Silence warning on `for i in 0..vec.len() { … }`: +#![allow(clippy::needless_range_loop)] + +use regex::Regex; +use tokio::sync::mpsc; +use tokio::time::Duration; +use waitgroup::WaitGroup; + +use super::*; +use crate::api::media_engine::MediaEngine; +use crate::api::{APIBuilder, API}; +use crate::data_channel::data_channel_init::RTCDataChannelInit; +//use log::LevelFilter; +//use std::io::Write; +use crate::dtls_transport::dtls_parameters::DTLSParameters; +use crate::dtls_transport::RTCDtlsTransport; +use crate::error::flatten_errs; +use crate::ice_transport::ice_candidate::RTCIceCandidate; +use crate::ice_transport::ice_connection_state::RTCIceConnectionState; +use crate::ice_transport::ice_gatherer::{RTCIceGatherOptions, RTCIceGatherer}; +use crate::ice_transport::ice_parameters::RTCIceParameters; +use crate::ice_transport::ice_role::RTCIceRole; +use crate::ice_transport::RTCIceTransport; +use crate::peer_connection::configuration::RTCConfiguration; +use crate::peer_connection::peer_connection_test::*; +use crate::peer_connection::RTCPeerConnection; +use crate::sctp_transport::sctp_transport_capabilities::SCTPTransportCapabilities; + +// EXPECTED_LABEL represents the label of the data channel we are trying to test. +// Some other channels may have been created during initialization (in the Wasm +// bindings this is a requirement). +const EXPECTED_LABEL: &str = "data"; + +async fn set_up_data_channel_parameters_test( + api: &API, + options: Option, +) -> Result<( + RTCPeerConnection, + RTCPeerConnection, + Arc, + mpsc::Sender<()>, + mpsc::Receiver<()>, +)> { + let (offer_pc, answer_pc) = new_pair(api).await?; + let (done_tx, done_rx) = mpsc::channel(1); + + let dc = offer_pc + .create_data_channel(EXPECTED_LABEL, options) + .await?; + Ok((offer_pc, answer_pc, dc, done_tx, done_rx)) +} + +async fn close_reliability_param_test( + pc1: &mut RTCPeerConnection, + pc2: &mut RTCPeerConnection, + done_rx: mpsc::Receiver<()>, +) -> Result<()> { + signal_pair(pc1, pc2).await?; + + close_pair(pc1, pc2, done_rx).await; + + Ok(()) +} + +/* +TODO: #[tokio::test] async fnBenchmarkDataChannelSend2(b *testing.B) { benchmarkDataChannelSend(b, 2) } +#[tokio::test] async fnBenchmarkDataChannelSend4(b *testing.B) { benchmarkDataChannelSend(b, 4) } +#[tokio::test] async fnBenchmarkDataChannelSend8(b *testing.B) { benchmarkDataChannelSend(b, 8) } +#[tokio::test] async fnBenchmarkDataChannelSend16(b *testing.B) { benchmarkDataChannelSend(b, 16) } +#[tokio::test] async fnBenchmarkDataChannelSend32(b *testing.B) { benchmarkDataChannelSend(b, 32) } + +// See https://github.com/pion/webrtc/issues/1516 +#[tokio::test] async fnbenchmarkDataChannelSend(b *testing.B, numChannels int) { + offerPC, answerPC, err := newPair() + if err != nil { + b.Fatalf("Failed to create a PC pair for testing") + } + + open := make(map[string]chan bool) + answerPC.OnDataChannel(func(d *DataChannel) { + if _, ok := open[d.Label()]; !ok { + // Ignore anything unknown channel label. + return + } + d.OnOpen(func() { open[d.Label()] <- true }) + }) + + var wg sync.WaitGroup + for i := 0; i < numChannels; i++ { + label := fmt.Sprintf("dc-%d", i) + open[label] = make(chan bool) + wg.Add(1) + dc, err := offerPC.CreateDataChannel(label, nil) + assert.NoError(b, err) + + dc.OnOpen(func() { + <-open[label] + for n := 0; n < b.N/numChannels; n++ { + if err := dc.SendText("Ping"); err != nil { + b.Fatalf("Unexpected error sending data (label=%q): %v", label, err) + } + } + wg.Done() + }) + } + + assert.NoError(b, signalPair(offerPC, answerPC)) + wg.Wait() + close_pair_now(b, offerPC, answerPC) +} +*/ + +#[tokio::test] +async fn test_data_channel_open() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, LevelFilter::Trace) + .init();*/ + + //"handler should be called once" + { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (mut offer_pc, mut answer_pc) = new_pair(&api).await?; + + let (done_tx, done_rx) = mpsc::channel(1); + let (open_calls_tx, mut open_calls_rx) = mpsc::channel(2); + + let open_calls_tx = Arc::new(open_calls_tx); + let done_tx = Arc::new(done_tx); + answer_pc.on_data_channel(Box::new(move |d: Arc| { + if d.label() == EXPECTED_LABEL { + let open_calls_tx2 = Arc::clone(&open_calls_tx); + let done_tx2 = Arc::clone(&done_tx); + Box::pin(async move { + d.on_open(Box::new(move || { + Box::pin(async move { + let _ = open_calls_tx2.send(()).await; + }) + })); + d.on_message(Box::new(move |_: DataChannelMessage| { + let done_tx3 = Arc::clone(&done_tx2); + tokio::spawn(async move { + // Wait a little bit to ensure all messages are processed. + tokio::time::sleep(Duration::from_millis(100)).await; + let _ = done_tx3.send(()).await; + }); + Box::pin(async {}) + })); + }) + } else { + Box::pin(async {}) + } + })); + + let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; + + let dc2 = Arc::clone(&dc); + dc.on_open(Box::new(move || { + Box::pin(async move { + let result = dc2.send_text("Ping".to_owned()).await; + assert!(result.is_ok(), "Failed to send string on data channel"); + }) + })); + + signal_pair(&mut offer_pc, &mut answer_pc).await?; + + close_pair(&offer_pc, &answer_pc, done_rx).await; + + let _ = open_calls_rx.recv().await; + } + + Ok(()) +} + +#[tokio::test] +async fn test_data_channel_send_before_signaling() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + //"before signaling" + + let (mut offer_pc, mut answer_pc) = new_pair(&api).await?; + + answer_pc.on_data_channel(Box::new(move |d: Arc| { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if d.label() != EXPECTED_LABEL { + return Box::pin(async {}); + } + Box::pin(async move { + let d2 = Arc::clone(&d); + d.on_message(Box::new(move |_: DataChannelMessage| { + let d3 = Arc::clone(&d2); + Box::pin(async move { + let result = d3.send(&Bytes::from(b"Pong".to_vec())).await; + assert!(result.is_ok(), "Failed to send string on data channel"); + }) + })); + assert!(d.ordered(), "Ordered should be set to true"); + }) + })); + + let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; + + assert!(dc.ordered(), "Ordered should be set to true"); + + let dc2 = Arc::clone(&dc); + dc.on_open(Box::new(move || { + let dc3 = Arc::clone(&dc2); + Box::pin(async move { + let result = dc3.send_text("Ping".to_owned()).await; + assert!(result.is_ok(), "Failed to send string on data channel"); + }) + })); + + let (done_tx, done_rx) = mpsc::channel::<()>(1); + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + dc.on_message(Box::new(move |_: DataChannelMessage| { + let done_tx2 = Arc::clone(&done_tx); + Box::pin(async move { + let mut done = done_tx2.lock().await; + done.take(); + }) + })); + + signal_pair(&mut offer_pc, &mut answer_pc).await?; + + close_pair(&offer_pc, &answer_pc, done_rx).await; + Ok(()) +} + +#[tokio::test] +async fn test_data_channel_send_after_connected() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (mut offer_pc, mut answer_pc) = new_pair(&api).await?; + + answer_pc.on_data_channel(Box::new(move |d: Arc| { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if d.label() != EXPECTED_LABEL { + return Box::pin(async {}); + } + Box::pin(async move { + let d2 = Arc::clone(&d); + d.on_message(Box::new(move |_: DataChannelMessage| { + let d3 = Arc::clone(&d2); + + Box::pin(async move { + let result = d3.send(&Bytes::from(b"Pong".to_vec())).await; + assert!(result.is_ok(), "Failed to send string on data channel"); + }) + })); + assert!(d.ordered(), "Ordered should be set to true"); + }) + })); + + let dc = offer_pc + .create_data_channel(EXPECTED_LABEL, None) + .await + .expect("Failed to create a PC pair for testing"); + + let (done_tx, done_rx) = mpsc::channel::<()>(1); + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + + //once := &sync.Once{} + offer_pc.on_ice_connection_state_change(Box::new(move |state: RTCIceConnectionState| { + let done_tx1 = Arc::clone(&done_tx); + let dc1 = Arc::clone(&dc); + Box::pin(async move { + if state == RTCIceConnectionState::Connected + || state == RTCIceConnectionState::Completed + { + // wasm fires completed state multiple times + /*once.Do(func()*/ + { + assert!(dc1.ordered(), "Ordered should be set to true"); + + dc1.on_message(Box::new(move |_: DataChannelMessage| { + let done_tx2 = Arc::clone(&done_tx1); + Box::pin(async move { + let mut done = done_tx2.lock().await; + done.take(); + }) + })); + + if dc1.send_text("Ping".to_owned()).await.is_err() { + // wasm binding doesn't fire OnOpen (we probably already missed it) + let dc2 = Arc::clone(&dc1); + dc1.on_open(Box::new(move || { + let dc3 = Arc::clone(&dc2); + Box::pin(async move { + let result = dc3.send_text("Ping".to_owned()).await; + assert!(result.is_ok(), "Failed to send string on data channel"); + }) + })); + } + } + } + }) + })); + + signal_pair(&mut offer_pc, &mut answer_pc).await?; + + close_pair(&offer_pc, &answer_pc, done_rx).await; + + Ok(()) +} + +#[tokio::test] +async fn test_data_channel_close() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + // "Close after PeerConnection Closed" + { + let (offer_pc, answer_pc) = new_pair(&api).await?; + + let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; + + close_pair_now(&offer_pc, &answer_pc).await; + dc.close().await?; + } + + // "Close before connected" + { + let (offer_pc, answer_pc) = new_pair(&api).await?; + + let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; + + dc.close().await?; + close_pair_now(&offer_pc, &answer_pc).await; + } + + Ok(()) +} + +#[tokio::test] +async fn test_data_channel_parameters_max_packet_life_time_exchange() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let ordered = true; + let max_packet_life_time = 3u16; + let options = RTCDataChannelInit { + ordered: Some(ordered), + max_packet_life_time: Some(max_packet_life_time), + ..Default::default() + }; + + let (mut offer_pc, mut answer_pc, dc, done_tx, done_rx) = + set_up_data_channel_parameters_test(&api, Some(options)).await?; + + // Check if parameters are correctly set + assert_eq!( + dc.ordered(), + ordered, + "Ordered should be same value as set in DataChannelInit" + ); + assert_eq!( + dc.max_packet_lifetime(), + max_packet_life_time, + "should match" + ); + + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + answer_pc.on_data_channel(Box::new(move |d: Arc| { + if d.label() != EXPECTED_LABEL { + return Box::pin(async {}); + } + // Check if parameters are correctly set + assert_eq!( + d.ordered(), + ordered, + "Ordered should be same value as set in DataChannelInit" + ); + assert_eq!( + d.max_packet_lifetime(), + max_packet_life_time, + "should match" + ); + let done_tx2 = Arc::clone(&done_tx); + Box::pin(async move { + let mut done = done_tx2.lock().await; + done.take(); + }) + })); + + close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; + + Ok(()) +} + +#[tokio::test] +async fn test_data_channel_parameters_max_retransmits_exchange() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let ordered = false; + let max_retransmits = 3000u16; + let options = RTCDataChannelInit { + ordered: Some(ordered), + max_retransmits: Some(max_retransmits), + ..Default::default() + }; + + let (mut offer_pc, mut answer_pc, dc, done_tx, done_rx) = + set_up_data_channel_parameters_test(&api, Some(options)).await?; + + // Check if parameters are correctly set + assert!(!dc.ordered(), "Ordered should be set to false"); + assert_eq!(dc.max_retransmits(), max_retransmits, "should match"); + + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + answer_pc.on_data_channel(Box::new(move |d: Arc| { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if d.label() != EXPECTED_LABEL { + return Box::pin(async {}); + } + + // Check if parameters are correctly set + assert!(!d.ordered(), "Ordered should be set to false"); + assert_eq!(max_retransmits, d.max_retransmits(), "should match"); + let done_tx2 = Arc::clone(&done_tx); + Box::pin(async move { + let mut done = done_tx2.lock().await; + done.take(); + }) + })); + + close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; + + Ok(()) +} + +#[tokio::test] +async fn test_data_channel_parameters_protocol_exchange() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let protocol = "json".to_owned(); + let options = RTCDataChannelInit { + protocol: Some(protocol.clone()), + ..Default::default() + }; + + let (mut offer_pc, mut answer_pc, dc, done_tx, done_rx) = + set_up_data_channel_parameters_test(&api, Some(options)).await?; + + // Check if parameters are correctly set + assert_eq!( + protocol, + dc.protocol(), + "Protocol should match DataChannelConfig" + ); + + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + answer_pc.on_data_channel(Box::new(move |d: Arc| { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if d.label() != EXPECTED_LABEL { + return Box::pin(async {}); + } + // Check if parameters are correctly set + assert_eq!( + protocol, + d.protocol(), + "Protocol should match what channel creator declared" + ); + + let done_tx2 = Arc::clone(&done_tx); + Box::pin(async move { + let mut done = done_tx2.lock().await; + done.take(); + }) + })); + + close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; + + Ok(()) +} + +#[tokio::test] +async fn test_data_channel_parameters_negotiated_exchange() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + const EXPECTED_MESSAGE: &str = "Hello World"; + + let id = 500u16; + let options = RTCDataChannelInit { + negotiated: Some(id), + ..Default::default() + }; + + let (mut offer_pc, mut answer_pc, offer_datachannel, done_tx, done_rx) = + set_up_data_channel_parameters_test(&api, Some(options.clone())).await?; + + let answer_datachannel = answer_pc + .create_data_channel(EXPECTED_LABEL, Some(options)) + .await?; + + answer_pc.on_data_channel(Box::new(move |d: Arc| { + // Ignore our default channel, exists to force ICE candidates. See signalPair for more info + if d.label() == "initial_data_channel" { + return Box::pin(async {}); + } + panic!("OnDataChannel must not be fired when negotiated == true"); + })); + + offer_pc.on_data_channel(Box::new(move |_d: Arc| { + panic!("OnDataChannel must not be fired when negotiated == true"); + })); + + let seen_answer_message = Arc::new(AtomicBool::new(false)); + let seen_offer_message = Arc::new(AtomicBool::new(false)); + + let seen_answer_message2 = Arc::clone(&seen_answer_message); + answer_datachannel.on_message(Box::new(move |msg: DataChannelMessage| { + if msg.is_string && msg.data == EXPECTED_MESSAGE { + seen_answer_message2.store(true, Ordering::SeqCst); + } + + Box::pin(async {}) + })); + + let seen_offer_message2 = Arc::clone(&seen_offer_message); + offer_datachannel.on_message(Box::new(move |msg: DataChannelMessage| { + if msg.is_string && msg.data == EXPECTED_MESSAGE { + seen_offer_message2.store(true, Ordering::SeqCst); + } + Box::pin(async {}) + })); + + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + tokio::spawn(async move { + loop { + if seen_answer_message.load(Ordering::SeqCst) + && seen_offer_message.load(Ordering::SeqCst) + { + break; + } + + if offer_datachannel.ready_state() == RTCDataChannelState::Open { + offer_datachannel + .send_text(EXPECTED_MESSAGE.to_owned()) + .await?; + } + if answer_datachannel.ready_state() == RTCDataChannelState::Open { + answer_datachannel + .send_text(EXPECTED_MESSAGE.to_owned()) + .await?; + } + + tokio::time::sleep(Duration::from_millis(50)).await; + } + + let mut done = done_tx.lock().await; + done.take(); + + Result::<()>::Ok(()) + }); + + close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; + + Ok(()) +} + +#[tokio::test] +async fn test_data_channel_event_handlers() -> Result<()> { + let api = APIBuilder::new().build(); + + let dc = RTCDataChannel { + setting_engine: Arc::clone(&api.setting_engine), + ..Default::default() + }; + + let (on_open_called_tx, mut on_open_called_rx) = mpsc::channel::<()>(1); + let (on_message_called_tx, mut on_message_called_rx) = mpsc::channel::<()>(1); + + // Verify that the noop case works + dc.do_open(); + + let on_open_called_tx = Arc::new(Mutex::new(Some(on_open_called_tx))); + dc.on_open(Box::new(move || { + let on_open_called_tx2 = Arc::clone(&on_open_called_tx); + Box::pin(async move { + let mut done = on_open_called_tx2.lock().await; + done.take(); + }) + })); + + let on_message_called_tx = Arc::new(Mutex::new(Some(on_message_called_tx))); + dc.on_message(Box::new(move |_: DataChannelMessage| { + let on_message_called_tx2 = Arc::clone(&on_message_called_tx); + Box::pin(async move { + let mut done = on_message_called_tx2.lock().await; + done.take(); + }) + })); + + // Verify that the set handlers are called + dc.do_open(); + dc.do_message(DataChannelMessage { + is_string: false, + data: Bytes::from_static(b"o hai"), + }) + .await; + + // Wait for all handlers to be called + let _ = on_open_called_rx.recv().await; + let _ = on_message_called_rx.recv().await; + + Ok(()) +} + +#[tokio::test] +async fn test_data_channel_messages_are_ordered() -> Result<()> { + let api = APIBuilder::new().build(); + + let dc = RTCDataChannel { + setting_engine: Arc::clone(&api.setting_engine), + ..Default::default() + }; + + let m = 16u64; + let (out_tx, mut out_rx) = mpsc::channel::(m as usize); + + let out_tx = Arc::new(out_tx); + + let out_tx1 = Arc::clone(&out_tx); + dc.on_message(Box::new(move |msg: DataChannelMessage| { + let out_tx2 = Arc::clone(&out_tx1); + + Box::pin(async move { + // randomly sleep + let r = rand::random::() % m; + tokio::time::sleep(Duration::from_millis(r)).await; + + let mut buf = [0u8; 8]; + for i in 0..8 { + buf[i] = msg.data[i]; + } + let s = u64::from_be_bytes(buf); + + let _ = out_tx2.send(s).await; + }) + })); + + tokio::spawn(async move { + for j in 1..=m { + let buf = j.to_be_bytes().to_vec(); + + dc.do_message(DataChannelMessage { + is_string: false, + data: Bytes::from(buf), + }) + .await; + // Change the registered handler a couple of times to make sure + // that everything continues to work, we don't lose messages, etc. + if j % 2 == 0 { + let out_tx1 = Arc::clone(&out_tx); + dc.on_message(Box::new(move |msg: DataChannelMessage| { + let out_tx2 = Arc::clone(&out_tx1); + + Box::pin(async move { + // randomly sleep + let r = rand::random::() % m; + tokio::time::sleep(Duration::from_millis(r)).await; + + let mut buf = [0u8; 8]; + for i in 0..8 { + buf[i] = msg.data[i]; + } + let s = u64::from_be_bytes(buf); + + let _ = out_tx2.send(s).await; + }) + })); + } + } + }); + + let mut values = vec![]; + for _ in 1..=m { + if let Some(v) = out_rx.recv().await { + values.push(v); + } else { + break; + } + } + + let mut expected = vec![0u64; m as usize]; + for i in 1..=m as usize { + expected[i - 1] = i as u64; + } + assert_eq!(values, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_data_channel_parameters_go() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + //"MaxPacketLifeTime exchange" + { + let ordered = true; + let max_packet_life_time = 3u16; + let options = RTCDataChannelInit { + ordered: Some(ordered), + max_packet_life_time: Some(max_packet_life_time), + ..Default::default() + }; + + let (mut offer_pc, mut answer_pc, dc, done_tx, done_rx) = + set_up_data_channel_parameters_test(&api, Some(options)).await?; + + // Check if parameters are correctly set + assert!(dc.ordered(), "Ordered should be set to true"); + assert_eq!( + max_packet_life_time, + dc.max_packet_lifetime(), + "should match" + ); + + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + answer_pc.on_data_channel(Box::new(move |d: Arc| { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if d.label() != EXPECTED_LABEL { + return Box::pin(async {}); + } + + // Check if parameters are correctly set + assert!(d.ordered, "Ordered should be set to true"); + assert_eq!( + max_packet_life_time, + d.max_packet_lifetime(), + "should match" + ); + + let done_tx2 = Arc::clone(&done_tx); + Box::pin(async move { + let mut done = done_tx2.lock().await; + done.take(); + }) + })); + + close_reliability_param_test(&mut offer_pc, &mut answer_pc, done_rx).await?; + } + + //"All other property methods" + { + let id = 123u16; + let dc = RTCDataChannel { + id: AtomicU16::new(id), + label: "mylabel".to_owned(), + protocol: "myprotocol".to_owned(), + negotiated: true, + ..Default::default() + }; + + assert_eq!(dc.id.load(Ordering::SeqCst), dc.id(), "should match"); + assert_eq!(dc.label, dc.label(), "should match"); + assert_eq!(dc.protocol, dc.protocol(), "should match"); + assert_eq!(dc.negotiated, dc.negotiated(), "should match"); + assert_eq!(0, dc.buffered_amount().await, "should match"); + dc.set_buffered_amount_low_threshold(1500).await; + assert_eq!( + 1500, + dc.buffered_amount_low_threshold().await, + "should match" + ); + } + + Ok(()) +} + +//use log::LevelFilter; +//use std::io::Write; + +#[tokio::test] +async fn test_data_channel_buffered_amount_set_before_open() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, LevelFilter::Trace) + .init();*/ + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let n_cbs = Arc::new(AtomicU16::new(0)); + let buf = Bytes::from_static(&[0u8; 1000]); + + let (mut offer_pc, mut answer_pc) = new_pair(&api).await?; + + let (done_tx, done_rx) = mpsc::channel::<()>(1); + + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + let n_packets_received = Arc::new(AtomicU16::new(0)); + answer_pc.on_data_channel(Box::new(move |d: Arc| { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if d.label() != EXPECTED_LABEL { + return Box::pin(async {}); + } + + let done_tx2 = Arc::clone(&done_tx); + let n_packets_received2 = Arc::clone(&n_packets_received); + Box::pin(async move { + d.on_message(Box::new(move |_msg: DataChannelMessage| { + let n = n_packets_received2.fetch_add(1, Ordering::SeqCst); + if n == 9 { + let done_tx3 = Arc::clone(&done_tx2); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + let mut done = done_tx3.lock().await; + done.take(); + }); + } + + Box::pin(async {}) + })); + + assert!(d.ordered(), "Ordered should be set to true"); + }) + })); + + let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; + + assert!(dc.ordered(), "Ordered should be set to true"); + + let dc2 = Arc::clone(&dc); + dc.on_open(Box::new(move || { + let dc3 = Arc::clone(&dc2); + Box::pin(async move { + for _ in 0..10 { + assert!( + dc3.send(&buf).await.is_ok(), + "Failed to send string on data channel" + ); + assert_eq!( + 1500, + dc3.buffered_amount_low_threshold().await, + "value mismatch" + ); + } + }) + })); + + dc.on_message(Box::new(|_msg: DataChannelMessage| Box::pin(async {}))); + + // The value is temporarily stored in the dc object + // until the dc gets opened + dc.set_buffered_amount_low_threshold(1500).await; + // The callback function is temporarily stored in the dc object + // until the dc gets opened + let n_cbs2 = Arc::clone(&n_cbs); + dc.on_buffered_amount_low(Box::new(move || { + n_cbs2.fetch_add(1, Ordering::SeqCst); + Box::pin(async {}) + })) + .await; + + signal_pair(&mut offer_pc, &mut answer_pc).await?; + + close_pair(&offer_pc, &answer_pc, done_rx).await; + + assert!( + n_cbs.load(Ordering::SeqCst) > 0, + "callback should be made at least once" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_data_channel_buffered_amount_set_after_open() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let n_cbs = Arc::new(AtomicU16::new(0)); + let buf = Bytes::from_static(&[0u8; 1000]); + + let (mut offer_pc, mut answer_pc) = new_pair(&api).await?; + + let (done_tx, done_rx) = mpsc::channel::<()>(1); + + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + let n_packets_received = Arc::new(AtomicU16::new(0)); + answer_pc.on_data_channel(Box::new(move |d: Arc| { + // Make sure this is the data channel we were looking for. (Not the one + // created in signalPair). + if d.label() != EXPECTED_LABEL { + return Box::pin(async {}); + } + + let done_tx2 = Arc::clone(&done_tx); + let n_packets_received2 = Arc::clone(&n_packets_received); + Box::pin(async move { + d.on_message(Box::new(move |_msg: DataChannelMessage| { + let n = n_packets_received2.fetch_add(1, Ordering::SeqCst); + if n == 9 { + let done_tx3 = Arc::clone(&done_tx2); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + let mut done = done_tx3.lock().await; + done.take(); + }); + } + + Box::pin(async {}) + })); + + assert!(d.ordered(), "Ordered should be set to true"); + }) + })); + + let dc = offer_pc.create_data_channel(EXPECTED_LABEL, None).await?; + + assert!(dc.ordered(), "Ordered should be set to true"); + + let dc2 = Arc::clone(&dc); + let n_cbs2 = Arc::clone(&n_cbs); + dc.on_open(Box::new(move || { + let dc3 = Arc::clone(&dc2); + Box::pin(async move { + // The value should directly be passed to sctp + dc3.set_buffered_amount_low_threshold(1500).await; + // The callback function should directly be passed to sctp + dc3.on_buffered_amount_low(Box::new(move || { + n_cbs2.fetch_add(1, Ordering::SeqCst); + Box::pin(async {}) + })) + .await; + + for _ in 0..10 { + assert!( + dc3.send(&buf).await.is_ok(), + "Failed to send string on data channel" + ); + assert_eq!( + 1500, + dc3.buffered_amount_low_threshold().await, + "value mismatch" + ); + } + }) + })); + + dc.on_message(Box::new(|_msg: DataChannelMessage| Box::pin(async {}))); + + signal_pair(&mut offer_pc, &mut answer_pc).await?; + + close_pair(&offer_pc, &answer_pc, done_rx).await; + + assert!( + n_cbs.load(Ordering::SeqCst) > 0, + "callback should be made at least once" + ); + + Ok(()) +} + +#[tokio::test] +async fn test_eof_detach() -> Result<()> { + let label: &str = "test-channel"; + let test_data: &'static str = "this is some test data"; + + // Use Detach data channels mode + let mut s = SettingEngine::default(); + s.detach_data_channels(); + let api = APIBuilder::new().with_setting_engine(s).build(); + + // Set up two peer connections. + let mut pca = api.new_peer_connection(RTCConfiguration::default()).await?; + let mut pcb = api.new_peer_connection(RTCConfiguration::default()).await?; + + let wg = WaitGroup::new(); + + let (dc_chan_tx, mut dc_chan_rx) = mpsc::channel(1); + let dc_chan_tx = Arc::new(dc_chan_tx); + pcb.on_data_channel(Box::new(move |dc: Arc| { + if dc.label() != label { + return Box::pin(async {}); + } + log::debug!("OnDataChannel was called"); + let dc_chan_tx2 = Arc::clone(&dc_chan_tx); + let dc2 = Arc::clone(&dc); + Box::pin(async move { + let dc3 = Arc::clone(&dc2); + dc2.on_open(Box::new(move || { + let dc_chan_tx3 = Arc::clone(&dc_chan_tx2); + let dc4 = Arc::clone(&dc3); + Box::pin(async move { + let detached = match dc4.detach().await { + Ok(detached) => detached, + Err(err) => { + log::debug!("Detach failed: {}", err); + panic!(); + } + }; + + let _ = dc_chan_tx3.send(detached).await; + }) + })); + }) + })); + + let w = wg.worker(); + tokio::spawn(async move { + let _d = w; + + log::debug!("Waiting for OnDataChannel"); + let dc = dc_chan_rx.recv().await.unwrap(); + log::debug!("data channel opened"); + + log::debug!("Waiting for ping..."); + let mut msg = vec![0u8; 256]; + let n = dc.read(&mut msg).await?; + log::debug!("Received ping! {:?}\n", &msg[..n]); + + assert_eq!(test_data.as_bytes(), &msg[..n]); + log::debug!("Received ping successfully!"); + + dc.close().await?; + + Result::<()>::Ok(()) + }); + + signal_pair(&mut pca, &mut pcb).await?; + + let attached = pca.create_data_channel(label, None).await?; + + log::debug!("Waiting for data channel to open"); + let (open_tx, mut open_rx) = mpsc::channel::<()>(1); + let open_tx = Arc::new(open_tx); + attached.on_open(Box::new(move || { + let open_tx2 = Arc::clone(&open_tx); + Box::pin(async move { + let _ = open_tx2.send(()).await; + }) + })); + + let _ = open_rx.recv().await; + log::debug!("data channel opened"); + + let dc = attached.detach().await?; + + let w = wg.worker(); + tokio::spawn(async move { + let _d = w; + log::debug!("Sending ping..."); + dc.write(&Bytes::from_static(test_data.as_bytes())).await?; + log::debug!("Sent ping"); + + dc.close().await?; + + log::debug!("Waiting for EOF"); + let mut buf = vec![0u8; 256]; + let n = dc.read(&mut buf).await?; + assert_eq!(0, n, "should be empty"); + + Result::<()>::Ok(()) + }); + + wg.wait().await; + + close_pair_now(&pca, &pcb).await; + + Ok(()) +} + +#[tokio::test] +async fn test_eof_no_detach() -> Result<()> { + let label: &str = "test-channel"; + let test_data: &'static [u8] = b"this is some test data"; + + let api = APIBuilder::new().build(); + + // Set up two peer connections. + let mut pca = api.new_peer_connection(RTCConfiguration::default()).await?; + let mut pcb = api.new_peer_connection(RTCConfiguration::default()).await?; + + let (dca_closed_ch_tx, mut dca_closed_ch_rx) = mpsc::channel::<()>(1); + let (dcb_closed_ch_tx, mut dcb_closed_ch_rx) = mpsc::channel::<()>(1); + + let dcb_closed_ch_tx = Arc::new(dcb_closed_ch_tx); + pcb.on_data_channel(Box::new(move |dc: Arc| { + if dc.label() != label { + return Box::pin(async {}); + } + + log::debug!("pcb: new datachannel: {}", dc.label()); + + let dcb_closed_ch_tx2 = Arc::clone(&dcb_closed_ch_tx); + Box::pin(async move { + // Register channel opening handling + dc.on_open(Box::new(move || { + log::debug!("pcb: datachannel opened"); + Box::pin(async {}) + })); + + dc.on_close(Box::new(move || { + // (2) + log::debug!("pcb: data channel closed"); + let dcb_closed_ch_tx3 = Arc::clone(&dcb_closed_ch_tx2); + Box::pin(async move { + let _ = dcb_closed_ch_tx3.send(()).await; + }) + })); + + // Register the OnMessage to handle incoming messages + log::debug!("pcb: registering onMessage callback"); + dc.on_message(Box::new(|dc_msg: DataChannelMessage| { + let test_data: &'static [u8] = b"this is some test data"; + log::debug!("pcb: received ping: {:?}", dc_msg.data); + assert_eq!(&dc_msg.data[..], test_data, "data mismatch"); + Box::pin(async {}) + })); + }) + })); + + let dca = pca.create_data_channel(label, None).await?; + let dca2 = Arc::clone(&dca); + dca.on_open(Box::new(move || { + log::debug!("pca: data channel opened"); + log::debug!("pca: sending {:?}", test_data); + let dca3 = Arc::clone(&dca2); + Box::pin(async move { + let _ = dca3.send(&Bytes::from_static(test_data)).await; + log::debug!("pca: sent ping"); + assert!(dca3.close().await.is_ok(), "should succeed"); // <-- dca closes + }) + })); + + let dca_closed_ch_tx = Arc::new(dca_closed_ch_tx); + dca.on_close(Box::new(move || { + // (1) + log::debug!("pca: data channel closed"); + let dca_closed_ch_tx2 = Arc::clone(&dca_closed_ch_tx); + Box::pin(async move { + let _ = dca_closed_ch_tx2.send(()).await; + }) + })); + + // Register the OnMessage to handle incoming messages + log::debug!("pca: registering onMessage callback"); + dca.on_message(Box::new(move |dc_msg: DataChannelMessage| { + log::debug!("pca: received pong: {:?}", &dc_msg.data[..]); + assert_eq!(&dc_msg.data[..], test_data, "data mismatch"); + Box::pin(async {}) + })); + + signal_pair(&mut pca, &mut pcb).await?; + + // When dca closes the channel, + // (1) dca.Onclose() will fire immediately, then + // (2) dcb.OnClose will also fire + let _ = dca_closed_ch_rx.recv().await; // (1) + let _ = dcb_closed_ch_rx.recv().await; // (2) + + close_pair_now(&pca, &pcb).await; + + Ok(()) +} + +// Assert that a Session Description that doesn't follow +// draft-ietf-mmusic-sctp-sdp is still accepted +#[tokio::test] +async fn test_data_channel_non_standard_session_description() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (offer_pc, answer_pc) = new_pair(&api).await?; + + let _ = offer_pc.create_data_channel("foo", None).await?; + + let (on_data_channel_called_tx, mut on_data_channel_called_rx) = mpsc::channel::<()>(1); + let on_data_channel_called_tx = Arc::new(on_data_channel_called_tx); + answer_pc.on_data_channel(Box::new(move |_: Arc| { + let on_data_channel_called_tx2 = Arc::clone(&on_data_channel_called_tx); + Box::pin(async move { + let _ = on_data_channel_called_tx2.send(()).await; + }) + })); + + let offer = offer_pc.create_offer(None).await?; + + let mut offer_gathering_complete = offer_pc.gathering_complete_promise().await; + offer_pc.set_local_description(offer).await?; + let _ = offer_gathering_complete.recv().await; + + let mut offer = offer_pc.local_description().await.unwrap(); + + // Replace with old values + const OLD_APPLICATION: &str = "m=application 63743 DTLS/SCTP 5000\r"; + const OLD_ATTRIBUTE: &str = "a=sctpmap:5000 webrtc-datachannel 256\r"; + + let re = Regex::new(r"m=application (.*?)\r").unwrap(); + offer.sdp = re + .replace_all(offer.sdp.as_str(), OLD_APPLICATION) + .to_string(); + let re = Regex::new(r"a=sctp-port(.*?)\r").unwrap(); + offer.sdp = re + .replace_all(offer.sdp.as_str(), OLD_ATTRIBUTE) + .to_string(); + + // Assert that replace worked + assert!(offer.sdp.contains(OLD_APPLICATION)); + assert!(offer.sdp.contains(OLD_ATTRIBUTE)); + + answer_pc.set_remote_description(offer).await?; + + let answer = answer_pc.create_answer(None).await?; + + let mut answer_gathering_complete = answer_pc.gathering_complete_promise().await; + answer_pc.set_local_description(answer).await?; + let _ = answer_gathering_complete.recv().await; + + let answer = answer_pc.local_description().await.unwrap(); + offer_pc.set_remote_description(answer).await?; + + let _ = on_data_channel_called_rx.recv().await; + + close_pair_now(&offer_pc, &answer_pc).await; + + Ok(()) +} + +struct TestOrtcStack { + //api *API + gatherer: Arc, + ice: Arc, + dtls: Arc, + sctp: Arc, +} + +struct TestOrtcSignal { + ice_candidates: Vec, //`json:"iceCandidates"` + ice_parameters: RTCIceParameters, //`json:"iceParameters"` + dtls_parameters: DTLSParameters, //`json:"dtlsParameters"` + sctp_capabilities: SCTPTransportCapabilities, //`json:"sctpCapabilities"` +} + +impl TestOrtcStack { + async fn new(api: &API) -> Result { + // Create the ICE gatherer + let gatherer = Arc::new(api.new_ice_gatherer(RTCIceGatherOptions::default())?); + + // Construct the ICE transport + let ice = Arc::new(api.new_ice_transport(Arc::clone(&gatherer))); + + // Construct the DTLS transport + let dtls = Arc::new(api.new_dtls_transport(Arc::clone(&ice), vec![])?); + + // Construct the SCTP transport + let sctp = Arc::new(api.new_sctp_transport(Arc::clone(&dtls))?); + + Ok(TestOrtcStack { + gatherer, + ice, + dtls, + sctp, + }) + } + + async fn set_signal(&self, sig: &TestOrtcSignal, is_offer: bool) -> Result<()> { + let ice_role = if is_offer { + RTCIceRole::Controlling + } else { + RTCIceRole::Controlled + }; + + self.ice.set_remote_candidates(&sig.ice_candidates).await?; + + // Start the ICE transport + self.ice.start(&sig.ice_parameters, Some(ice_role)).await?; + + // Start the DTLS transport + self.dtls.start(sig.dtls_parameters.clone()).await?; + + // Start the SCTP transport + self.sctp.start(sig.sctp_capabilities).await?; + + Ok(()) + } + + async fn get_signal(&self) -> Result { + let (gather_finished_tx, mut gather_finished_rx) = mpsc::channel::<()>(1); + let gather_finished_tx = Arc::new(gather_finished_tx); + self.gatherer + .on_local_candidate(Box::new(move |i: Option| { + let gather_finished_tx2 = Arc::clone(&gather_finished_tx); + Box::pin(async move { + if i.is_none() { + let _ = gather_finished_tx2.send(()).await; + } + }) + })); + + self.gatherer.gather().await?; + + let _ = gather_finished_rx.recv().await; + + let ice_candidates = self.gatherer.get_local_candidates().await?; + + let ice_parameters = self.gatherer.get_local_parameters().await?; + + let dtls_parameters = self.dtls.get_local_parameters()?; + + let sctp_capabilities = self.sctp.get_capabilities(); + + Ok(TestOrtcSignal { + ice_candidates, + ice_parameters, + dtls_parameters, + sctp_capabilities, + }) + } + + async fn close(&self) -> Result<()> { + let mut close_errs = vec![]; + + if let Err(err) = self.sctp.stop().await { + close_errs.push(err); + } + + if let Err(err) = self.ice.stop().await { + close_errs.push(err); + } + + flatten_errs(close_errs) + } +} + +async fn new_ortc_pair(api: &API) -> Result<(Arc, Arc)> { + let sa = Arc::new(TestOrtcStack::new(api).await?); + let sb = Arc::new(TestOrtcStack::new(api).await?); + Ok((sa, sb)) +} + +async fn signal_ortc_pair(stack_a: Arc, stack_b: Arc) -> Result<()> { + let sig_a = stack_a.get_signal().await?; + let sig_b = stack_b.get_signal().await?; + + let (a_tx, mut a_rx) = mpsc::channel(1); + let (b_tx, mut b_rx) = mpsc::channel(1); + + tokio::spawn(async move { + let _ = a_tx.send(stack_b.set_signal(&sig_a, false).await).await; + }); + + tokio::spawn(async move { + let _ = b_tx.send(stack_a.set_signal(&sig_b, true).await).await; + }); + + let err_a = a_rx.recv().await.unwrap(); + let err_b = b_rx.recv().await.unwrap(); + + let mut close_errs = vec![]; + if let Err(err) = err_a { + close_errs.push(err); + } + if let Err(err) = err_b { + close_errs.push(err); + } + + flatten_errs(close_errs) +} + +#[tokio::test] +async fn test_data_channel_ortc_e2e() -> Result<()> { + let api = APIBuilder::new().build(); + + let (stack_a, stack_b) = new_ortc_pair(&api).await?; + + let (await_setup_tx, mut await_setup_rx) = mpsc::channel::<()>(1); + let (await_string_tx, mut await_string_rx) = mpsc::channel::<()>(1); + let (await_binary_tx, mut await_binary_rx) = mpsc::channel::<()>(1); + + let await_setup_tx = Arc::new(await_setup_tx); + let await_string_tx = Arc::new(await_string_tx); + let await_binary_tx = Arc::new(await_binary_tx); + stack_b + .sctp + .on_data_channel(Box::new(move |d: Arc| { + let await_setup_tx2 = Arc::clone(&await_setup_tx); + let await_string_tx2 = Arc::clone(&await_string_tx); + let await_binary_tx2 = Arc::clone(&await_binary_tx); + Box::pin(async move { + let _ = await_setup_tx2.send(()).await; + + d.on_message(Box::new(move |msg: DataChannelMessage| { + let await_string_tx3 = Arc::clone(&await_string_tx2); + let await_binary_tx3 = Arc::clone(&await_binary_tx2); + Box::pin(async move { + if msg.is_string { + let _ = await_string_tx3.send(()).await; + } else { + let _ = await_binary_tx3.send(()).await; + } + }) + })); + }) + })); + + signal_ortc_pair(Arc::clone(&stack_a), Arc::clone(&stack_b)).await?; + + let dc_params = DataChannelParameters { + label: "Foo".to_owned(), + negotiated: None, + ..Default::default() + }; + + let channel_a = api + .new_data_channel(Arc::clone(&stack_a.sctp), dc_params) + .await?; + + let _ = await_setup_rx.recv().await; + + channel_a.send_text("ABC".to_owned()).await?; + channel_a.send(&Bytes::from_static(b"ABC")).await?; + + let _ = await_string_rx.recv().await; + let _ = await_binary_rx.recv().await; + + stack_a.close().await?; + stack_b.close().await?; + + // attempt to send when channel is closed + let result = channel_a.send(&Bytes::from_static(b"ABC")).await; + if let Err(err) = result { + assert_eq!( + Error::ErrClosedPipe, + err, + "expected ErrClosedPipe, but got {err}" + ); + } else { + panic!(); + } + + let result = channel_a.send_text("test".to_owned()).await; + if let Err(err) = result { + assert_eq!( + Error::ErrClosedPipe, + err, + "expected ErrClosedPipe, but got {err}" + ); + } else { + panic!(); + } + + let result = channel_a.ensure_open(); + if let Err(err) = result { + assert_eq!( + Error::ErrClosedPipe, + err, + "expected ErrClosedPipe, but got {err}" + ); + } else { + panic!(); + } + + Ok(()) +} diff --git a/rtc/src/data_channel/mod.rs b/rtc/src/data_channel/mod.rs new file mode 100644 index 0000000..d25db3d --- /dev/null +++ b/rtc/src/data_channel/mod.rs @@ -0,0 +1,555 @@ +#[cfg(test)] +mod data_channel_test; + +pub mod data_channel_init; +pub mod data_channel_message; +pub mod data_channel_parameters; +pub mod data_channel_state; + +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU8, AtomicUsize, Ordering}; +use std::sync::{Arc, Weak}; +use std::time::SystemTime; + +use arc_swap::ArcSwapOption; +use bytes::Bytes; +use data::message::message_channel_open::ChannelType; +use data_channel_message::*; +use data_channel_parameters::*; +use data_channel_state::RTCDataChannelState; +use sctp::stream::OnBufferedAmountLowFn; +use tokio::sync::{Mutex, Notify}; +use util::sync::Mutex as SyncMutex; + +use crate::api::setting_engine::SettingEngine; +use crate::error::{Error, OnErrorHdlrFn, Result}; +use crate::sctp_transport::RTCSctpTransport; +use crate::stats::stats_collector::StatsCollector; +use crate::stats::{DataChannelStats, StatsReportType}; + +/// message size limit for Chromium +const DATA_CHANNEL_BUFFER_SIZE: u16 = u16::MAX; + +pub type OnMessageHdlrFn = Box< + dyn (FnMut(DataChannelMessage) -> Pin + Send + 'static>>) + + Send + + Sync, +>; + +pub type OnOpenHdlrFn = + Box Pin + Send + 'static>>) + Send + Sync>; + +pub type OnCloseHdlrFn = + Box Pin + Send + 'static>>) + Send + Sync>; + +/// DataChannel represents a WebRTC DataChannel +/// The DataChannel interface represents a network channel +/// which can be used for bidirectional peer-to-peer transfers of arbitrary data +#[derive(Default)] +pub struct RTCDataChannel { + pub(crate) stats_id: String, + pub(crate) label: String, + pub(crate) ordered: bool, + pub(crate) max_packet_lifetime: u16, + pub(crate) max_retransmits: u16, + pub(crate) protocol: String, + pub(crate) negotiated: bool, + pub(crate) id: AtomicU16, + pub(crate) ready_state: Arc, // DataChannelState + pub(crate) buffered_amount_low_threshold: AtomicUsize, + pub(crate) detach_called: Arc, + + // The binaryType represents attribute MUST, on getting, return the value to + // which it was last set. On setting, if the new value is either the string + // "blob" or the string "arraybuffer", then set the IDL attribute to this + // new value. Otherwise, throw a SyntaxError. When an DataChannel object + // is created, the binaryType attribute MUST be initialized to the string + // "blob". This attribute controls how binary data is exposed to scripts. + // binaryType string + pub(crate) on_message_handler: Arc>>, + pub(crate) on_open_handler: SyncMutex>, + pub(crate) on_close_handler: Arc>>, + pub(crate) on_error_handler: Arc>>, + + pub(crate) on_buffered_amount_low: Mutex>, + + pub(crate) sctp_transport: Mutex>>, + pub(crate) data_channel: Mutex>>, + + pub(crate) notify_tx: Arc, + + // A reference to the associated api object used by this datachannel + pub(crate) setting_engine: Arc, +} + +impl RTCDataChannel { + // create the DataChannel object before the networking is set up. + pub(crate) fn new(params: DataChannelParameters, setting_engine: Arc) -> Self { + // the id value if non-negotiated doesn't matter, since it will be overwritten + // on opening + let id = params.negotiated.unwrap_or(0); + RTCDataChannel { + stats_id: format!( + "DataChannel-{}", + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map_or(0, |d| d.as_nanos()) + ), + label: params.label, + protocol: params.protocol, + negotiated: params.negotiated.is_some(), + id: AtomicU16::new(id), + ordered: params.ordered, + max_packet_lifetime: params.max_packet_life_time, + max_retransmits: params.max_retransmits, + ready_state: Arc::new(AtomicU8::new(RTCDataChannelState::Connecting as u8)), + detach_called: Arc::new(AtomicBool::new(false)), + + notify_tx: Arc::new(Notify::new()), + + setting_engine, + ..Default::default() + } + } + + /// open opens the datachannel over the sctp transport + pub(crate) async fn open(&self, sctp_transport: Arc) -> Result<()> { + if let Some(association) = sctp_transport.association().await { + { + let mut st = self.sctp_transport.lock().await; + if st.is_none() { + *st = Some(Arc::downgrade(&sctp_transport)); + } else { + return Ok(()); + } + } + + let channel_type; + let reliability_parameter; + + if self.max_packet_lifetime == 0 && self.max_retransmits == 0 { + reliability_parameter = 0u32; + if self.ordered { + channel_type = ChannelType::Reliable; + } else { + channel_type = ChannelType::ReliableUnordered; + } + } else if self.max_retransmits != 0 { + reliability_parameter = self.max_retransmits as u32; + if self.ordered { + channel_type = ChannelType::PartialReliableRexmit; + } else { + channel_type = ChannelType::PartialReliableRexmitUnordered; + } + } else { + reliability_parameter = self.max_packet_lifetime as u32; + if self.ordered { + channel_type = ChannelType::PartialReliableTimed; + } else { + channel_type = ChannelType::PartialReliableTimedUnordered; + } + } + + let cfg = data::data_channel::Config { + channel_type, + priority: data::message::message_channel_open::CHANNEL_PRIORITY_NORMAL, + reliability_parameter, + label: self.label.clone(), + protocol: self.protocol.clone(), + negotiated: self.negotiated, + }; + + if !self.negotiated { + self.id.store( + sctp_transport + .generate_and_set_data_channel_id( + sctp_transport.dtls_transport.role().await, + ) + .await?, + Ordering::SeqCst, + ); + } + + let dc = data::data_channel::DataChannel::dial(&association, self.id(), cfg).await?; + + // buffered_amount_low_threshold and on_buffered_amount_low might be set earlier + dc.set_buffered_amount_low_threshold( + self.buffered_amount_low_threshold.load(Ordering::SeqCst), + ); + { + let mut on_buffered_amount_low = self.on_buffered_amount_low.lock().await; + if let Some(f) = on_buffered_amount_low.take() { + dc.on_buffered_amount_low(f); + } + } + + self.handle_open(Arc::new(dc)).await; + + Ok(()) + } else { + Err(Error::ErrSCTPNotEstablished) + } + } + + /// transport returns the SCTPTransport instance the DataChannel is sending over. + pub async fn transport(&self) -> Option> { + let sctp_transport = self.sctp_transport.lock().await; + sctp_transport.clone() + } + + /// on_open sets an event handler which is invoked when + /// the underlying data transport has been established (or re-established). + pub fn on_open(&self, f: OnOpenHdlrFn) { + let _ = self.on_open_handler.lock().replace(f); + + if self.ready_state() == RTCDataChannelState::Open { + self.do_open(); + } + } + + fn do_open(&self) { + let on_open_handler = self.on_open_handler.lock().take(); + if on_open_handler.is_none() { + return; + } + + let detach_data_channels = self.setting_engine.detach.data_channels; + let detach_called = Arc::clone(&self.detach_called); + tokio::spawn(async move { + if let Some(f) = on_open_handler { + f().await; + + // self.check_detach_after_open(); + // After onOpen is complete check that the user called detach + // and provide an error message if the call was missed + if detach_data_channels && !detach_called.load(Ordering::SeqCst) { + log::warn!( + "webrtc.DetachDataChannels() enabled but didn't Detach, call Detach from OnOpen" + ); + } + } + }); + } + + /// on_close sets an event handler which is invoked when + /// the underlying data transport has been closed. + pub fn on_close(&self, f: OnCloseHdlrFn) { + self.on_close_handler.store(Some(Arc::new(Mutex::new(f)))); + } + + /// on_message sets an event handler which is invoked on a binary + /// message arrival over the sctp transport from a remote peer. + /// OnMessage can currently receive messages up to 16384 bytes + /// in size. Check out the detach API if you want to use larger + /// message sizes. Note that browser support for larger messages + /// is also limited. + pub fn on_message(&self, f: OnMessageHdlrFn) { + self.on_message_handler.store(Some(Arc::new(Mutex::new(f)))); + } + + async fn do_message(&self, msg: DataChannelMessage) { + if let Some(handler) = &*self.on_message_handler.load() { + let mut f = handler.lock().await; + f(msg).await; + } + } + + pub(crate) async fn handle_open(&self, dc: Arc) { + { + let mut data_channel = self.data_channel.lock().await; + *data_channel = Some(Arc::clone(&dc)); + } + self.set_ready_state(RTCDataChannelState::Open); + + self.do_open(); + + if !self.setting_engine.detach.data_channels { + let ready_state = Arc::clone(&self.ready_state); + let on_message_handler = Arc::clone(&self.on_message_handler); + let on_close_handler = Arc::clone(&self.on_close_handler); + let on_error_handler = Arc::clone(&self.on_error_handler); + let notify_rx = self.notify_tx.clone(); + tokio::spawn(async move { + RTCDataChannel::read_loop( + notify_rx, + dc, + ready_state, + on_message_handler, + on_close_handler, + on_error_handler, + ) + .await; + }); + } + } + + /// on_error sets an event handler which is invoked when + /// the underlying data transport cannot be read. + pub fn on_error(&self, f: OnErrorHdlrFn) { + self.on_error_handler.store(Some(Arc::new(Mutex::new(f)))); + } + + async fn read_loop( + notify_rx: Arc, + data_channel: Arc, + ready_state: Arc, + on_message_handler: Arc>>, + on_close_handler: Arc>>, + on_error_handler: Arc>>, + ) { + let mut buffer = vec![0u8; DATA_CHANNEL_BUFFER_SIZE as usize]; + loop { + let (n, is_string) = tokio::select! { + _ = notify_rx.notified() => break, + result = data_channel.read_data_channel(&mut buffer) => { + match result{ + // EOF (`data_channel` was either closed or the underlying stream got + // reset by the remote) => close and run `on_close` handler. + Ok((0, _)) => + { + ready_state.store(RTCDataChannelState::Closed as u8, Ordering::SeqCst); + + let on_close_handler2 = Arc::clone(&on_close_handler); + tokio::spawn(async move { + if let Some(handler) = &*on_close_handler2.load() { + let mut f = handler.lock().await; + f().await; + } + }); + + break; + } + Ok((n, is_string)) => (n, is_string), + Err(err) => { + ready_state.store(RTCDataChannelState::Closed as u8, Ordering::SeqCst); + + let on_error_handler2 = Arc::clone(&on_error_handler); + tokio::spawn(async move { + if let Some(handler) = &*on_error_handler2.load() { + let mut f = handler.lock().await; + f(err.into()).await; + } + }); + + let on_close_handler2 = Arc::clone(&on_close_handler); + tokio::spawn(async move { + if let Some(handler) = &*on_close_handler2.load() { + let mut f = handler.lock().await; + f().await; + } + }); + + break; + } + } + } + }; + + if let Some(handler) = &*on_message_handler.load() { + let mut f = handler.lock().await; + f(DataChannelMessage { + is_string, + data: Bytes::from(buffer[..n].to_vec()), + }) + .await; + } + } + } + + /// send sends the binary message to the DataChannel peer + pub async fn send(&self, data: &Bytes) -> Result { + self.ensure_open()?; + + let data_channel = self.data_channel.lock().await; + if let Some(dc) = &*data_channel { + Ok(dc.write_data_channel(data, false).await?) + } else { + Err(Error::ErrClosedPipe) + } + } + + /// send_text sends the text message to the DataChannel peer + pub async fn send_text(&self, s: impl Into) -> Result { + self.ensure_open()?; + + let data_channel = self.data_channel.lock().await; + if let Some(dc) = &*data_channel { + Ok(dc.write_data_channel(&Bytes::from(s.into()), true).await?) + } else { + Err(Error::ErrClosedPipe) + } + } + + fn ensure_open(&self) -> Result<()> { + if self.ready_state() != RTCDataChannelState::Open { + Err(Error::ErrClosedPipe) + } else { + Ok(()) + } + } + + /// detach allows you to detach the underlying datachannel. This provides + /// an idiomatic API to work with, however it disables the OnMessage callback. + /// Before calling Detach you have to enable this behavior by calling + /// webrtc.DetachDataChannels(). Combining detached and normal data channels + /// is not supported. + /// Please refer to the data-channels-detach example and the + /// pion/datachannel documentation for the correct way to handle the + /// resulting DataChannel object. + pub async fn detach(&self) -> Result> { + if !self.setting_engine.detach.data_channels { + return Err(Error::ErrDetachNotEnabled); + } + + let data_channel = self.data_channel.lock().await; + if let Some(dc) = &*data_channel { + self.detach_called.store(true, Ordering::SeqCst); + + Ok(Arc::clone(dc)) + } else { + Err(Error::ErrDetachBeforeOpened) + } + } + + /// Close Closes the DataChannel. It may be called regardless of whether + /// the DataChannel object was created by this peer or the remote peer. + pub async fn close(&self) -> Result<()> { + if self.ready_state() == RTCDataChannelState::Closed { + return Ok(()); + } + + self.set_ready_state(RTCDataChannelState::Closing); + self.notify_tx.notify_waiters(); + + let data_channel = self.data_channel.lock().await; + if let Some(dc) = &*data_channel { + Ok(dc.close().await?) + } else { + Ok(()) + } + } + + /// label represents a label that can be used to distinguish this + /// DataChannel object from other DataChannel objects. Scripts are + /// allowed to create multiple DataChannel objects with the same label. + pub fn label(&self) -> &str { + self.label.as_str() + } + + /// Ordered returns true if the DataChannel is ordered, and false if + /// out-of-order delivery is allowed. + pub fn ordered(&self) -> bool { + self.ordered + } + + /// max_packet_lifetime represents the length of the time window (msec) during + /// which transmissions and retransmissions may occur in unreliable mode. + pub fn max_packet_lifetime(&self) -> u16 { + self.max_packet_lifetime + } + + /// max_retransmits represents the maximum number of retransmissions that are + /// attempted in unreliable mode. + pub fn max_retransmits(&self) -> u16 { + self.max_retransmits + } + + /// protocol represents the name of the sub-protocol used with this + /// DataChannel. + pub fn protocol(&self) -> &str { + self.protocol.as_str() + } + + /// negotiated represents whether this DataChannel was negotiated by the + /// application (true), or not (false). + pub fn negotiated(&self) -> bool { + self.negotiated + } + + /// ID represents the ID for this DataChannel. The value is initially + /// null, which is what will be returned if the ID was not provided at + /// channel creation time, and the DTLS role of the SCTP transport has not + /// yet been negotiated. Otherwise, it will return the ID that was either + /// selected by the script or generated. After the ID is set to a non-null + /// value, it will not change. + pub fn id(&self) -> u16 { + self.id.load(Ordering::SeqCst) + } + + /// ready_state represents the state of the DataChannel object. + pub fn ready_state(&self) -> RTCDataChannelState { + self.ready_state.load(Ordering::SeqCst).into() + } + + /// buffered_amount represents the number of bytes of application data + /// (UTF-8 text and binary data) that have been queued using send(). Even + /// though the data transmission can occur in parallel, the returned value + /// MUST NOT be decreased before the current task yielded back to the event + /// loop to prevent race conditions. The value does not include framing + /// overhead incurred by the protocol, or buffering done by the operating + /// system or network hardware. The value of buffered_amount slot will only + /// increase with each call to the send() method as long as the ready_state is + /// open; however, buffered_amount does not reset to zero once the channel + /// closes. + pub async fn buffered_amount(&self) -> usize { + let data_channel = self.data_channel.lock().await; + if let Some(dc) = &*data_channel { + dc.buffered_amount() + } else { + 0 + } + } + + /// buffered_amount_low_threshold represents the threshold at which the + /// bufferedAmount is considered to be low. When the bufferedAmount decreases + /// from above this threshold to equal or below it, the bufferedamountlow + /// event fires. buffered_amount_low_threshold is initially zero on each new + /// DataChannel, but the application may change its value at any time. + /// The threshold is set to 0 by default. + pub async fn buffered_amount_low_threshold(&self) -> usize { + let data_channel = self.data_channel.lock().await; + if let Some(dc) = &*data_channel { + dc.buffered_amount_low_threshold() + } else { + self.buffered_amount_low_threshold.load(Ordering::SeqCst) + } + } + + /// set_buffered_amount_low_threshold is used to update the threshold. + /// See buffered_amount_low_threshold(). + pub async fn set_buffered_amount_low_threshold(&self, th: usize) { + self.buffered_amount_low_threshold + .store(th, Ordering::SeqCst); + let data_channel = self.data_channel.lock().await; + if let Some(dc) = &*data_channel { + dc.set_buffered_amount_low_threshold(th); + } + } + + /// on_buffered_amount_low sets an event handler which is invoked when + /// the number of bytes of outgoing data becomes lower than the + /// buffered_amount_low_threshold. + pub async fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) { + let data_channel = self.data_channel.lock().await; + if let Some(dc) = &*data_channel { + dc.on_buffered_amount_low(f); + } else { + let mut on_buffered_amount_low = self.on_buffered_amount_low.lock().await; + *on_buffered_amount_low = Some(f); + } + } + + pub(crate) fn get_stats_id(&self) -> &str { + self.stats_id.as_str() + } + + pub(crate) async fn collect_stats(&self, collector: &StatsCollector) { + let stats = DataChannelStats::from(self).await; + collector.insert(self.stats_id.clone(), StatsReportType::DataChannel(stats)); + } + + pub(crate) fn set_ready_state(&self, r: RTCDataChannelState) { + self.ready_state.store(r as u8, Ordering::SeqCst); + } +} diff --git a/rtc/src/dtls_transport/dtls_fingerprint.rs b/rtc/src/dtls_transport/dtls_fingerprint.rs new file mode 100644 index 0000000..eaea367 --- /dev/null +++ b/rtc/src/dtls_transport/dtls_fingerprint.rs @@ -0,0 +1,15 @@ +use serde::{Deserialize, Serialize}; + +/// DTLSFingerprint specifies the hash function algorithm and certificate +/// fingerprint as described in . +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct RTCDtlsFingerprint { + /// Algorithm specifies one of the the hash function algorithms defined in + /// the 'Hash function Textual Names' registry. + pub algorithm: String, + + /// Value specifies the value of the certificate fingerprint in lowercase + /// hex string as expressed utilizing the syntax of 'fingerprint' in + /// . + pub value: String, +} diff --git a/rtc/src/dtls_transport/dtls_parameters.rs b/rtc/src/dtls_transport/dtls_parameters.rs new file mode 100644 index 0000000..6aa1447 --- /dev/null +++ b/rtc/src/dtls_transport/dtls_parameters.rs @@ -0,0 +1,11 @@ +use serde::{Deserialize, Serialize}; + +use super::dtls_fingerprint::*; +use super::dtls_role::*; + +/// DTLSParameters holds information relating to DTLS configuration. +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct DTLSParameters { + pub role: DTLSRole, + pub fingerprints: Vec, +} diff --git a/rtc/src/dtls_transport/dtls_role.rs b/rtc/src/dtls_transport/dtls_role.rs new file mode 100644 index 0000000..88730c6 --- /dev/null +++ b/rtc/src/dtls_transport/dtls_role.rs @@ -0,0 +1,170 @@ +use std::fmt; + +use sdp::description::session::SessionDescription; +use sdp::util::ConnectionRole; +use serde::{Deserialize, Serialize}; + +/// DtlsRole indicates the role of the DTLS transport. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DTLSRole { + #[default] + Unspecified = 0, + + /// DTLSRoleAuto defines the DTLS role is determined based on + /// the resolved ICE role: the ICE controlled role acts as the DTLS + /// client and the ICE controlling role acts as the DTLS server. + #[serde(rename = "auto")] + Auto = 1, + + /// DTLSRoleClient defines the DTLS client role. + #[serde(rename = "client")] + Client = 2, + + /// DTLSRoleServer defines the DTLS server role. + #[serde(rename = "server")] + Server = 3, +} + +/// +/// The answerer MUST use either a +/// setup attribute value of setup:active or setup:passive. Note that +/// if the answerer uses setup:passive, then the DTLS handshake will +/// not begin until the answerer is received, which adds additional +/// latency. setup:active allows the answer and the DTLS handshake to +/// occur in parallel. Thus, setup:active is RECOMMENDED. +pub(crate) const DEFAULT_DTLS_ROLE_ANSWER: DTLSRole = DTLSRole::Client; + +/// The endpoint that is the offerer MUST use the setup attribute +/// value of setup:actpass and be prepared to receive a client_hello +/// before it receives the answer. +pub(crate) const DEFAULT_DTLS_ROLE_OFFER: DTLSRole = DTLSRole::Auto; + +impl fmt::Display for DTLSRole { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + DTLSRole::Auto => write!(f, "auto"), + DTLSRole::Client => write!(f, "client"), + DTLSRole::Server => write!(f, "server"), + _ => write!(f, "{}", crate::UNSPECIFIED_STR), + } + } +} + +/// Iterate a SessionDescription from a remote to determine if an explicit +/// role can been determined from it. The decision is made from the first role we we parse. +/// If no role can be found we return DTLSRoleAuto +impl From<&SessionDescription> for DTLSRole { + fn from(session_description: &SessionDescription) -> Self { + for media_section in &session_description.media_descriptions { + for attribute in &media_section.attributes { + if attribute.key == "setup" { + if let Some(value) = &attribute.value { + match value.as_str() { + "active" => return DTLSRole::Client, + "passive" => return DTLSRole::Server, + _ => return DTLSRole::Auto, + }; + } else { + return DTLSRole::Auto; + } + } + } + } + + DTLSRole::Auto + } +} + +impl DTLSRole { + pub(crate) fn to_connection_role(self) -> ConnectionRole { + match self { + DTLSRole::Client => ConnectionRole::Active, + DTLSRole::Server => ConnectionRole::Passive, + DTLSRole::Auto => ConnectionRole::Actpass, + _ => ConnectionRole::Unspecified, + } + } +} + +#[cfg(test)] +mod test { + use std::io::Cursor; + + use super::*; + use crate::error::Result; + + #[test] + fn test_dtls_role_string() { + let tests = vec![ + (DTLSRole::Unspecified, "Unspecified"), + (DTLSRole::Auto, "auto"), + (DTLSRole::Client, "client"), + (DTLSRole::Server, "server"), + ]; + + for (role, expected_string) in tests { + assert_eq!(role.to_string(), expected_string) + } + } + + #[test] + fn test_dtls_role_from_remote_sdp() -> Result<()> { + const NO_MEDIA: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +"; + + const MEDIA_NO_SETUP: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=application 47299 DTLS/SCTP 5000 +c=IN IP4 192.168.20.129 +"; + + const MEDIA_SETUP_DECLARED: &str = "v=0 +o=- 4596489990601351948 2 IN IP4 127.0.0.1 +s=- +t=0 0 +m=application 47299 DTLS/SCTP 5000 +c=IN IP4 192.168.20.129 +a=setup:"; + + let tests = vec![ + ("No MediaDescriptions", NO_MEDIA.to_owned(), DTLSRole::Auto), + ( + "MediaDescription, no setup", + MEDIA_NO_SETUP.to_owned(), + DTLSRole::Auto, + ), + ( + "MediaDescription, setup:actpass", + format!("{}{}\n", MEDIA_SETUP_DECLARED, "actpass"), + DTLSRole::Auto, + ), + ( + "MediaDescription, setup:passive", + format!("{}{}\n", MEDIA_SETUP_DECLARED, "passive"), + DTLSRole::Server, + ), + ( + "MediaDescription, setup:active", + format!("{}{}\n", MEDIA_SETUP_DECLARED, "active"), + DTLSRole::Client, + ), + ]; + + for (name, session_description_str, expected_role) in tests { + let mut reader = Cursor::new(session_description_str.as_bytes()); + let session_description = SessionDescription::unmarshal(&mut reader)?; + assert_eq!( + DTLSRole::from(&session_description), + expected_role, + "{name} failed" + ); + } + + Ok(()) + } +} diff --git a/rtc/src/dtls_transport/dtls_transport_state.rs b/rtc/src/dtls_transport/dtls_transport_state.rs new file mode 100644 index 0000000..3700d8b --- /dev/null +++ b/rtc/src/dtls_transport/dtls_transport_state.rs @@ -0,0 +1,117 @@ +use std::fmt; + +/// DTLSTransportState indicates the DTLS transport establishment state. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RTCDtlsTransportState { + #[default] + Unspecified = 0, + + /// DTLSTransportStateNew indicates that DTLS has not started negotiating + /// yet. + New = 1, + + /// DTLSTransportStateConnecting indicates that DTLS is in the process of + /// negotiating a secure connection and verifying the remote fingerprint. + Connecting = 2, + + /// DTLSTransportStateConnected indicates that DTLS has completed + /// negotiation of a secure connection and verified the remote fingerprint. + Connected = 3, + + /// DTLSTransportStateClosed indicates that the transport has been closed + /// intentionally as the result of receipt of a close_notify alert, or + /// calling close(). + Closed = 4, + + /// DTLSTransportStateFailed indicates that the transport has failed as + /// the result of an error (such as receipt of an error alert or failure to + /// validate the remote fingerprint). + Failed = 5, +} + +const DTLS_TRANSPORT_STATE_NEW_STR: &str = "new"; +const DTLS_TRANSPORT_STATE_CONNECTING_STR: &str = "connecting"; +const DTLS_TRANSPORT_STATE_CONNECTED_STR: &str = "connected"; +const DTLS_TRANSPORT_STATE_CLOSED_STR: &str = "closed"; +const DTLS_TRANSPORT_STATE_FAILED_STR: &str = "failed"; + +impl From<&str> for RTCDtlsTransportState { + fn from(raw: &str) -> Self { + match raw { + DTLS_TRANSPORT_STATE_NEW_STR => RTCDtlsTransportState::New, + DTLS_TRANSPORT_STATE_CONNECTING_STR => RTCDtlsTransportState::Connecting, + DTLS_TRANSPORT_STATE_CONNECTED_STR => RTCDtlsTransportState::Connected, + DTLS_TRANSPORT_STATE_CLOSED_STR => RTCDtlsTransportState::Closed, + DTLS_TRANSPORT_STATE_FAILED_STR => RTCDtlsTransportState::Failed, + _ => RTCDtlsTransportState::Unspecified, + } + } +} + +impl From for RTCDtlsTransportState { + fn from(v: u8) -> Self { + match v { + 1 => RTCDtlsTransportState::New, + 2 => RTCDtlsTransportState::Connecting, + 3 => RTCDtlsTransportState::Connected, + 4 => RTCDtlsTransportState::Closed, + 5 => RTCDtlsTransportState::Failed, + _ => RTCDtlsTransportState::Unspecified, + } + } +} + +impl fmt::Display for RTCDtlsTransportState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + RTCDtlsTransportState::New => DTLS_TRANSPORT_STATE_NEW_STR, + RTCDtlsTransportState::Connecting => DTLS_TRANSPORT_STATE_CONNECTING_STR, + RTCDtlsTransportState::Connected => DTLS_TRANSPORT_STATE_CONNECTED_STR, + RTCDtlsTransportState::Closed => DTLS_TRANSPORT_STATE_CLOSED_STR, + RTCDtlsTransportState::Failed => DTLS_TRANSPORT_STATE_FAILED_STR, + RTCDtlsTransportState::Unspecified => crate::UNSPECIFIED_STR, + }; + write!(f, "{s}") + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_dtls_transport_state() { + let tests = vec![ + (crate::UNSPECIFIED_STR, RTCDtlsTransportState::Unspecified), + ("new", RTCDtlsTransportState::New), + ("connecting", RTCDtlsTransportState::Connecting), + ("connected", RTCDtlsTransportState::Connected), + ("closed", RTCDtlsTransportState::Closed), + ("failed", RTCDtlsTransportState::Failed), + ]; + + for (state_string, expected_state) in tests { + assert_eq!( + RTCDtlsTransportState::from(state_string), + expected_state, + "testCase: {expected_state}", + ); + } + } + + #[test] + fn test_dtls_transport_state_string() { + let tests = vec![ + (RTCDtlsTransportState::Unspecified, crate::UNSPECIFIED_STR), + (RTCDtlsTransportState::New, "new"), + (RTCDtlsTransportState::Connecting, "connecting"), + (RTCDtlsTransportState::Connected, "connected"), + (RTCDtlsTransportState::Closed, "closed"), + (RTCDtlsTransportState::Failed, "failed"), + ]; + + for (state, expected_string) in tests { + assert_eq!(state.to_string(), expected_string) + } + } +} diff --git a/rtc/src/dtls_transport/dtls_transport_test.rs b/rtc/src/dtls_transport/dtls_transport_test.rs new file mode 100644 index 0000000..a9e71ac --- /dev/null +++ b/rtc/src/dtls_transport/dtls_transport_test.rs @@ -0,0 +1,204 @@ +use ice::mdns::MulticastDnsMode; +use ice::network_type::NetworkType; +use regex::Regex; +use tokio::time::Duration; +use waitgroup::WaitGroup; + +use super::*; +use crate::api::media_engine::MediaEngine; +use crate::api::APIBuilder; +use crate::data_channel::RTCDataChannel; +use crate::ice_transport::ice_candidate::RTCIceCandidate; +use crate::peer_connection::configuration::RTCConfiguration; +use crate::peer_connection::peer_connection_state::RTCPeerConnectionState; +use crate::peer_connection::peer_connection_test::{ + close_pair_now, new_pair, signal_pair, until_connection_state, +}; + +//use log::LevelFilter; +//use std::io::Write; + +// An invalid fingerprint MUST cause PeerConnectionState to go to PeerConnectionStateFailed +#[tokio::test] +async fn test_invalid_fingerprint_causes_failed() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, LevelFilter::Trace) + .init();*/ + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (mut pc_offer, mut pc_answer) = new_pair(&api).await?; + + pc_answer.on_data_channel(Box::new(|_: Arc| { + panic!("A DataChannel must not be created when Fingerprint verification fails"); + })); + + let (offer_chan_tx, mut offer_chan_rx) = mpsc::channel::<()>(1); + + let offer_chan_tx = Arc::new(offer_chan_tx); + pc_offer.on_ice_candidate(Box::new(move |candidate: Option| { + let offer_chan_tx2 = Arc::clone(&offer_chan_tx); + Box::pin(async move { + if candidate.is_none() { + let _ = offer_chan_tx2.send(()).await; + } + }) + })); + + let offer_connection_has_failed = WaitGroup::new(); + until_connection_state( + &mut pc_offer, + &offer_connection_has_failed, + RTCPeerConnectionState::Failed, + ) + .await; + let answer_connection_has_failed = WaitGroup::new(); + until_connection_state( + &mut pc_answer, + &answer_connection_has_failed, + RTCPeerConnectionState::Failed, + ) + .await; + + let _ = pc_offer + .create_data_channel("unusedDataChannel", None) + .await?; + + let offer = pc_offer.create_offer(None).await?; + pc_offer.set_local_description(offer).await?; + + let timeout = tokio::time::sleep(Duration::from_secs(1)); + tokio::pin!(timeout); + + tokio::select! { + _ = offer_chan_rx.recv() =>{ + let mut offer = pc_offer.pending_local_description().await.unwrap(); + + log::trace!("receiving pending local desc: {:?}", offer); + + // Replace with invalid fingerprint + let re = Regex::new(r"sha-256 (.*?)\r").unwrap(); + offer.sdp = re.replace_all(offer.sdp.as_str(), "sha-256 AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA\r").to_string(); + + pc_answer.set_remote_description(offer).await?; + + let mut answer = pc_answer.create_answer(None).await?; + + pc_answer.set_local_description(answer.clone()).await?; + + answer.sdp = re.replace_all(answer.sdp.as_str(), "sha-256 AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA:AA\r").to_string(); + + pc_offer.set_remote_description(answer).await?; + } + _ = timeout.as_mut() =>{ + panic!("timed out waiting to receive offer"); + } + } + + log::trace!("offer_connection_has_failed wait begin"); + + offer_connection_has_failed.wait().await; + answer_connection_has_failed.wait().await; + + log::trace!("offer_connection_has_failed wait end"); + { + let transport = pc_offer.sctp().transport(); + assert_eq!(transport.state(), RTCDtlsTransportState::Failed); + assert!(transport.conn().await.is_none()); + } + + { + let transport = pc_answer.sctp().transport(); + assert_eq!(transport.state(), RTCDtlsTransportState::Failed); + assert!(transport.conn().await.is_none()); + } + + close_pair_now(&pc_offer, &pc_answer).await; + + Ok(()) +} + +async fn run_test(r: DTLSRole) -> Result<()> { + let mut offer_s = SettingEngine::default(); + offer_s.set_answering_dtls_role(r)?; + offer_s.set_ice_multicast_dns_mode(MulticastDnsMode::Disabled); + offer_s.set_network_types(vec![NetworkType::Udp4]); + let mut offer_pc = APIBuilder::new() + .with_setting_engine(offer_s) + .build() + .new_peer_connection(RTCConfiguration::default()) + .await?; + + let mut answer_s = SettingEngine::default(); + answer_s.set_answering_dtls_role(r)?; + answer_s.set_ice_multicast_dns_mode(MulticastDnsMode::Disabled); + answer_s.set_network_types(vec![NetworkType::Udp4]); + let mut answer_pc = APIBuilder::new() + .with_setting_engine(answer_s) + .build() + .new_peer_connection(RTCConfiguration::default()) + .await?; + + signal_pair(&mut offer_pc, &mut answer_pc).await?; + + let wg = WaitGroup::new(); + until_connection_state(&mut answer_pc, &wg, RTCPeerConnectionState::Connected).await; + wg.wait().await; + + close_pair_now(&offer_pc, &answer_pc).await; + + Ok(()) +} + +#[tokio::test] +async fn test_peer_connection_dtls_role_setting_engine_server() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, LevelFilter::Trace) + .init();*/ + + run_test(DTLSRole::Server).await +} + +#[tokio::test] +async fn test_peer_connection_dtls_role_setting_engine_client() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, LevelFilter::Trace) + .init();*/ + + run_test(DTLSRole::Client).await +} diff --git a/rtc/src/dtls_transport/mod.rs b/rtc/src/dtls_transport/mod.rs new file mode 100644 index 0000000..8d6bcf3 --- /dev/null +++ b/rtc/src/dtls_transport/mod.rs @@ -0,0 +1,616 @@ +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; +use std::sync::Arc; + +use arc_swap::ArcSwapOption; +use bytes::Bytes; +use dtls::config::ClientAuthType; +use dtls::conn::DTLSConn; +use dtls::extension::extension_use_srtp::SrtpProtectionProfile; +use dtls_role::*; +use interceptor::stream_info::StreamInfo; +use interceptor::{Interceptor, RTCPReader, RTPReader}; +use sha2::{Digest, Sha256}; +use srtp::protection_profile::ProtectionProfile; +use srtp::session::Session; +use srtp::stream::Stream; +use tokio::sync::{mpsc, Mutex}; +use util::Conn; + +use crate::api::setting_engine::SettingEngine; +use crate::dtls_transport::dtls_parameters::DTLSParameters; +use crate::dtls_transport::dtls_transport_state::RTCDtlsTransportState; +use crate::error::{flatten_errs, Error, Result}; +use crate::ice_transport::ice_role::RTCIceRole; +use crate::ice_transport::ice_transport_state::RTCIceTransportState; +use crate::ice_transport::RTCIceTransport; +use crate::mux::endpoint::Endpoint; +use crate::mux::mux_func::{match_dtls, match_srtcp, match_srtp, MatchFunc}; +use crate::peer_connection::certificate::RTCCertificate; +use crate::rtp_transceiver::SSRC; +use crate::stats::stats_collector::StatsCollector; + +#[cfg(test)] +mod dtls_transport_test; + +pub mod dtls_fingerprint; +pub mod dtls_parameters; +pub mod dtls_role; +pub mod dtls_transport_state; + +pub(crate) fn default_srtp_protection_profiles() -> Vec { + vec![ + SrtpProtectionProfile::Srtp_Aead_Aes_128_Gcm, + SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80, + ] +} + +pub type OnDTLSTransportStateChangeHdlrFn = Box< + dyn (FnMut(RTCDtlsTransportState) -> Pin + Send + 'static>>) + + Send + + Sync, +>; + +/// DTLSTransport allows an application access to information about the DTLS +/// transport over which RTP and RTCP packets are sent and received by +/// RTPSender and RTPReceiver, as well other data such as SCTP packets sent +/// and received by data channels. +#[derive(Default)] +pub struct RTCDtlsTransport { + pub(crate) ice_transport: Arc, + pub(crate) certificates: Vec, + pub(crate) setting_engine: Arc, + + pub(crate) remote_parameters: Mutex, + pub(crate) remote_certificate: Mutex, + pub(crate) state: AtomicU8, //DTLSTransportState, + pub(crate) srtp_protection_profile: Mutex, + pub(crate) on_state_change_handler: ArcSwapOption>, + pub(crate) conn: Mutex>>, + + pub(crate) srtp_session: Mutex>>, + pub(crate) srtcp_session: Mutex>>, + pub(crate) srtp_endpoint: Mutex>>, + pub(crate) srtcp_endpoint: Mutex>>, + + pub(crate) simulcast_streams: Mutex>>, + + pub(crate) srtp_ready_signal: Arc, + pub(crate) srtp_ready_tx: Mutex>>, + pub(crate) srtp_ready_rx: Mutex>>, + + pub(crate) dtls_matcher: Option, +} + +impl RTCDtlsTransport { + pub(crate) fn new( + ice_transport: Arc, + certificates: Vec, + setting_engine: Arc, + ) -> Self { + let (srtp_ready_tx, srtp_ready_rx) = mpsc::channel(1); + RTCDtlsTransport { + ice_transport, + certificates, + setting_engine, + srtp_ready_signal: Arc::new(AtomicBool::new(false)), + srtp_ready_tx: Mutex::new(Some(srtp_ready_tx)), + srtp_ready_rx: Mutex::new(Some(srtp_ready_rx)), + state: AtomicU8::new(RTCDtlsTransportState::New as u8), + dtls_matcher: Some(Box::new(match_dtls)), + ..Default::default() + } + } + + pub(crate) async fn conn(&self) -> Option> { + let conn = self.conn.lock().await; + conn.clone() + } + + /// returns the currently-configured ICETransport or None + /// if one has not been configured + pub fn ice_transport(&self) -> &RTCIceTransport { + &self.ice_transport + } + + /// state_change requires the caller holds the lock + async fn state_change(&self, state: RTCDtlsTransportState) { + self.state.store(state as u8, Ordering::SeqCst); + if let Some(handler) = &*self.on_state_change_handler.load() { + let mut f = handler.lock().await; + f(state).await; + } + } + + /// on_state_change sets a handler that is fired when the DTLS + /// connection state changes. + pub fn on_state_change(&self, f: OnDTLSTransportStateChangeHdlrFn) { + self.on_state_change_handler + .store(Some(Arc::new(Mutex::new(f)))); + } + + /// state returns the current dtls_transport transport state. + pub fn state(&self) -> RTCDtlsTransportState { + self.state.load(Ordering::SeqCst).into() + } + + /// write_rtcp sends a user provided RTCP packet to the connected peer. If no peer is connected the + /// packet is discarded. + pub async fn write_rtcp( + &self, + pkts: &[Box], + ) -> Result { + let srtcp_session = self.srtcp_session.lock().await; + if let Some(srtcp_session) = &*srtcp_session { + let raw = rtcp::packet::marshal(pkts)?; + Ok(srtcp_session.write(&raw, false).await?) + } else { + Ok(0) + } + } + + /// get_local_parameters returns the DTLS parameters of the local DTLSTransport upon construction. + pub fn get_local_parameters(&self) -> Result { + let mut fingerprints = vec![]; + + for c in &self.certificates { + fingerprints.extend(c.get_fingerprints()); + } + + Ok(DTLSParameters { + role: DTLSRole::Auto, // always returns the default role + fingerprints, + }) + } + + /// get_remote_certificate returns the certificate chain in use by the remote side + /// returns an empty list prior to selection of the remote certificate + pub async fn get_remote_certificate(&self) -> Bytes { + let remote_certificate = self.remote_certificate.lock().await; + remote_certificate.clone() + } + + pub(crate) async fn start_srtp(&self) -> Result<()> { + let profile = { + let srtp_protection_profile = self.srtp_protection_profile.lock().await; + *srtp_protection_profile + }; + + let mut srtp_config = srtp::config::Config { + profile, + ..Default::default() + }; + + if self.setting_engine.replay_protection.srtp != 0 { + srtp_config.remote_rtp_options = Some(srtp::option::srtp_replay_protection( + self.setting_engine.replay_protection.srtp, + )); + } else if self.setting_engine.disable_srtp_replay_protection { + srtp_config.remote_rtp_options = Some(srtp::option::srtp_no_replay_protection()); + } + + if let Some(conn) = self.conn().await { + let conn_state = conn.connection_state().await; + srtp_config + .extract_session_keys_from_dtls(conn_state, self.role().await == DTLSRole::Client) + .await?; + } else { + return Err(Error::ErrDtlsTransportNotStarted); + } + + { + let mut srtp_session = self.srtp_session.lock().await; + *srtp_session = { + let se = self.srtp_endpoint.lock().await; + if let Some(srtp_endpoint) = &*se { + Some(Arc::new( + Session::new( + Arc::clone(srtp_endpoint) as Arc, + srtp_config, + true, + ) + .await?, + )) + } else { + None + } + }; + } + + let mut srtcp_config = srtp::config::Config { + profile, + ..Default::default() + }; + if self.setting_engine.replay_protection.srtcp != 0 { + srtcp_config.remote_rtcp_options = Some(srtp::option::srtcp_replay_protection( + self.setting_engine.replay_protection.srtcp, + )); + } else if self.setting_engine.disable_srtcp_replay_protection { + srtcp_config.remote_rtcp_options = Some(srtp::option::srtcp_no_replay_protection()); + } + + if let Some(conn) = self.conn().await { + let conn_state = conn.connection_state().await; + srtcp_config + .extract_session_keys_from_dtls(conn_state, self.role().await == DTLSRole::Client) + .await?; + } else { + return Err(Error::ErrDtlsTransportNotStarted); + } + + { + let mut srtcp_session = self.srtcp_session.lock().await; + *srtcp_session = { + let se = self.srtcp_endpoint.lock().await; + if let Some(srtcp_endpoint) = &*se { + Some(Arc::new( + Session::new( + Arc::clone(srtcp_endpoint) as Arc, + srtcp_config, + false, + ) + .await?, + )) + } else { + None + } + }; + } + + { + let mut srtp_ready_tx = self.srtp_ready_tx.lock().await; + srtp_ready_tx.take(); + if srtp_ready_tx.is_none() { + self.srtp_ready_signal.store(true, Ordering::SeqCst); + } + } + + Ok(()) + } + + pub(crate) async fn get_srtp_session(&self) -> Option> { + let srtp_session = self.srtp_session.lock().await; + srtp_session.clone() + } + + pub(crate) async fn get_srtcp_session(&self) -> Option> { + let srtcp_session = self.srtcp_session.lock().await; + srtcp_session.clone() + } + + pub(crate) async fn role(&self) -> DTLSRole { + // If remote has an explicit role use the inverse + { + let remote_parameters = self.remote_parameters.lock().await; + match remote_parameters.role { + DTLSRole::Client => return DTLSRole::Server, + DTLSRole::Server => return DTLSRole::Client, + _ => {} + }; + } + + // If SettingEngine has an explicit role + match self.setting_engine.answering_dtls_role { + DTLSRole::Server => return DTLSRole::Server, + DTLSRole::Client => return DTLSRole::Client, + _ => {} + }; + + // Remote was auto and no explicit role was configured via SettingEngine + if self.ice_transport.role().await == RTCIceRole::Controlling { + return DTLSRole::Server; + } + + DEFAULT_DTLS_ROLE_ANSWER + } + + pub(crate) async fn collect_stats(&self, collector: &StatsCollector) { + for cert in &self.certificates { + cert.collect_stats(collector).await; + } + } + + async fn prepare_transport( + &self, + remote_parameters: DTLSParameters, + ) -> Result<(DTLSRole, dtls::config::Config)> { + self.ensure_ice_conn()?; + + if self.state() != RTCDtlsTransportState::New { + return Err(Error::ErrInvalidDTLSStart); + } + + { + let mut srtp_endpoint = self.srtp_endpoint.lock().await; + *srtp_endpoint = self.ice_transport.new_endpoint(Box::new(match_srtp)).await; + } + { + let mut srtcp_endpoint = self.srtcp_endpoint.lock().await; + *srtcp_endpoint = self.ice_transport.new_endpoint(Box::new(match_srtcp)).await; + } + { + let mut rp = self.remote_parameters.lock().await; + *rp = remote_parameters; + } + + let certificate = if let Some(cert) = self.certificates.first() { + cert.dtls_certificate.clone() + } else { + return Err(Error::ErrNonCertificate); + }; + self.state_change(RTCDtlsTransportState::Connecting).await; + + Ok(( + self.role().await, + dtls::config::Config { + certificates: vec![certificate], + srtp_protection_profiles: if !self + .setting_engine + .srtp_protection_profiles + .is_empty() + { + self.setting_engine.srtp_protection_profiles.clone() + } else { + default_srtp_protection_profiles() + }, + client_auth: ClientAuthType::RequireAnyClientCert, + insecure_skip_verify: true, + insecure_verification: self.setting_engine.allow_insecure_verification_algorithm, + ..Default::default() + }, + )) + } + + /// start DTLS transport negotiation with the parameters of the remote DTLS transport + pub async fn start(&self, remote_parameters: DTLSParameters) -> Result<()> { + let dtls_conn_result = if let Some(dtls_endpoint) = + self.ice_transport.new_endpoint(Box::new(match_dtls)).await + { + let (role, mut dtls_config) = self.prepare_transport(remote_parameters).await?; + if self.setting_engine.replay_protection.dtls != 0 { + dtls_config.replay_protection_window = self.setting_engine.replay_protection.dtls; + } + + // Connect as DTLS Client/Server, function is blocking and we + // must not hold the DTLSTransport lock + if role == DTLSRole::Client { + dtls::conn::DTLSConn::new( + dtls_endpoint as Arc, + dtls_config, + true, + None, + ) + .await + } else { + dtls::conn::DTLSConn::new( + dtls_endpoint as Arc, + dtls_config, + false, + None, + ) + .await + } + } else { + Err(dtls::Error::Other( + "ice_transport.new_endpoint failed".to_owned(), + )) + }; + + let dtls_conn = match dtls_conn_result { + Ok(dtls_conn) => dtls_conn, + Err(err) => { + self.state_change(RTCDtlsTransportState::Failed).await; + return Err(err.into()); + } + }; + + let srtp_profile = dtls_conn.selected_srtpprotection_profile(); + { + let mut srtp_protection_profile = self.srtp_protection_profile.lock().await; + *srtp_protection_profile = match srtp_profile { + dtls::extension::extension_use_srtp::SrtpProtectionProfile::Srtp_Aead_Aes_128_Gcm => { + srtp::protection_profile::ProtectionProfile::AeadAes128Gcm + } + dtls::extension::extension_use_srtp::SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80 => { + srtp::protection_profile::ProtectionProfile::Aes128CmHmacSha1_80 + } + _ => { + if let Err(err) = dtls_conn.close().await { + log::error!("{}", err); + } + + self.state_change(RTCDtlsTransportState::Failed).await; + return Err(Error::ErrNoSRTPProtectionProfile); + } + }; + } + + // Check the fingerprint if a certificate was exchanged + let remote_certs = &dtls_conn.connection_state().await.peer_certificates; + if remote_certs.is_empty() { + if let Err(err) = dtls_conn.close().await { + log::error!("{}", err); + } + + self.state_change(RTCDtlsTransportState::Failed).await; + return Err(Error::ErrNoRemoteCertificate); + } + + { + let mut remote_certificate = self.remote_certificate.lock().await; + *remote_certificate = Bytes::from(remote_certs[0].clone()); + } + + if !self + .setting_engine + .disable_certificate_fingerprint_verification + { + if let Err(err) = self.validate_fingerprint(&remote_certs[0]).await { + if let Err(close_err) = dtls_conn.close().await { + log::error!("{}", close_err); + } + + self.state_change(RTCDtlsTransportState::Failed).await; + return Err(err); + } + } + + { + let mut conn = self.conn.lock().await; + *conn = Some(Arc::new(dtls_conn)); + } + self.state_change(RTCDtlsTransportState::Connected).await; + + self.start_srtp().await + } + + /// stops and closes the DTLSTransport object. + pub async fn stop(&self) -> Result<()> { + // Try closing everything and collect the errors + let mut close_errs: Vec = vec![]; + { + let srtp_session = { + let mut srtp_session = self.srtp_session.lock().await; + srtp_session.take() + }; + if let Some(srtp_session) = srtp_session { + match srtp_session.close().await { + Ok(_) => {} + Err(err) => { + close_errs.push(err.into()); + } + }; + } + } + + { + let srtcp_session = { + let mut srtcp_session = self.srtcp_session.lock().await; + srtcp_session.take() + }; + if let Some(srtcp_session) = srtcp_session { + match srtcp_session.close().await { + Ok(_) => {} + Err(err) => { + close_errs.push(err.into()); + } + }; + } + } + + { + let simulcast_streams: Vec> = { + let mut simulcast_streams = self.simulcast_streams.lock().await; + simulcast_streams.drain().map(|(_, v)| v).collect() + }; + for ss in simulcast_streams { + match ss.close().await { + Ok(_) => {} + Err(err) => { + close_errs.push(Error::new(format!( + "simulcast_streams ssrc={}: {}", + ss.get_ssrc(), + err + ))); + } + }; + } + } + + if let Some(conn) = self.conn().await { + // dtls_transport connection may be closed on sctp close. + match conn.close().await { + Ok(_) => {} + Err(err) => { + if err.to_string() != dtls::Error::ErrConnClosed.to_string() { + close_errs.push(err.into()); + } + } + } + } + + self.state_change(RTCDtlsTransportState::Closed).await; + + flatten_errs(close_errs) + } + + pub(crate) async fn validate_fingerprint(&self, remote_cert: &[u8]) -> Result<()> { + let remote_parameters = self.remote_parameters.lock().await; + for fp in &remote_parameters.fingerprints { + if fp.algorithm != "sha-256" { + return Err(Error::ErrUnsupportedFingerprintAlgorithm); + } + + let mut h = Sha256::new(); + h.update(remote_cert); + let hashed = h.finalize(); + let values: Vec = hashed.iter().map(|x| format! {"{x:02x}"}).collect(); + let remote_value = values.join(":").to_lowercase(); + + if remote_value == fp.value.to_lowercase() { + return Ok(()); + } + } + + Err(Error::ErrNoMatchingCertificateFingerprint) + } + + pub(crate) fn ensure_ice_conn(&self) -> Result<()> { + if self.ice_transport.state() == RTCIceTransportState::New { + Err(Error::ErrICEConnectionNotStarted) + } else { + Ok(()) + } + } + + pub(crate) async fn store_simulcast_stream(&self, ssrc: SSRC, stream: Arc) { + let mut simulcast_streams = self.simulcast_streams.lock().await; + simulcast_streams.insert(ssrc, stream); + } + + pub(crate) async fn remove_simulcast_stream(&self, ssrc: SSRC) { + let mut simulcast_streams = self.simulcast_streams.lock().await; + simulcast_streams.remove(&ssrc); + } + + pub(crate) async fn streams_for_ssrc( + &self, + ssrc: SSRC, + stream_info: &StreamInfo, + interceptor: &Arc, + ) -> Result<( + Arc, + Arc, + Arc, + Arc, + )> { + let srtp_session = self + .get_srtp_session() + .await + .ok_or(Error::ErrDtlsTransportNotStarted)?; + //log::debug!("streams_for_ssrc: srtp_session.listen ssrc={}", ssrc); + let rtp_read_stream = srtp_session.open(ssrc).await; + let rtp_stream_reader = Arc::clone(&rtp_read_stream) as Arc; + let rtp_interceptor = interceptor + .bind_remote_stream(stream_info, rtp_stream_reader) + .await; + + let srtcp_session = self + .get_srtcp_session() + .await + .ok_or(Error::ErrDtlsTransportNotStarted)?; + //log::debug!("streams_for_ssrc: srtcp_session.listen ssrc={}", ssrc); + let rtcp_read_stream = srtcp_session.open(ssrc).await; + let rtcp_stream_reader = Arc::clone(&rtcp_read_stream) as Arc; + let rtcp_interceptor = interceptor.bind_rtcp_reader(rtcp_stream_reader).await; + + Ok(( + rtp_read_stream, + rtp_interceptor, + rtcp_read_stream, + rtcp_interceptor, + )) + } +} diff --git a/rtc/src/ice_transport/ice_candidate.rs b/rtc/src/ice_transport/ice_candidate.rs new file mode 100644 index 0000000..d9cb74b --- /dev/null +++ b/rtc/src/ice_transport/ice_candidate.rs @@ -0,0 +1,219 @@ +use std::fmt; +use std::sync::Arc; + +use ice::candidate::candidate_base::CandidateBaseConfig; +use ice::candidate::candidate_host::CandidateHostConfig; +use ice::candidate::candidate_peer_reflexive::CandidatePeerReflexiveConfig; +use ice::candidate::candidate_relay::CandidateRelayConfig; +use ice::candidate::candidate_server_reflexive::CandidateServerReflexiveConfig; +use ice::candidate::Candidate; +use serde::{Deserialize, Serialize}; + +use crate::error::{Error, Result}; +use crate::ice_transport::ice_candidate_type::RTCIceCandidateType; +use crate::ice_transport::ice_protocol::RTCIceProtocol; + +/// ICECandidate represents a ice candidate +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RTCIceCandidate { + pub stats_id: String, + pub foundation: String, + pub priority: u32, + pub address: String, + pub protocol: RTCIceProtocol, + pub port: u16, + pub typ: RTCIceCandidateType, + pub component: u16, + pub related_address: String, + pub related_port: u16, + pub tcp_type: String, +} + +/// Conversion for ice_candidates +pub(crate) fn rtc_ice_candidates_from_ice_candidates( + ice_candidates: &[Arc], +) -> Vec { + ice_candidates.iter().map(|c| c.into()).collect() +} + +impl From<&Arc> for RTCIceCandidate { + fn from(c: &Arc) -> Self { + let typ: RTCIceCandidateType = c.candidate_type().into(); + let protocol = RTCIceProtocol::from(c.network_type().network_short().as_str()); + let (related_address, related_port) = if let Some(ra) = c.related_address() { + (ra.address, ra.port) + } else { + (String::new(), 0) + }; + + RTCIceCandidate { + stats_id: c.id(), + foundation: c.foundation(), + priority: c.priority(), + address: c.address(), + protocol, + port: c.port(), + component: c.component(), + typ, + tcp_type: c.tcp_type().to_string(), + related_address, + related_port, + } + } +} + +impl RTCIceCandidate { + pub(crate) fn to_ice(&self) -> Result { + let candidate_id = self.stats_id.clone(); + let c = match self.typ { + RTCIceCandidateType::Host => { + let config = CandidateHostConfig { + base_config: CandidateBaseConfig { + candidate_id, + network: self.protocol.to_string(), + address: self.address.clone(), + port: self.port, + component: self.component, + //tcp_type: ice.NewTCPType(c.TCPType), + foundation: self.foundation.clone(), + priority: self.priority, + ..Default::default() + }, + ..Default::default() + }; + config.new_candidate_host()? + } + RTCIceCandidateType::Srflx => { + let config = CandidateServerReflexiveConfig { + base_config: CandidateBaseConfig { + candidate_id, + network: self.protocol.to_string(), + address: self.address.clone(), + port: self.port, + component: self.component, + foundation: self.foundation.clone(), + priority: self.priority, + ..Default::default() + }, + rel_addr: self.related_address.clone(), + rel_port: self.related_port, + }; + config.new_candidate_server_reflexive()? + } + RTCIceCandidateType::Prflx => { + let config = CandidatePeerReflexiveConfig { + base_config: CandidateBaseConfig { + candidate_id, + network: self.protocol.to_string(), + address: self.address.clone(), + port: self.port, + component: self.component, + foundation: self.foundation.clone(), + priority: self.priority, + ..Default::default() + }, + rel_addr: self.related_address.clone(), + rel_port: self.related_port, + }; + config.new_candidate_peer_reflexive()? + } + RTCIceCandidateType::Relay => { + let config = CandidateRelayConfig { + base_config: CandidateBaseConfig { + candidate_id, + network: self.protocol.to_string(), + address: self.address.clone(), + port: self.port, + component: self.component, + foundation: self.foundation.clone(), + priority: self.priority, + ..Default::default() + }, + rel_addr: self.related_address.clone(), + rel_port: self.related_port, + relay_client: None, //TODO? + }; + config.new_candidate_relay()? + } + _ => return Err(Error::ErrICECandidateTypeUnknown), + }; + + Ok(c) + } + + /// to_json returns an ICECandidateInit + /// as indicated by the spec + pub fn to_json(&self) -> Result { + let candidate = self.to_ice()?; + + Ok(RTCIceCandidateInit { + candidate: format!("candidate:{}", candidate.marshal()), + sdp_mid: Some("".to_owned()), + sdp_mline_index: Some(0u16), + username_fragment: None, + }) + } +} + +impl fmt::Display for RTCIceCandidate { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} {} {}:{}{}", + self.protocol, self.typ, self.address, self.port, self.related_address, + ) + } +} + +/// ICECandidateInit is used to serialize ice candidates +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RTCIceCandidateInit { + pub candidate: String, + pub sdp_mid: Option, + #[serde(rename = "sdpMLineIndex")] + pub sdp_mline_index: Option, + pub username_fragment: Option, +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_ice_candidate_serialization() { + let tests = vec![ + ( + RTCIceCandidateInit { + candidate: "candidate:abc123".to_string(), + sdp_mid: Some("0".to_string()), + sdp_mline_index: Some(0), + username_fragment: Some("def".to_string()), + }, + r#"{"candidate":"candidate:abc123","sdpMid":"0","sdpMLineIndex":0,"usernameFragment":"def"}"#, + ), + ( + RTCIceCandidateInit { + candidate: "candidate:abc123".to_string(), + sdp_mid: None, + sdp_mline_index: None, + username_fragment: None, + }, + r#"{"candidate":"candidate:abc123","sdpMid":null,"sdpMLineIndex":null,"usernameFragment":null}"#, + ), + ]; + + for (candidate_init, expected_string) in tests { + let result = serde_json::to_string(&candidate_init); + assert!(result.is_ok(), "testCase: marshal err: {result:?}"); + let candidate_data = result.unwrap(); + assert_eq!(candidate_data, expected_string, "string is not expected"); + + let result = serde_json::from_str::(&candidate_data); + assert!(result.is_ok(), "testCase: unmarshal err: {result:?}"); + if let Ok(actual_candidate_init) = result { + assert_eq!(actual_candidate_init, candidate_init); + } + } + } +} diff --git a/rtc/src/ice_transport/ice_candidate_pair.rs b/rtc/src/ice_transport/ice_candidate_pair.rs new file mode 100644 index 0000000..8c0684c --- /dev/null +++ b/rtc/src/ice_transport/ice_candidate_pair.rs @@ -0,0 +1,34 @@ +use std::fmt; + +use crate::ice_transport::ice_candidate::*; + +/// ICECandidatePair represents an ICE Candidate pair +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct RTCIceCandidatePair { + stats_id: String, + local: RTCIceCandidate, + remote: RTCIceCandidate, +} + +impl fmt::Display for RTCIceCandidatePair { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(local) {} <-> (remote) {}", self.local, self.remote) + } +} + +impl RTCIceCandidatePair { + fn stats_id(local_id: &str, remote_id: &str) -> String { + format!("{local_id}-{remote_id}") + } + + /// returns an initialized ICECandidatePair + /// for the given pair of ICECandidate instances + pub fn new(local: RTCIceCandidate, remote: RTCIceCandidate) -> Self { + let stats_id = Self::stats_id(&local.stats_id, &remote.stats_id); + RTCIceCandidatePair { + stats_id, + local, + remote, + } + } +} diff --git a/rtc/src/ice_transport/ice_candidate_type.rs b/rtc/src/ice_transport/ice_candidate_type.rs new file mode 100644 index 0000000..ec44328 --- /dev/null +++ b/rtc/src/ice_transport/ice_candidate_type.rs @@ -0,0 +1,119 @@ +use std::fmt; + +use ice::candidate::CandidateType; +use serde::{Deserialize, Serialize}; + +/// ICECandidateType represents the type of the ICE candidate used. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum RTCIceCandidateType { + #[default] + Unspecified, + + /// ICECandidateTypeHost indicates that the candidate is of Host type as + /// described in . A + /// candidate obtained by binding to a specific port from an IP address on + /// the host. This includes IP addresses on physical interfaces and logical + /// ones, such as ones obtained through VPNs. + #[serde(rename = "host")] + Host, + + /// ICECandidateTypeSrflx indicates the the candidate is of Server + /// Reflexive type as described + /// . A candidate type + /// whose IP address and port are a binding allocated by a NAT for an ICE + /// agent after it sends a packet through the NAT to a server, such as a + /// STUN server. + #[serde(rename = "srflx")] + Srflx, + + /// ICECandidateTypePrflx indicates that the candidate is of Peer + /// Reflexive type. A candidate type whose IP address and port are a binding + /// allocated by a NAT for an ICE agent after it sends a packet through the + /// NAT to its peer. + #[serde(rename = "prflx")] + Prflx, + + /// ICECandidateTypeRelay indicates the the candidate is of Relay type as + /// described in . A + /// candidate type obtained from a relay server, such as a TURN server. + #[serde(rename = "relay")] + Relay, +} + +const ICE_CANDIDATE_TYPE_HOST_STR: &str = "host"; +const ICE_CANDIDATE_TYPE_SRFLX_STR: &str = "srflx"; +const ICE_CANDIDATE_TYPE_PRFLX_STR: &str = "prflx"; +const ICE_CANDIDATE_TYPE_RELAY_STR: &str = "relay"; + +/// takes a string and converts it into ICECandidateType +impl From<&str> for RTCIceCandidateType { + fn from(raw: &str) -> Self { + match raw { + ICE_CANDIDATE_TYPE_HOST_STR => RTCIceCandidateType::Host, + ICE_CANDIDATE_TYPE_SRFLX_STR => RTCIceCandidateType::Srflx, + ICE_CANDIDATE_TYPE_PRFLX_STR => RTCIceCandidateType::Prflx, + ICE_CANDIDATE_TYPE_RELAY_STR => RTCIceCandidateType::Relay, + _ => RTCIceCandidateType::Unspecified, + } + } +} + +impl From for RTCIceCandidateType { + fn from(candidate_type: CandidateType) -> Self { + match candidate_type { + CandidateType::Host => RTCIceCandidateType::Host, + CandidateType::ServerReflexive => RTCIceCandidateType::Srflx, + CandidateType::PeerReflexive => RTCIceCandidateType::Prflx, + CandidateType::Relay => RTCIceCandidateType::Relay, + _ => RTCIceCandidateType::Unspecified, + } + } +} + +impl fmt::Display for RTCIceCandidateType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + RTCIceCandidateType::Host => write!(f, "{ICE_CANDIDATE_TYPE_HOST_STR}"), + RTCIceCandidateType::Srflx => write!(f, "{ICE_CANDIDATE_TYPE_SRFLX_STR}"), + RTCIceCandidateType::Prflx => write!(f, "{ICE_CANDIDATE_TYPE_PRFLX_STR}"), + RTCIceCandidateType::Relay => write!(f, "{ICE_CANDIDATE_TYPE_RELAY_STR}"), + _ => write!(f, "{}", crate::UNSPECIFIED_STR), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_ice_candidate_type() { + let tests = vec![ + ("Unspecified", RTCIceCandidateType::Unspecified), + ("host", RTCIceCandidateType::Host), + ("srflx", RTCIceCandidateType::Srflx), + ("prflx", RTCIceCandidateType::Prflx), + ("relay", RTCIceCandidateType::Relay), + ]; + + for (type_string, expected_type) in tests { + let actual = RTCIceCandidateType::from(type_string); + assert_eq!(actual, expected_type); + } + } + + #[test] + fn test_ice_candidate_type_string() { + let tests = vec![ + (RTCIceCandidateType::Unspecified, "Unspecified"), + (RTCIceCandidateType::Host, "host"), + (RTCIceCandidateType::Srflx, "srflx"), + (RTCIceCandidateType::Prflx, "prflx"), + (RTCIceCandidateType::Relay, "relay"), + ]; + + for (ctype, expected_string) in tests { + assert_eq!(ctype.to_string(), expected_string); + } + } +} diff --git a/rtc/src/ice_transport/ice_connection_state.rs b/rtc/src/ice_transport/ice_connection_state.rs new file mode 100644 index 0000000..d019854 --- /dev/null +++ b/rtc/src/ice_transport/ice_connection_state.rs @@ -0,0 +1,142 @@ +use std::fmt; + +/// RTCIceConnectionState indicates signaling state of the ICE Connection. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RTCIceConnectionState { + #[default] + Unspecified, + + /// ICEConnectionStateNew indicates that any of the ICETransports are + /// in the "new" state and none of them are in the "checking", "disconnected" + /// or "failed" state, or all ICETransports are in the "closed" state, or + /// there are no transports. + New, + + /// ICEConnectionStateChecking indicates that any of the ICETransports + /// are in the "checking" state and none of them are in the "disconnected" + /// or "failed" state. + Checking, + + /// ICEConnectionStateConnected indicates that all ICETransports are + /// in the "connected", "completed" or "closed" state and at least one of + /// them is in the "connected" state. + Connected, + + /// ICEConnectionStateCompleted indicates that all ICETransports are + /// in the "completed" or "closed" state and at least one of them is in the + /// "completed" state. + Completed, + + /// ICEConnectionStateDisconnected indicates that any of the + /// ICETransports are in the "disconnected" state and none of them are + /// in the "failed" state. + Disconnected, + + /// ICEConnectionStateFailed indicates that any of the ICETransports + /// are in the "failed" state. + Failed, + + /// ICEConnectionStateClosed indicates that the PeerConnection's + /// isClosed is true. + Closed, +} + +const ICE_CONNECTION_STATE_NEW_STR: &str = "new"; +const ICE_CONNECTION_STATE_CHECKING_STR: &str = "checking"; +const ICE_CONNECTION_STATE_CONNECTED_STR: &str = "connected"; +const ICE_CONNECTION_STATE_COMPLETED_STR: &str = "completed"; +const ICE_CONNECTION_STATE_DISCONNECTED_STR: &str = "disconnected"; +const ICE_CONNECTION_STATE_FAILED_STR: &str = "failed"; +const ICE_CONNECTION_STATE_CLOSED_STR: &str = "closed"; + +/// takes a string and converts it to iceconnection_state +impl From<&str> for RTCIceConnectionState { + fn from(raw: &str) -> Self { + match raw { + ICE_CONNECTION_STATE_NEW_STR => RTCIceConnectionState::New, + ICE_CONNECTION_STATE_CHECKING_STR => RTCIceConnectionState::Checking, + ICE_CONNECTION_STATE_CONNECTED_STR => RTCIceConnectionState::Connected, + ICE_CONNECTION_STATE_COMPLETED_STR => RTCIceConnectionState::Completed, + ICE_CONNECTION_STATE_DISCONNECTED_STR => RTCIceConnectionState::Disconnected, + ICE_CONNECTION_STATE_FAILED_STR => RTCIceConnectionState::Failed, + ICE_CONNECTION_STATE_CLOSED_STR => RTCIceConnectionState::Closed, + _ => RTCIceConnectionState::Unspecified, + } + } +} + +impl From for RTCIceConnectionState { + fn from(v: u8) -> Self { + match v { + 1 => RTCIceConnectionState::New, + 2 => RTCIceConnectionState::Checking, + 3 => RTCIceConnectionState::Connected, + 4 => RTCIceConnectionState::Completed, + 5 => RTCIceConnectionState::Disconnected, + 6 => RTCIceConnectionState::Failed, + 7 => RTCIceConnectionState::Closed, + _ => RTCIceConnectionState::Unspecified, + } + } +} + +impl fmt::Display for RTCIceConnectionState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + RTCIceConnectionState::New => ICE_CONNECTION_STATE_NEW_STR, + RTCIceConnectionState::Checking => ICE_CONNECTION_STATE_CHECKING_STR, + RTCIceConnectionState::Connected => ICE_CONNECTION_STATE_CONNECTED_STR, + RTCIceConnectionState::Completed => ICE_CONNECTION_STATE_COMPLETED_STR, + RTCIceConnectionState::Disconnected => ICE_CONNECTION_STATE_DISCONNECTED_STR, + RTCIceConnectionState::Failed => ICE_CONNECTION_STATE_FAILED_STR, + RTCIceConnectionState::Closed => ICE_CONNECTION_STATE_CLOSED_STR, + RTCIceConnectionState::Unspecified => crate::UNSPECIFIED_STR, + }; + write!(f, "{s}") + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_ice_connection_state() { + let tests = vec![ + (crate::UNSPECIFIED_STR, RTCIceConnectionState::Unspecified), + ("new", RTCIceConnectionState::New), + ("checking", RTCIceConnectionState::Checking), + ("connected", RTCIceConnectionState::Connected), + ("completed", RTCIceConnectionState::Completed), + ("disconnected", RTCIceConnectionState::Disconnected), + ("failed", RTCIceConnectionState::Failed), + ("closed", RTCIceConnectionState::Closed), + ]; + + for (state_string, expected_state) in tests { + assert_eq!( + RTCIceConnectionState::from(state_string), + expected_state, + "testCase: {expected_state}", + ); + } + } + + #[test] + fn test_ice_connection_state_string() { + let tests = vec![ + (RTCIceConnectionState::Unspecified, crate::UNSPECIFIED_STR), + (RTCIceConnectionState::New, "new"), + (RTCIceConnectionState::Checking, "checking"), + (RTCIceConnectionState::Connected, "connected"), + (RTCIceConnectionState::Completed, "completed"), + (RTCIceConnectionState::Disconnected, "disconnected"), + (RTCIceConnectionState::Failed, "failed"), + (RTCIceConnectionState::Closed, "closed"), + ]; + + for (state, expected_string) in tests { + assert_eq!(state.to_string(), expected_string) + } + } +} diff --git a/rtc/src/ice_transport/ice_credential_type.rs b/rtc/src/ice_transport/ice_credential_type.rs new file mode 100644 index 0000000..5db6c6c --- /dev/null +++ b/rtc/src/ice_transport/ice_credential_type.rs @@ -0,0 +1,74 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +/// ICECredentialType indicates the type of credentials used to connect to +/// an ICE server. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] +pub enum RTCIceCredentialType { + #[default] + Unspecified, + + /// ICECredential::Password describes username and password based + /// credentials as described in . + Password, + + /// ICECredential::Oauth describes token based credential as described + /// in . + /// Not supported in WebRTC 1.0 spec + Oauth, +} + +const ICE_CREDENTIAL_TYPE_PASSWORD_STR: &str = "password"; +const ICE_CREDENTIAL_TYPE_OAUTH_STR: &str = "oauth"; + +impl From<&str> for RTCIceCredentialType { + fn from(raw: &str) -> Self { + match raw { + ICE_CREDENTIAL_TYPE_PASSWORD_STR => RTCIceCredentialType::Password, + ICE_CREDENTIAL_TYPE_OAUTH_STR => RTCIceCredentialType::Oauth, + _ => RTCIceCredentialType::Unspecified, + } + } +} + +impl fmt::Display for RTCIceCredentialType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + RTCIceCredentialType::Password => write!(f, "{ICE_CREDENTIAL_TYPE_PASSWORD_STR}"), + RTCIceCredentialType::Oauth => write!(f, "{ICE_CREDENTIAL_TYPE_OAUTH_STR}"), + _ => write!(f, "{}", crate::UNSPECIFIED_STR), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_ice_credential_type() { + let tests = vec![ + ("Unspecified", RTCIceCredentialType::Unspecified), + ("password", RTCIceCredentialType::Password), + ("oauth", RTCIceCredentialType::Oauth), + ]; + + for (ct_str, expected_ct) in tests { + assert_eq!(RTCIceCredentialType::from(ct_str), expected_ct); + } + } + + #[test] + fn test_ice_credential_type_string() { + let tests = vec![ + (RTCIceCredentialType::Unspecified, "Unspecified"), + (RTCIceCredentialType::Password, "password"), + (RTCIceCredentialType::Oauth, "oauth"), + ]; + + for (ct, expected_string) in tests { + assert_eq!(ct.to_string(), expected_string); + } + } +} diff --git a/rtc/src/ice_transport/ice_gatherer.rs b/rtc/src/ice_transport/ice_gatherer.rs new file mode 100644 index 0000000..eda930b --- /dev/null +++ b/rtc/src/ice_transport/ice_gatherer.rs @@ -0,0 +1,409 @@ +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::Arc; + +use arc_swap::ArcSwapOption; +use ice::agent::Agent; +use ice::candidate::{Candidate, CandidateType}; +use ice::url::Url; +use tokio::sync::Mutex; + +use crate::api::setting_engine::SettingEngine; +use crate::error::{Error, Result}; +use crate::ice_transport::ice_candidate::*; +use crate::ice_transport::ice_candidate_type::RTCIceCandidateType; +use crate::ice_transport::ice_gatherer_state::RTCIceGathererState; +use crate::ice_transport::ice_parameters::RTCIceParameters; +use crate::ice_transport::ice_server::RTCIceServer; +use crate::peer_connection::policy::ice_transport_policy::RTCIceTransportPolicy; +use crate::stats::stats_collector::StatsCollector; +use crate::stats::SourceStatsType::*; +use crate::stats::{ICECandidatePairStats, StatsReportType}; + +/// ICEGatherOptions provides options relating to the gathering of ICE candidates. +#[derive(Default, Debug, Clone)] +pub struct RTCIceGatherOptions { + pub ice_servers: Vec, + pub ice_gather_policy: RTCIceTransportPolicy, +} + +pub type OnLocalCandidateHdlrFn = Box< + dyn (FnMut(Option) -> Pin + Send + 'static>>) + + Send + + Sync, +>; + +pub type OnICEGathererStateChangeHdlrFn = Box< + dyn (FnMut(RTCIceGathererState) -> Pin + Send + 'static>>) + + Send + + Sync, +>; + +pub type OnGatheringCompleteHdlrFn = + Box Pin + Send + 'static>>) + Send + Sync>; + +/// ICEGatherer gathers local host, server reflexive and relay +/// candidates, as well as enabling the retrieval of local Interactive +/// Connectivity Establishment (ICE) parameters which can be +/// exchanged in signaling. +#[derive(Default)] +pub struct RTCIceGatherer { + pub(crate) validated_servers: Vec, + pub(crate) gather_policy: RTCIceTransportPolicy, + pub(crate) setting_engine: Arc, + + pub(crate) state: Arc, //ICEGathererState, + pub(crate) agent: Mutex>>, + + pub(crate) on_local_candidate_handler: Arc>>, + pub(crate) on_state_change_handler: Arc>>, + + // Used for gathering_complete_promise + pub(crate) on_gathering_complete_handler: Arc>>, +} + +impl RTCIceGatherer { + pub(crate) fn new( + validated_servers: Vec, + gather_policy: RTCIceTransportPolicy, + setting_engine: Arc, + ) -> Self { + RTCIceGatherer { + gather_policy, + validated_servers, + setting_engine, + state: Arc::new(AtomicU8::new(RTCIceGathererState::New as u8)), + ..Default::default() + } + } + + pub(crate) async fn create_agent(&self) -> Result<()> { + // NOTE: A lock is held for the duration of this function in order to + // avoid potential double-agent creations. Care should be taken to + // ensure we do not do anything expensive other than the actual agent + // creation in this function. + let mut agent = self.agent.lock().await; + + if agent.is_some() || self.state() != RTCIceGathererState::New { + return Ok(()); + } + + let mut candidate_types = vec![]; + if self.setting_engine.candidates.ice_lite { + candidate_types.push(ice::candidate::CandidateType::Host); + } else if self.gather_policy == RTCIceTransportPolicy::Relay { + candidate_types.push(ice::candidate::CandidateType::Relay); + } + + let nat_1to1_cand_type = match self.setting_engine.candidates.nat_1to1_ip_candidate_type { + RTCIceCandidateType::Host => CandidateType::Host, + RTCIceCandidateType::Srflx => CandidateType::ServerReflexive, + _ => CandidateType::Unspecified, + }; + + let mdns_mode = self.setting_engine.candidates.multicast_dns_mode; + + let mut config = ice::agent::agent_config::AgentConfig { + udp_network: self.setting_engine.udp_network.clone(), + lite: self.setting_engine.candidates.ice_lite, + urls: self.validated_servers.clone(), + disconnected_timeout: self.setting_engine.timeout.ice_disconnected_timeout, + failed_timeout: self.setting_engine.timeout.ice_failed_timeout, + keepalive_interval: self.setting_engine.timeout.ice_keepalive_interval, + candidate_types, + host_acceptance_min_wait: self.setting_engine.timeout.ice_host_acceptance_min_wait, + srflx_acceptance_min_wait: self.setting_engine.timeout.ice_srflx_acceptance_min_wait, + prflx_acceptance_min_wait: self.setting_engine.timeout.ice_prflx_acceptance_min_wait, + relay_acceptance_min_wait: self.setting_engine.timeout.ice_relay_acceptance_min_wait, + interface_filter: self.setting_engine.candidates.interface_filter.clone(), + ip_filter: self.setting_engine.candidates.ip_filter.clone(), + nat_1to1_ips: self.setting_engine.candidates.nat_1to1_ips.clone(), + nat_1to1_ip_candidate_type: nat_1to1_cand_type, + net: self.setting_engine.vnet.clone(), + multicast_dns_mode: mdns_mode, + multicast_dns_host_name: self + .setting_engine + .candidates + .multicast_dns_host_name + .clone(), + local_ufrag: self.setting_engine.candidates.username_fragment.clone(), + local_pwd: self.setting_engine.candidates.password.clone(), + //TODO: TCPMux: self.setting_engine.iceTCPMux, + //TODO: ProxyDialer: self.setting_engine.iceProxyDialer, + ..Default::default() + }; + + let requested_network_types = if self.setting_engine.candidates.ice_network_types.is_empty() + { + ice::network_type::supported_network_types() + } else { + self.setting_engine.candidates.ice_network_types.clone() + }; + + config.network_types.extend(requested_network_types); + + *agent = Some(Arc::new(ice::agent::Agent::new(config).await?)); + + Ok(()) + } + + /// Gather ICE candidates. + pub async fn gather(&self) -> Result<()> { + self.create_agent().await?; + self.set_state(RTCIceGathererState::Gathering).await; + + if let Some(agent) = self.get_agent().await { + let state = Arc::clone(&self.state); + let on_local_candidate_handler = Arc::clone(&self.on_local_candidate_handler); + let on_state_change_handler = Arc::clone(&self.on_state_change_handler); + let on_gathering_complete_handler = Arc::clone(&self.on_gathering_complete_handler); + + agent.on_candidate(Box::new( + move |candidate: Option>| { + let state_clone = Arc::clone(&state); + let on_local_candidate_handler_clone = Arc::clone(&on_local_candidate_handler); + let on_state_change_handler_clone = Arc::clone(&on_state_change_handler); + let on_gathering_complete_handler_clone = + Arc::clone(&on_gathering_complete_handler); + + Box::pin(async move { + if let Some(cand) = candidate { + if let Some(handler) = &*on_local_candidate_handler_clone.load() { + let mut f = handler.lock().await; + f(Some(RTCIceCandidate::from(&cand))).await; + } + } else { + state_clone + .store(RTCIceGathererState::Complete as u8, Ordering::SeqCst); + + if let Some(handler) = &*on_state_change_handler_clone.load() { + let mut f = handler.lock().await; + f(RTCIceGathererState::Complete).await; + } + + if let Some(handler) = &*on_gathering_complete_handler_clone.load() { + let mut f = handler.lock().await; + f().await; + } + + if let Some(handler) = &*on_local_candidate_handler_clone.load() { + let mut f = handler.lock().await; + f(None).await; + } + } + }) + }, + )); + + agent.gather_candidates()?; + } + + Ok(()) + } + + /// Close prunes all local candidates, and closes the ports. + pub async fn close(&self) -> Result<()> { + self.set_state(RTCIceGathererState::Closed).await; + + let agent = { + let mut agent_opt = self.agent.lock().await; + agent_opt.take() + }; + + if let Some(agent) = agent { + agent.close().await?; + } + + Ok(()) + } + + /// get_local_parameters returns the ICE parameters of the ICEGatherer. + pub async fn get_local_parameters(&self) -> Result { + self.create_agent().await?; + + let (frag, pwd) = if let Some(agent) = self.get_agent().await { + agent.get_local_user_credentials().await + } else { + return Err(Error::ErrICEAgentNotExist); + }; + + Ok(RTCIceParameters { + username_fragment: frag, + password: pwd, + ice_lite: false, + }) + } + + /// get_local_candidates returns the sequence of valid local candidates associated with the ICEGatherer. + pub async fn get_local_candidates(&self) -> Result> { + self.create_agent().await?; + + let ice_candidates = if let Some(agent) = self.get_agent().await { + agent.get_local_candidates().await? + } else { + return Err(Error::ErrICEAgentNotExist); + }; + + Ok(rtc_ice_candidates_from_ice_candidates(&ice_candidates)) + } + + /// on_local_candidate sets an event handler which fires when a new local ICE candidate is available + /// Take note that the handler is gonna be called with a nil pointer when gathering is finished. + pub fn on_local_candidate(&self, f: OnLocalCandidateHdlrFn) { + self.on_local_candidate_handler + .store(Some(Arc::new(Mutex::new(f)))); + } + + /// on_state_change sets an event handler which fires any time the ICEGatherer changes + pub fn on_state_change(&self, f: OnICEGathererStateChangeHdlrFn) { + self.on_state_change_handler + .store(Some(Arc::new(Mutex::new(f)))); + } + + /// on_gathering_complete sets an event handler which fires any time the ICEGatherer changes + pub fn on_gathering_complete(&self, f: OnGatheringCompleteHdlrFn) { + self.on_gathering_complete_handler + .store(Some(Arc::new(Mutex::new(f)))); + } + + /// State indicates the current state of the ICE gatherer. + pub fn state(&self) -> RTCIceGathererState { + self.state.load(Ordering::SeqCst).into() + } + + pub async fn set_state(&self, s: RTCIceGathererState) { + self.state.store(s as u8, Ordering::SeqCst); + + if let Some(handler) = &*self.on_state_change_handler.load() { + let mut f = handler.lock().await; + f(s).await; + } + } + + pub(crate) async fn get_agent(&self) -> Option> { + let agent = self.agent.lock().await; + agent.clone() + } + + pub(crate) async fn collect_stats(&self, collector: &StatsCollector) { + if let Some(agent) = self.get_agent().await { + let mut reports = HashMap::new(); + + for stats in agent.get_candidate_pairs_stats().await { + let stats: ICECandidatePairStats = stats.into(); + reports.insert(stats.id.clone(), StatsReportType::CandidatePair(stats)); + } + + for stats in agent.get_local_candidates_stats().await { + reports.insert( + stats.id.clone(), + StatsReportType::from(LocalCandidate(stats)), + ); + } + + for stats in agent.get_remote_candidates_stats().await { + reports.insert( + stats.id.clone(), + StatsReportType::from(RemoteCandidate(stats)), + ); + } + + collector.merge(reports); + } + } +} + +#[cfg(test)] +mod test { + use tokio::sync::mpsc; + + use super::*; + use crate::api::APIBuilder; + use crate::ice_transport::ice_gatherer::RTCIceGatherOptions; + use crate::ice_transport::ice_server::RTCIceServer; + + #[tokio::test] + async fn test_new_ice_gatherer_success() -> Result<()> { + let opts = RTCIceGatherOptions { + ice_servers: vec![RTCIceServer { + urls: vec!["stun:stun.l.google.com:19302".to_owned()], + ..Default::default() + }], + ..Default::default() + }; + + let gatherer = APIBuilder::new().build().new_ice_gatherer(opts)?; + + assert_eq!( + gatherer.state(), + RTCIceGathererState::New, + "Expected gathering state new" + ); + + let (gather_finished_tx, mut gather_finished_rx) = mpsc::channel::<()>(1); + let gather_finished_tx = Arc::new(Mutex::new(Some(gather_finished_tx))); + gatherer.on_local_candidate(Box::new(move |c: Option| { + let gather_finished_tx_clone = Arc::clone(&gather_finished_tx); + Box::pin(async move { + if c.is_none() { + let mut tx = gather_finished_tx_clone.lock().await; + tx.take(); + } + }) + })); + + gatherer.gather().await?; + + let _ = gather_finished_rx.recv().await; + + let params = gatherer.get_local_parameters().await?; + + assert!( + !params.username_fragment.is_empty() && !params.password.is_empty(), + "Empty local username or password frag" + ); + + let candidates = gatherer.get_local_candidates().await?; + + assert!(!candidates.is_empty(), "No candidates gathered"); + + gatherer.close().await?; + + Ok(()) + } + + #[tokio::test] + async fn test_ice_gather_mdns_candidate_gathering() -> Result<()> { + let mut s = SettingEngine::default(); + s.set_ice_multicast_dns_mode(ice::mdns::MulticastDnsMode::QueryAndGather); + + let gatherer = APIBuilder::new() + .with_setting_engine(s) + .build() + .new_ice_gatherer(RTCIceGatherOptions::default())?; + + let (done_tx, mut done_rx) = mpsc::channel::<()>(1); + let done_tx = Arc::new(Mutex::new(Some(done_tx))); + gatherer.on_local_candidate(Box::new(move |c: Option| { + let done_tx_clone = Arc::clone(&done_tx); + Box::pin(async move { + if let Some(c) = c { + if c.address.ends_with(".local") { + let mut tx = done_tx_clone.lock().await; + tx.take(); + } + } + }) + })); + + gatherer.gather().await?; + + let _ = done_rx.recv().await; + + gatherer.close().await?; + + Ok(()) + } +} diff --git a/rtc/src/ice_transport/ice_gatherer_state.rs b/rtc/src/ice_transport/ice_gatherer_state.rs new file mode 100644 index 0000000..7b24e99 --- /dev/null +++ b/rtc/src/ice_transport/ice_gatherer_state.rs @@ -0,0 +1,88 @@ +use std::fmt; + +/// ICEGathererState represents the current state of the ICE gatherer. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RTCIceGathererState { + #[default] + Unspecified, + + /// ICEGathererStateNew indicates object has been created but + /// gather() has not been called. + New, + + /// ICEGathererStateGathering indicates gather() has been called, + /// and the ICEGatherer is in the process of gathering candidates. + Gathering, + + /// ICEGathererStateComplete indicates the ICEGatherer has completed gathering. + Complete, + + /// ICEGathererStateClosed indicates the closed state can only be entered + /// when the ICEGatherer has been closed intentionally by calling close(). + Closed, +} + +const ICE_GATHERED_STATE_NEW_STR: &str = "new"; +const ICE_GATHERED_STATE_GATHERING_STR: &str = "gathering"; +const ICE_GATHERED_STATE_COMPLETE_STR: &str = "complete"; +const ICE_GATHERED_STATE_CLOSED_STR: &str = "closed"; + +impl From<&str> for RTCIceGathererState { + fn from(raw: &str) -> Self { + match raw { + ICE_GATHERED_STATE_NEW_STR => RTCIceGathererState::New, + ICE_GATHERED_STATE_GATHERING_STR => RTCIceGathererState::Gathering, + ICE_GATHERED_STATE_COMPLETE_STR => RTCIceGathererState::Complete, + ICE_GATHERED_STATE_CLOSED_STR => RTCIceGathererState::Closed, + _ => RTCIceGathererState::Unspecified, + } + } +} + +impl From for RTCIceGathererState { + fn from(v: u8) -> Self { + match v { + 1 => RTCIceGathererState::New, + 2 => RTCIceGathererState::Gathering, + 3 => RTCIceGathererState::Complete, + 4 => RTCIceGathererState::Closed, + _ => RTCIceGathererState::Unspecified, + } + } +} + +impl fmt::Display for RTCIceGathererState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + RTCIceGathererState::New => write!(f, "{ICE_GATHERED_STATE_NEW_STR}"), + RTCIceGathererState::Gathering => write!(f, "{ICE_GATHERED_STATE_GATHERING_STR}"), + RTCIceGathererState::Complete => { + write!(f, "{ICE_GATHERED_STATE_COMPLETE_STR}") + } + RTCIceGathererState::Closed => { + write!(f, "{ICE_GATHERED_STATE_CLOSED_STR}") + } + _ => write!(f, "{}", crate::UNSPECIFIED_STR), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_ice_gatherer_state_string() { + let tests = vec![ + (RTCIceGathererState::Unspecified, "Unspecified"), + (RTCIceGathererState::New, "new"), + (RTCIceGathererState::Gathering, "gathering"), + (RTCIceGathererState::Complete, "complete"), + (RTCIceGathererState::Closed, "closed"), + ]; + + for (state, expected_string) in tests { + assert_eq!(state.to_string(), expected_string); + } + } +} diff --git a/rtc/src/ice_transport/ice_gathering_state.rs b/rtc/src/ice_transport/ice_gathering_state.rs new file mode 100644 index 0000000..fa04331 --- /dev/null +++ b/rtc/src/ice_transport/ice_gathering_state.rs @@ -0,0 +1,83 @@ +use std::fmt; + +/// ICEGatheringState describes the state of the candidate gathering process. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RTCIceGatheringState { + #[default] + Unspecified, + + /// ICEGatheringStateNew indicates that any of the ICETransports are + /// in the "new" gathering state and none of the transports are in the + /// "gathering" state, or there are no transports. + New, + + /// ICEGatheringStateGathering indicates that any of the ICETransports + /// are in the "gathering" state. + Gathering, + + /// ICEGatheringStateComplete indicates that at least one ICETransport + /// exists, and all ICETransports are in the "completed" gathering state. + Complete, +} + +const ICE_GATHERING_STATE_NEW_STR: &str = "new"; +const ICE_GATHERING_STATE_GATHERING_STR: &str = "gathering"; +const ICE_GATHERING_STATE_COMPLETE_STR: &str = "complete"; + +/// takes a string and converts it to ICEGatheringState +impl From<&str> for RTCIceGatheringState { + fn from(raw: &str) -> Self { + match raw { + ICE_GATHERING_STATE_NEW_STR => RTCIceGatheringState::New, + ICE_GATHERING_STATE_GATHERING_STR => RTCIceGatheringState::Gathering, + ICE_GATHERING_STATE_COMPLETE_STR => RTCIceGatheringState::Complete, + _ => RTCIceGatheringState::Unspecified, + } + } +} + +impl fmt::Display for RTCIceGatheringState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + RTCIceGatheringState::New => write!(f, "{ICE_GATHERING_STATE_NEW_STR}"), + RTCIceGatheringState::Gathering => write!(f, "{ICE_GATHERING_STATE_GATHERING_STR}"), + RTCIceGatheringState::Complete => { + write!(f, "{ICE_GATHERING_STATE_COMPLETE_STR}") + } + _ => write!(f, "{}", crate::UNSPECIFIED_STR), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_ice_gathering_state() { + let tests = vec![ + ("Unspecified", RTCIceGatheringState::Unspecified), + ("new", RTCIceGatheringState::New), + ("gathering", RTCIceGatheringState::Gathering), + ("complete", RTCIceGatheringState::Complete), + ]; + + for (state_string, expected_state) in tests { + assert_eq!(RTCIceGatheringState::from(state_string), expected_state); + } + } + + #[test] + fn test_ice_gathering_state_string() { + let tests = vec![ + (RTCIceGatheringState::Unspecified, "Unspecified"), + (RTCIceGatheringState::New, "new"), + (RTCIceGatheringState::Gathering, "gathering"), + (RTCIceGatheringState::Complete, "complete"), + ]; + + for (state, expected_string) in tests { + assert_eq!(state.to_string(), expected_string); + } + } +} diff --git a/rtc/src/ice_transport/ice_parameters.rs b/rtc/src/ice_transport/ice_parameters.rs new file mode 100644 index 0000000..048e359 --- /dev/null +++ b/rtc/src/ice_transport/ice_parameters.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; + +/// ICEParameters includes the ICE username fragment +/// and password and other ICE-related parameters. +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RTCIceParameters { + pub username_fragment: String, + pub password: String, + pub ice_lite: bool, +} diff --git a/rtc/src/ice_transport/ice_protocol.rs b/rtc/src/ice_transport/ice_protocol.rs new file mode 100644 index 0000000..3089125 --- /dev/null +++ b/rtc/src/ice_transport/ice_protocol.rs @@ -0,0 +1,79 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +/// ICEProtocol indicates the transport protocol type that is used in the +/// ice.URL structure. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum RTCIceProtocol { + #[default] + Unspecified, + + /// UDP indicates the URL uses a UDP transport. + #[serde(rename = "udp")] + Udp, + + /// TCP indicates the URL uses a TCP transport. + #[serde(rename = "tcp")] + Tcp, +} + +const ICE_PROTOCOL_UDP_STR: &str = "udp"; +const ICE_PROTOCOL_TCP_STR: &str = "tcp"; + +/// takes a string and converts it to ICEProtocol +impl From<&str> for RTCIceProtocol { + fn from(raw: &str) -> Self { + if raw.to_uppercase() == ICE_PROTOCOL_UDP_STR.to_uppercase() { + RTCIceProtocol::Udp + } else if raw.to_uppercase() == ICE_PROTOCOL_TCP_STR.to_uppercase() { + RTCIceProtocol::Tcp + } else { + RTCIceProtocol::Unspecified + } + } +} + +impl fmt::Display for RTCIceProtocol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + RTCIceProtocol::Udp => write!(f, "{ICE_PROTOCOL_UDP_STR}"), + RTCIceProtocol::Tcp => write!(f, "{ICE_PROTOCOL_TCP_STR}"), + _ => write!(f, "{}", crate::UNSPECIFIED_STR), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_ice_protocol() { + let tests = vec![ + ("Unspecified", RTCIceProtocol::Unspecified), + ("udp", RTCIceProtocol::Udp), + ("tcp", RTCIceProtocol::Tcp), + ("UDP", RTCIceProtocol::Udp), + ("TCP", RTCIceProtocol::Tcp), + ]; + + for (proto_string, expected_proto) in tests { + let actual = RTCIceProtocol::from(proto_string); + assert_eq!(actual, expected_proto); + } + } + + #[test] + fn test_ice_protocol_string() { + let tests = vec![ + (RTCIceProtocol::Unspecified, "Unspecified"), + (RTCIceProtocol::Udp, "udp"), + (RTCIceProtocol::Tcp, "tcp"), + ]; + + for (proto, expected_string) in tests { + assert_eq!(proto.to_string(), expected_string); + } + } +} diff --git a/rtc/src/ice_transport/ice_role.rs b/rtc/src/ice_transport/ice_role.rs new file mode 100644 index 0000000..699dd4b --- /dev/null +++ b/rtc/src/ice_transport/ice_role.rs @@ -0,0 +1,73 @@ +use std::fmt; + +/// ICERole describes the role ice.Agent is playing in selecting the +/// preferred the candidate pair. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RTCIceRole { + #[default] + Unspecified, + + /// ICERoleControlling indicates that the ICE agent that is responsible + /// for selecting the final choice of candidate pairs and signaling them + /// through STUN and an updated offer, if needed. In any session, one agent + /// is always controlling. The other is the controlled agent. + Controlling, + + /// ICERoleControlled indicates that an ICE agent that waits for the + /// controlling agent to select the final choice of candidate pairs. + Controlled, +} + +const ICE_ROLE_CONTROLLING_STR: &str = "controlling"; +const ICE_ROLE_CONTROLLED_STR: &str = "controlled"; + +impl From<&str> for RTCIceRole { + fn from(raw: &str) -> Self { + match raw { + ICE_ROLE_CONTROLLING_STR => RTCIceRole::Controlling, + ICE_ROLE_CONTROLLED_STR => RTCIceRole::Controlled, + _ => RTCIceRole::Unspecified, + } + } +} + +impl fmt::Display for RTCIceRole { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + RTCIceRole::Controlling => write!(f, "{ICE_ROLE_CONTROLLING_STR}"), + RTCIceRole::Controlled => write!(f, "{ICE_ROLE_CONTROLLED_STR}"), + _ => write!(f, "{}", crate::UNSPECIFIED_STR), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_ice_role() { + let tests = vec![ + ("Unspecified", RTCIceRole::Unspecified), + ("controlling", RTCIceRole::Controlling), + ("controlled", RTCIceRole::Controlled), + ]; + + for (role_string, expected_role) in tests { + assert_eq!(RTCIceRole::from(role_string), expected_role); + } + } + + #[test] + fn test_ice_role_string() { + let tests = vec![ + (RTCIceRole::Unspecified, "Unspecified"), + (RTCIceRole::Controlling, "controlling"), + (RTCIceRole::Controlled, "controlled"), + ]; + + for (proto, expected_string) in tests { + assert_eq!(proto.to_string(), expected_string); + } + } +} diff --git a/rtc/src/ice_transport/ice_server.rs b/rtc/src/ice_transport/ice_server.rs new file mode 100644 index 0000000..3cfc66f --- /dev/null +++ b/rtc/src/ice_transport/ice_server.rs @@ -0,0 +1,173 @@ +use serde::{Deserialize, Serialize}; + +use crate::error::{Error, Result}; +use crate::ice_transport::ice_credential_type::RTCIceCredentialType; + +/// ICEServer describes a single STUN and TURN server that can be used by +/// the ICEAgent to establish a connection with a peer. +#[derive(Default, Debug, Clone, Serialize, Deserialize, Hash)] +pub struct RTCIceServer { + pub urls: Vec, + pub username: String, + pub credential: String, + pub credential_type: RTCIceCredentialType, +} + +impl RTCIceServer { + pub(crate) fn parse_url(&self, url_str: &str) -> Result { + Ok(ice::url::Url::parse_url(url_str)?) + } + + pub(crate) fn validate(&self) -> Result<()> { + self.urls()?; + Ok(()) + } + + pub(crate) fn urls(&self) -> Result> { + let mut urls = vec![]; + + for url_str in &self.urls { + let mut url = self.parse_url(url_str)?; + if url.scheme == ice::url::SchemeType::Turn || url.scheme == ice::url::SchemeType::Turns + { + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11.3.2) + if self.username.is_empty() || self.credential.is_empty() { + return Err(Error::ErrNoTurnCredentials); + } + url.username = self.username.clone(); + + match self.credential_type { + RTCIceCredentialType::Password => { + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11.3.3) + url.password = self.credential.clone(); + } + RTCIceCredentialType::Oauth => { + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11.3.4) + /*if _, ok: = s.Credential.(OAuthCredential); !ok { + return nil, + &rtcerr.InvalidAccessError{Err: ErrTurnCredentials + } + }*/ + } + _ => return Err(Error::ErrTurnCredentials), + }; + } + + urls.push(url); + } + + Ok(urls) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_ice_server_validate_success() { + let tests = vec![ + ( + RTCIceServer { + urls: vec!["turn:192.158.29.39?transport=udp".to_owned()], + username: "unittest".to_owned(), + credential: "placeholder".to_owned(), + credential_type: RTCIceCredentialType::Password, + }, + true, + ), + ( + RTCIceServer { + urls: vec!["turn:[2001:db8:1234:5678::1]?transport=udp".to_owned()], + username: "unittest".to_owned(), + credential: "placeholder".to_owned(), + credential_type: RTCIceCredentialType::Password, + }, + true, + ), + /*TODO:(ICEServer{ + URLs: []string{"turn:192.158.29.39?transport=udp"}, + Username: "unittest".to_owned(), + Credential: OAuthCredential{ + MACKey: "WmtzanB3ZW9peFhtdm42NzUzNG0=", + AccessToken: "AAwg3kPHWPfvk9bDFL936wYvkoctMADzQ5VhNDgeMR3+ZlZ35byg972fW8QjpEl7bx91YLBPFsIhsxloWcXPhA==", + }, + CredentialType: ICECredentialTypeOauth, + }, true),*/ + ]; + + for (ice_server, expected_validate) in tests { + let result = ice_server.urls(); + assert_eq!(result.is_ok(), expected_validate); + } + } + + #[test] + fn test_ice_server_validate_failure() { + let tests = vec![ + ( + RTCIceServer { + urls: vec!["turn:192.158.29.39?transport=udp".to_owned()], + ..Default::default() + }, + Error::ErrNoTurnCredentials, + ), + ( + RTCIceServer { + urls: vec!["turn:192.158.29.39?transport=udp".to_owned()], + username: "unittest".to_owned(), + credential: String::new(), + credential_type: RTCIceCredentialType::Password, + }, + Error::ErrNoTurnCredentials, + ), + ( + RTCIceServer { + urls: vec!["turn:192.158.29.39?transport=udp".to_owned()], + username: "unittest".to_owned(), + credential: String::new(), + credential_type: RTCIceCredentialType::Oauth, + }, + Error::ErrNoTurnCredentials, + ), + ( + RTCIceServer { + urls: vec!["turn:192.158.29.39?transport=udp".to_owned()], + username: "unittest".to_owned(), + credential: String::new(), + credential_type: RTCIceCredentialType::Unspecified, + }, + Error::ErrNoTurnCredentials, + ), + ]; + + for (ice_server, expected_err) in tests { + if let Err(err) = ice_server.urls() { + assert_eq!(err, expected_err, "{ice_server:?} with err {err:?}"); + } else { + panic!("expected error, but got ok"); + } + } + } + + #[test] + fn test_ice_server_validate_failure_err_stun_query() { + let tests = vec![( + RTCIceServer { + urls: vec!["stun:google.de?transport=udp".to_owned()], + username: "unittest".to_owned(), + credential: String::new(), + credential_type: RTCIceCredentialType::Oauth, + }, + ice::Error::ErrStunQuery, + )]; + + for (ice_server, expected_err) in tests { + if let Err(err) = ice_server.urls() { + assert_eq!(err, expected_err, "{ice_server:?} with err {err:?}"); + } else { + panic!("expected error, but got ok"); + } + } + } +} diff --git a/rtc/src/ice_transport/ice_transport_state.rs b/rtc/src/ice_transport/ice_transport_state.rs new file mode 100644 index 0000000..2abdede --- /dev/null +++ b/rtc/src/ice_transport/ice_transport_state.rs @@ -0,0 +1,185 @@ +use std::fmt; + +use ice::state::ConnectionState; + +/// ICETransportState represents the current state of the ICE transport. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RTCIceTransportState { + #[default] + Unspecified, + + /// ICETransportStateNew indicates the ICETransport is waiting + /// for remote candidates to be supplied. + New, + + /// ICETransportStateChecking indicates the ICETransport has + /// received at least one remote candidate, and a local and remote + /// ICECandidateComplete dictionary was not added as the last candidate. + Checking, + + /// ICETransportStateConnected indicates the ICETransport has + /// received a response to an outgoing connectivity check, or has + /// received incoming DTLS/media after a successful response to an + /// incoming connectivity check, but is still checking other candidate + /// pairs to see if there is a better connection. + Connected, + + /// ICETransportStateCompleted indicates the ICETransport tested + /// all appropriate candidate pairs and at least one functioning + /// candidate pair has been found. + Completed, + + /// ICETransportStateFailed indicates the ICETransport the last + /// candidate was added and all appropriate candidate pairs have either + /// failed connectivity checks or have lost consent. + Failed, + + /// ICETransportStateDisconnected indicates the ICETransport has received + /// at least one local and remote candidate, but the final candidate was + /// received yet and all appropriate candidate pairs thus far have been + /// tested and failed. + Disconnected, + + /// ICETransportStateClosed indicates the ICETransport has shut down + /// and is no longer responding to STUN requests. + Closed, +} + +const ICE_TRANSPORT_STATE_NEW_STR: &str = "new"; +const ICE_TRANSPORT_STATE_CHECKING_STR: &str = "checking"; +const ICE_TRANSPORT_STATE_CONNECTED_STR: &str = "connected"; +const ICE_TRANSPORT_STATE_COMPLETED_STR: &str = "completed"; +const ICE_TRANSPORT_STATE_FAILED_STR: &str = "failed"; +const ICE_TRANSPORT_STATE_DISCONNECTED_STR: &str = "disconnected"; +const ICE_TRANSPORT_STATE_CLOSED_STR: &str = "closed"; + +impl From<&str> for RTCIceTransportState { + fn from(raw: &str) -> Self { + match raw { + ICE_TRANSPORT_STATE_NEW_STR => RTCIceTransportState::New, + ICE_TRANSPORT_STATE_CHECKING_STR => RTCIceTransportState::Checking, + ICE_TRANSPORT_STATE_CONNECTED_STR => RTCIceTransportState::Connected, + ICE_TRANSPORT_STATE_COMPLETED_STR => RTCIceTransportState::Completed, + ICE_TRANSPORT_STATE_FAILED_STR => RTCIceTransportState::Failed, + ICE_TRANSPORT_STATE_DISCONNECTED_STR => RTCIceTransportState::Disconnected, + ICE_TRANSPORT_STATE_CLOSED_STR => RTCIceTransportState::Closed, + _ => RTCIceTransportState::Unspecified, + } + } +} + +impl From for RTCIceTransportState { + fn from(v: u8) -> Self { + match v { + 1 => Self::New, + 2 => Self::Checking, + 3 => Self::Connected, + 4 => Self::Completed, + 5 => Self::Failed, + 6 => Self::Disconnected, + 7 => Self::Closed, + _ => Self::Unspecified, + } + } +} + +impl fmt::Display for RTCIceTransportState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + RTCIceTransportState::New => write!(f, "{ICE_TRANSPORT_STATE_NEW_STR}"), + RTCIceTransportState::Checking => write!(f, "{ICE_TRANSPORT_STATE_CHECKING_STR}"), + RTCIceTransportState::Connected => { + write!(f, "{ICE_TRANSPORT_STATE_CONNECTED_STR}") + } + RTCIceTransportState::Completed => write!(f, "{ICE_TRANSPORT_STATE_COMPLETED_STR}"), + RTCIceTransportState::Failed => { + write!(f, "{ICE_TRANSPORT_STATE_FAILED_STR}") + } + RTCIceTransportState::Disconnected => { + write!(f, "{ICE_TRANSPORT_STATE_DISCONNECTED_STR}") + } + RTCIceTransportState::Closed => { + write!(f, "{ICE_TRANSPORT_STATE_CLOSED_STR}") + } + _ => write!(f, "{}", crate::UNSPECIFIED_STR), + } + } +} + +impl From for RTCIceTransportState { + fn from(raw: ConnectionState) -> Self { + match raw { + ConnectionState::New => RTCIceTransportState::New, + ConnectionState::Checking => RTCIceTransportState::Checking, + ConnectionState::Connected => RTCIceTransportState::Connected, + ConnectionState::Completed => RTCIceTransportState::Completed, + ConnectionState::Failed => RTCIceTransportState::Failed, + ConnectionState::Disconnected => RTCIceTransportState::Disconnected, + ConnectionState::Closed => RTCIceTransportState::Closed, + _ => RTCIceTransportState::Unspecified, + } + } +} + +impl RTCIceTransportState { + pub(crate) fn to_ice(self) -> ConnectionState { + match self { + RTCIceTransportState::New => ConnectionState::New, + RTCIceTransportState::Checking => ConnectionState::Checking, + RTCIceTransportState::Connected => ConnectionState::Connected, + RTCIceTransportState::Completed => ConnectionState::Completed, + RTCIceTransportState::Failed => ConnectionState::Failed, + RTCIceTransportState::Disconnected => ConnectionState::Disconnected, + RTCIceTransportState::Closed => ConnectionState::Closed, + _ => ConnectionState::Unspecified, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_ice_transport_state_string() { + let tests = vec![ + (RTCIceTransportState::Unspecified, "Unspecified"), + (RTCIceTransportState::New, "new"), + (RTCIceTransportState::Checking, "checking"), + (RTCIceTransportState::Connected, "connected"), + (RTCIceTransportState::Completed, "completed"), + (RTCIceTransportState::Failed, "failed"), + (RTCIceTransportState::Disconnected, "disconnected"), + (RTCIceTransportState::Closed, "closed"), + ]; + + for (state, expected_string) in tests { + assert_eq!(state.to_string(), expected_string); + } + } + + #[test] + fn test_ice_transport_state_convert() { + let tests = vec![ + ( + RTCIceTransportState::Unspecified, + ConnectionState::Unspecified, + ), + (RTCIceTransportState::New, ConnectionState::New), + (RTCIceTransportState::Checking, ConnectionState::Checking), + (RTCIceTransportState::Connected, ConnectionState::Connected), + (RTCIceTransportState::Completed, ConnectionState::Completed), + (RTCIceTransportState::Failed, ConnectionState::Failed), + ( + RTCIceTransportState::Disconnected, + ConnectionState::Disconnected, + ), + (RTCIceTransportState::Closed, ConnectionState::Closed), + ]; + + for (native, ice_state) in tests { + assert_eq!(native.to_ice(), ice_state); + assert_eq!(native, RTCIceTransportState::from(ice_state)); + } + } +} diff --git a/rtc/src/ice_transport/ice_transport_test.rs b/rtc/src/ice_transport/ice_transport_test.rs new file mode 100644 index 0000000..866f78a --- /dev/null +++ b/rtc/src/ice_transport/ice_transport_test.rs @@ -0,0 +1,123 @@ +use std::sync::atomic::AtomicU32; + +use tokio::time::Duration; +use waitgroup::WaitGroup; + +use super::*; +use crate::api::media_engine::MediaEngine; +use crate::api::APIBuilder; +use crate::error::Result; +use crate::ice_transport::ice_connection_state::RTCIceConnectionState; +use crate::peer_connection::peer_connection_state::RTCPeerConnectionState; +use crate::peer_connection::peer_connection_test::{ + close_pair_now, new_pair, signal_pair, until_connection_state, +}; + +#[tokio::test] +async fn test_ice_transport_on_selected_candidate_pair_change() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (mut pc_offer, mut pc_answer) = new_pair(&api).await?; + + let (ice_complete_tx, mut ice_complete_rx) = mpsc::channel::<()>(1); + let ice_complete_tx = Arc::new(Mutex::new(Some(ice_complete_tx))); + pc_answer.on_ice_connection_state_change(Box::new(move |ice_state: RTCIceConnectionState| { + let ice_complete_tx2 = Arc::clone(&ice_complete_tx); + Box::pin(async move { + if ice_state == RTCIceConnectionState::Connected { + tokio::time::sleep(Duration::from_secs(1)).await; + let mut done = ice_complete_tx2.lock().await; + done.take(); + } + }) + })); + + let sender_called_candidate_change = Arc::new(AtomicU32::new(0)); + let sender_called_candidate_change2 = Arc::clone(&sender_called_candidate_change); + pc_offer + .sctp() + .transport() + .ice_transport() + .on_selected_candidate_pair_change(Box::new(move |_: RTCIceCandidatePair| { + sender_called_candidate_change2.store(1, Ordering::SeqCst); + Box::pin(async {}) + })); + + signal_pair(&mut pc_offer, &mut pc_answer).await?; + + let _ = ice_complete_rx.recv().await; + assert_eq!( + sender_called_candidate_change.load(Ordering::SeqCst), + 1, + "Sender ICETransport OnSelectedCandidateChange was never called" + ); + + close_pair_now(&pc_offer, &pc_answer).await; + + Ok(()) +} + +#[tokio::test] +async fn test_ice_transport_get_selected_candidate_pair() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (mut offerer, mut answerer) = new_pair(&api).await?; + + let peer_connection_connected = WaitGroup::new(); + until_connection_state( + &mut offerer, + &peer_connection_connected, + RTCPeerConnectionState::Connected, + ) + .await; + until_connection_state( + &mut answerer, + &peer_connection_connected, + RTCPeerConnectionState::Connected, + ) + .await; + + let offerer_selected_pair = offerer + .sctp() + .transport() + .ice_transport() + .get_selected_candidate_pair() + .await; + assert!(offerer_selected_pair.is_none()); + + let answerer_selected_pair = answerer + .sctp() + .transport() + .ice_transport() + .get_selected_candidate_pair() + .await; + assert!(answerer_selected_pair.is_none()); + + signal_pair(&mut offerer, &mut answerer).await?; + + peer_connection_connected.wait().await; + + let offerer_selected_pair = offerer + .sctp() + .transport() + .ice_transport() + .get_selected_candidate_pair() + .await; + assert!(offerer_selected_pair.is_some()); + + let answerer_selected_pair = answerer + .sctp() + .transport() + .ice_transport() + .get_selected_candidate_pair() + .await; + assert!(answerer_selected_pair.is_some()); + + close_pair_now(&offerer, &answerer).await; + + Ok(()) +} diff --git a/rtc/src/ice_transport/mod.rs b/rtc/src/ice_transport/mod.rs new file mode 100644 index 0000000..6bf09c0 --- /dev/null +++ b/rtc/src/ice_transport/mod.rs @@ -0,0 +1,355 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::Arc; + +use arc_swap::ArcSwapOption; +use ice::candidate::Candidate; +use ice::state::ConnectionState; +use ice_candidate::RTCIceCandidate; +use ice_candidate_pair::RTCIceCandidatePair; +use ice_gatherer::RTCIceGatherer; +use ice_role::RTCIceRole; +use tokio::sync::{mpsc, Mutex}; +use util::Conn; + +use crate::error::{flatten_errs, Error, Result}; +use crate::ice_transport::ice_parameters::RTCIceParameters; +use crate::ice_transport::ice_transport_state::RTCIceTransportState; +use crate::mux::endpoint::Endpoint; +use crate::mux::mux_func::MatchFunc; +use crate::mux::{Config, Mux}; +use crate::stats::stats_collector::StatsCollector; +use crate::stats::ICETransportStats; +use crate::stats::StatsReportType::Transport; + +#[cfg(test)] +mod ice_transport_test; + +pub mod ice_candidate; +pub mod ice_candidate_pair; +pub mod ice_candidate_type; +pub mod ice_connection_state; +pub mod ice_credential_type; +pub mod ice_gatherer; +pub mod ice_gatherer_state; +pub mod ice_gathering_state; +pub mod ice_parameters; +pub mod ice_protocol; +pub mod ice_role; +pub mod ice_server; +pub mod ice_transport_state; + +pub type OnConnectionStateChangeHdlrFn = Box< + dyn (FnMut(RTCIceTransportState) -> Pin + Send + 'static>>) + + Send + + Sync, +>; + +pub type OnSelectedCandidatePairChangeHdlrFn = Box< + dyn (FnMut(RTCIceCandidatePair) -> Pin + Send + 'static>>) + + Send + + Sync, +>; + +#[derive(Default)] +struct ICETransportInternal { + role: RTCIceRole, + conn: Option>, //AgentConn + mux: Option, + cancel_tx: Option>, +} + +/// ICETransport allows an application access to information about the ICE +/// transport over which packets are sent and received. +#[derive(Default)] +pub struct RTCIceTransport { + pub(crate) gatherer: Arc, + on_connection_state_change_handler: Arc>>, + on_selected_candidate_pair_change_handler: + Arc>>, + state: Arc, // ICETransportState + internal: Mutex, +} + +impl RTCIceTransport { + /// creates a new new_icetransport. + pub(crate) fn new(gatherer: Arc) -> Self { + RTCIceTransport { + state: Arc::new(AtomicU8::new(RTCIceTransportState::New as u8)), + gatherer, + ..Default::default() + } + } + + /// get_selected_candidate_pair returns the selected candidate pair on which packets are sent + /// if there is no selected pair nil is returned + pub async fn get_selected_candidate_pair(&self) -> Option { + if let Some(agent) = self.gatherer.get_agent().await { + if let Some(ice_pair) = agent.get_selected_candidate_pair() { + let local = RTCIceCandidate::from(&ice_pair.local); + let remote = RTCIceCandidate::from(&ice_pair.remote); + return Some(RTCIceCandidatePair::new(local, remote)); + } + } + None + } + + /// Start incoming connectivity checks based on its configured role. + pub async fn start(&self, params: &RTCIceParameters, role: Option) -> Result<()> { + if self.state() != RTCIceTransportState::New { + return Err(Error::ErrICETransportNotInNew); + } + + self.ensure_gatherer().await?; + + if let Some(agent) = self.gatherer.get_agent().await { + let state = Arc::clone(&self.state); + + let on_connection_state_change_handler = + Arc::clone(&self.on_connection_state_change_handler); + agent.on_connection_state_change(Box::new(move |ice_state: ConnectionState| { + let s = RTCIceTransportState::from(ice_state); + let on_connection_state_change_handler_clone = + Arc::clone(&on_connection_state_change_handler); + state.store(s as u8, Ordering::SeqCst); + Box::pin(async move { + if let Some(handler) = &*on_connection_state_change_handler_clone.load() { + let mut f = handler.lock().await; + f(s).await; + } + }) + })); + + let on_selected_candidate_pair_change_handler = + Arc::clone(&self.on_selected_candidate_pair_change_handler); + agent.on_selected_candidate_pair_change(Box::new( + move |local: &Arc, + remote: &Arc| { + let on_selected_candidate_pair_change_handler_clone = + Arc::clone(&on_selected_candidate_pair_change_handler); + let local = RTCIceCandidate::from(local); + let remote = RTCIceCandidate::from(remote); + Box::pin(async move { + if let Some(handler) = + &*on_selected_candidate_pair_change_handler_clone.load() + { + let mut f = handler.lock().await; + f(RTCIceCandidatePair::new(local, remote)).await; + } + }) + }, + )); + + let role = if let Some(role) = role { + role + } else { + RTCIceRole::Controlled + }; + + let (cancel_tx, cancel_rx) = mpsc::channel(1); + { + let mut internal = self.internal.lock().await; + internal.role = role; + internal.cancel_tx = Some(cancel_tx); + } + + let conn: Arc = match role { + RTCIceRole::Controlling => { + agent + .dial( + cancel_rx, + params.username_fragment.clone(), + params.password.clone(), + ) + .await? + } + + RTCIceRole::Controlled => { + agent + .accept( + cancel_rx, + params.username_fragment.clone(), + params.password.clone(), + ) + .await? + } + + _ => return Err(Error::ErrICERoleUnknown), + }; + + let config = Config { + conn: Arc::clone(&conn), + buffer_size: self.gatherer.setting_engine.get_receive_mtu(), + }; + + { + let mut internal = self.internal.lock().await; + internal.conn = Some(conn); + internal.mux = Some(Mux::new(config)); + } + + Ok(()) + } else { + Err(Error::ErrICEAgentNotExist) + } + } + + /// restart is not exposed currently because ORTC has users create a whole new ICETransport + /// so for now lets keep it private so we don't cause ORTC users to depend on non-standard APIs + pub(crate) async fn restart(&self) -> Result<()> { + if let Some(agent) = self.gatherer.get_agent().await { + agent + .restart( + self.gatherer + .setting_engine + .candidates + .username_fragment + .clone(), + self.gatherer.setting_engine.candidates.password.clone(), + ) + .await?; + } else { + return Err(Error::ErrICEAgentNotExist); + } + self.gatherer.gather().await + } + + /// Stop irreversibly stops the ICETransport. + pub async fn stop(&self) -> Result<()> { + self.set_state(RTCIceTransportState::Closed); + + let mut errs: Vec = vec![]; + { + let mut internal = self.internal.lock().await; + internal.cancel_tx.take(); + if let Some(mut mux) = internal.mux.take() { + mux.close().await; + } + if let Some(conn) = internal.conn.take() { + if let Err(err) = conn.close().await { + errs.push(err.into()); + } + } + } + + if let Err(err) = self.gatherer.close().await { + errs.push(err); + } + + flatten_errs(errs) + } + + /// on_selected_candidate_pair_change sets a handler that is invoked when a new + /// ICE candidate pair is selected + pub fn on_selected_candidate_pair_change(&self, f: OnSelectedCandidatePairChangeHdlrFn) { + self.on_selected_candidate_pair_change_handler + .store(Some(Arc::new(Mutex::new(f)))); + } + + /// on_connection_state_change sets a handler that is fired when the ICE + /// connection state changes. + pub fn on_connection_state_change(&self, f: OnConnectionStateChangeHdlrFn) { + self.on_connection_state_change_handler + .store(Some(Arc::new(Mutex::new(f)))); + } + + /// Role indicates the current role of the ICE transport. + pub async fn role(&self) -> RTCIceRole { + let internal = self.internal.lock().await; + internal.role + } + + /// set_remote_candidates sets the sequence of candidates associated with the remote ICETransport. + pub async fn set_remote_candidates(&self, remote_candidates: &[RTCIceCandidate]) -> Result<()> { + self.ensure_gatherer().await?; + + if let Some(agent) = self.gatherer.get_agent().await { + for rc in remote_candidates { + let c: Arc = Arc::new(rc.to_ice()?); + agent.add_remote_candidate(&c)?; + } + Ok(()) + } else { + Err(Error::ErrICEAgentNotExist) + } + } + + /// adds a candidate associated with the remote ICETransport. + pub async fn add_remote_candidate( + &self, + remote_candidate: Option, + ) -> Result<()> { + self.ensure_gatherer().await?; + + if let Some(agent) = self.gatherer.get_agent().await { + if let Some(r) = remote_candidate { + let c: Arc = Arc::new(r.to_ice()?); + agent.add_remote_candidate(&c)?; + } + + Ok(()) + } else { + Err(Error::ErrICEAgentNotExist) + } + } + + /// State returns the current ice transport state. + pub fn state(&self) -> RTCIceTransportState { + RTCIceTransportState::from(self.state.load(Ordering::SeqCst)) + } + + pub(crate) fn set_state(&self, s: RTCIceTransportState) { + self.state.store(s as u8, Ordering::SeqCst) + } + + pub(crate) async fn new_endpoint(&self, f: MatchFunc) -> Option> { + let internal = self.internal.lock().await; + if let Some(mux) = &internal.mux { + Some(mux.new_endpoint(f).await) + } else { + None + } + } + + pub(crate) async fn ensure_gatherer(&self) -> Result<()> { + if self.gatherer.get_agent().await.is_none() { + self.gatherer.create_agent().await + } else { + Ok(()) + } + } + + pub(crate) async fn collect_stats(&self, collector: &StatsCollector) { + if let Some(agent) = self.gatherer.get_agent().await { + let stats = ICETransportStats::new("ice_transport".to_string(), agent); + + collector.insert("ice_transport".to_string(), Transport(stats)); + } + } + + pub(crate) async fn have_remote_credentials_change( + &self, + new_ufrag: &str, + new_pwd: &str, + ) -> bool { + if let Some(agent) = self.gatherer.get_agent().await { + let (ufrag, upwd) = agent.get_remote_user_credentials().await; + ufrag != new_ufrag || upwd != new_pwd + } else { + false + } + } + + pub(crate) async fn set_remote_credentials( + &self, + new_ufrag: String, + new_pwd: String, + ) -> Result<()> { + if let Some(agent) = self.gatherer.get_agent().await { + Ok(agent.set_remote_credentials(new_ufrag, new_pwd).await?) + } else { + Err(Error::ErrICEAgentNotExist) + } + } +} diff --git a/rtc/src/lib.rs b/rtc/src/lib.rs index 80c6a85..bcb17bb 100644 --- a/rtc/src/lib.rs +++ b/rtc/src/lib.rs @@ -1,2 +1,32 @@ #![warn(rust_2018_idioms)] #![allow(dead_code)] + +/* +#[macro_use] +extern crate lazy_static; + +pub mod api; +pub mod data_channel; +pub mod dtls_transport; +pub mod ice_transport; +pub mod mux; +pub mod peer_connection; +pub mod rtp_transceiver; +pub mod sctp_transport; +pub mod stats; +pub mod track; + +// re-export sub-crates +pub use {data, dtls, ice, interceptor, mdns, media, rtcp, rtp, sctp, sdp, srtp, stun, turn, util}; +*/ + +pub(crate) const UNSPECIFIED_STR: &str = "Unspecified"; + +/// Equal to UDP MTU +pub(crate) const RECEIVE_MTU: usize = 1460; + +pub(crate) const SDP_ATTRIBUTE_RID: &str = "rid"; +pub(crate) const SDP_ATTRIBUTE_SIMULCAST: &str = "simulcast"; +pub(crate) const GENERATED_CERTIFICATE_ORIGIN: &str = "WebRTC"; +pub(crate) const SDES_REPAIR_RTP_STREAM_ID_URI: &str = + "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id"; diff --git a/rtc/src/mux/endpoint.rs b/rtc/src/mux/endpoint.rs new file mode 100644 index 0000000..64e8d51 --- /dev/null +++ b/rtc/src/mux/endpoint.rs @@ -0,0 +1,73 @@ +use std::collections::HashMap; +use std::io; +use std::net::SocketAddr; +use std::sync::Arc; + +use async_trait::async_trait; +use tokio::sync::Mutex; +use util::{Buffer, Conn}; + +use crate::mux::mux_func::MatchFunc; + +/// Endpoint implements net.Conn. It is used to read muxed packets. +pub struct Endpoint { + pub(crate) id: usize, + pub(crate) buffer: Buffer, + pub(crate) match_fn: MatchFunc, + pub(crate) next_conn: Arc, + pub(crate) endpoints: Arc>>>, +} + +impl Endpoint { + /// Close unregisters the endpoint from the Mux + pub async fn close(&self) -> Result<()> { + self.buffer.close().await; + + let mut endpoints = self.endpoints.lock().await; + endpoints.remove(&self.id); + + Ok(()) + } +} + +type Result = std::result::Result; + +#[async_trait] +impl Conn for Endpoint { + async fn connect(&self, _addr: SocketAddr) -> Result<()> { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + /// reads a packet of len(p) bytes from the underlying conn + /// that are matched by the associated MuxFunc + async fn recv(&self, buf: &mut [u8]) -> Result { + match self.buffer.read(buf, None).await { + Ok(n) => Ok(n), + Err(err) => Err(io::Error::new(io::ErrorKind::Other, err.to_string()).into()), + } + } + async fn recv_from(&self, _buf: &mut [u8]) -> Result<(usize, SocketAddr)> { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + /// writes bytes to the underlying conn + async fn send(&self, buf: &[u8]) -> Result { + self.next_conn.send(buf).await + } + + async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> Result { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + fn local_addr(&self) -> Result { + self.next_conn.local_addr() + } + + fn remote_addr(&self) -> Option { + self.next_conn.remote_addr() + } + + async fn close(&self) -> Result<()> { + self.next_conn.close().await + } +} diff --git a/rtc/src/mux/mod.rs b/rtc/src/mux/mod.rs new file mode 100644 index 0000000..d5bea2e --- /dev/null +++ b/rtc/src/mux/mod.rs @@ -0,0 +1,156 @@ +#[cfg(test)] +mod mux_test; + +pub mod endpoint; +pub mod mux_func; + +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use tokio::sync::{mpsc, Mutex}; +use util::{Buffer, Conn}; + +use crate::error::Result; +use crate::mux::endpoint::Endpoint; +use crate::mux::mux_func::MatchFunc; +use crate::util::Error; + +/// mux multiplexes packets on a single socket (RFC7983) + +/// The maximum amount of data that can be buffered before returning errors. +const MAX_BUFFER_SIZE: usize = 1000 * 1000; // 1MB + +/// Config collects the arguments to mux.Mux construction into +/// a single structure +pub struct Config { + pub conn: Arc, + pub buffer_size: usize, +} + +/// Mux allows multiplexing +#[derive(Clone)] +pub struct Mux { + id: Arc, + next_conn: Arc, + endpoints: Arc>>>, + buffer_size: usize, + closed_ch_tx: Option>, +} + +impl Mux { + pub fn new(config: Config) -> Self { + let (closed_ch_tx, closed_ch_rx) = mpsc::channel(1); + let m = Mux { + id: Arc::new(AtomicUsize::new(0)), + next_conn: Arc::clone(&config.conn), + endpoints: Arc::new(Mutex::new(HashMap::new())), + buffer_size: config.buffer_size, + closed_ch_tx: Some(closed_ch_tx), + }; + + let buffer_size = m.buffer_size; + let next_conn = Arc::clone(&m.next_conn); + let endpoints = Arc::clone(&m.endpoints); + tokio::spawn(async move { + Mux::read_loop(buffer_size, next_conn, closed_ch_rx, endpoints).await; + }); + + m + } + + /// creates a new Endpoint + pub async fn new_endpoint(&self, f: MatchFunc) -> Arc { + let mut endpoints = self.endpoints.lock().await; + + let id = self.id.fetch_add(1, Ordering::SeqCst); + // Set a maximum size of the buffer in bytes. + let e = Arc::new(Endpoint { + id, + buffer: Buffer::new(0, MAX_BUFFER_SIZE), + match_fn: f, + next_conn: Arc::clone(&self.next_conn), + endpoints: Arc::clone(&self.endpoints), + }); + + endpoints.insert(e.id, Arc::clone(&e)); + + e + } + + /// remove_endpoint removes an endpoint from the Mux + pub async fn remove_endpoint(&mut self, e: &Endpoint) { + let mut endpoints = self.endpoints.lock().await; + endpoints.remove(&e.id); + } + + /// Close closes the Mux and all associated Endpoints. + pub async fn close(&mut self) { + self.closed_ch_tx.take(); + + let mut endpoints = self.endpoints.lock().await; + endpoints.clear(); + } + + async fn read_loop( + buffer_size: usize, + next_conn: Arc, + mut closed_ch_rx: mpsc::Receiver<()>, + endpoints: Arc>>>, + ) { + let mut buf = vec![0u8; buffer_size]; + let mut n = 0usize; + loop { + tokio::select! { + _ = closed_ch_rx.recv() => break, + result = next_conn.recv(&mut buf) => { + if let Ok(m) = result{ + n = m; + } + } + }; + + if let Err(err) = Mux::dispatch(&buf[..n], &endpoints).await { + log::error!("mux: ending readLoop dispatch error {:?}", err); + break; + } + } + } + + async fn dispatch( + buf: &[u8], + endpoints: &Arc>>>, + ) -> Result<()> { + let mut endpoint = None; + + { + let eps = endpoints.lock().await; + for ep in eps.values() { + if (ep.match_fn)(buf) { + endpoint = Some(Arc::clone(ep)); + break; + } + } + } + + if let Some(ep) = endpoint { + match ep.buffer.write(buf).await { + // Expected when bytes are received faster than the endpoint can process them + Err(Error::ErrBufferFull) => { + log::info!("mux: endpoint buffer is full, dropping packet") + } + Ok(_) => (), + Err(e) => return Err(crate::Error::Util(e)), + } + } else if !buf.is_empty() { + log::warn!( + "Warning: mux: no endpoint for packet starting with {}", + buf[0] + ); + } else { + log::warn!("Warning: mux: no endpoint for zero length packet"); + } + + Ok(()) + } +} diff --git a/rtc/src/mux/mux_func.rs b/rtc/src/mux/mux_func.rs new file mode 100644 index 0000000..ae81907 --- /dev/null +++ b/rtc/src/mux/mux_func.rs @@ -0,0 +1,63 @@ +/// MatchFunc allows custom logic for mapping packets to an Endpoint +pub type MatchFunc = Box bool) + Send + Sync>; + +/// match_all always returns true +pub fn match_all(_b: &[u8]) -> bool { + true +} + +/// match_range is a MatchFunc that accepts packets with the first byte in [lower..upper] +pub fn match_range(lower: u8, upper: u8) -> MatchFunc { + Box::new(move |buf: &[u8]| -> bool { + if buf.is_empty() { + return false; + } + let b = buf[0]; + b >= lower && b <= upper + }) +} + +/// MatchFuncs as described in RFC7983 +/// +/// +----------------+ +/// | [0..3] -+--> forward to STUN +/// | | +/// | [16..19] -+--> forward to ZRTP +/// | | +/// packet --> | [20..63] -+--> forward to DTLS +/// | | +/// | [64..79] -+--> forward to TURN Channel +/// | | +/// | [128..191] -+--> forward to RTP/RTCP +/// +----------------+ +/// match_dtls is a MatchFunc that accepts packets with the first byte in [20..63] +/// as defied in RFC7983 +pub fn match_dtls(b: &[u8]) -> bool { + match_range(20, 63)(b) +} + +// match_srtp_or_srtcp is a MatchFunc that accepts packets with the first byte in [128..191] +// as defied in RFC7983 +pub fn match_srtp_or_srtcp(b: &[u8]) -> bool { + match_range(128, 191)(b) +} + +pub(crate) fn is_rtcp(buf: &[u8]) -> bool { + // Not long enough to determine RTP/RTCP + if buf.len() < 4 { + return false; + } + + let rtcp_packet_type = buf[1]; + (192..=223).contains(&rtcp_packet_type) +} + +/// match_srtp is a MatchFunc that only matches SRTP and not SRTCP +pub fn match_srtp(buf: &[u8]) -> bool { + match_srtp_or_srtcp(buf) && !is_rtcp(buf) +} + +/// match_srtcp is a MatchFunc that only matches SRTCP and not SRTP +pub fn match_srtcp(buf: &[u8]) -> bool { + match_srtp_or_srtcp(buf) && is_rtcp(buf) +} diff --git a/rtc/src/mux/mux_test.rs b/rtc/src/mux/mux_test.rs new file mode 100644 index 0000000..36e6651 --- /dev/null +++ b/rtc/src/mux/mux_test.rs @@ -0,0 +1,137 @@ +use std::io; +use std::net::SocketAddr; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use async_trait::async_trait; +use util::conn::conn_pipe::pipe; + +use super::*; +use crate::mux::mux_func::{match_all, match_srtp}; + +const TEST_PIPE_BUFFER_SIZE: usize = 8192; + +#[tokio::test] +async fn test_no_endpoints() -> crate::error::Result<()> { + // In memory pipe + let (ca, _) = pipe(); + + let mut m = Mux::new(Config { + conn: Arc::new(ca), + buffer_size: TEST_PIPE_BUFFER_SIZE, + }); + + Mux::dispatch(&[0], &m.endpoints).await?; + m.close().await; + + Ok(()) +} + +struct MuxErrorConn { + idx: AtomicUsize, + data: Vec>, +} + +type Result = std::result::Result; + +#[async_trait] +impl Conn for MuxErrorConn { + async fn connect(&self, _addr: SocketAddr) -> Result<()> { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + async fn recv(&self, buf: &mut [u8]) -> Result { + let idx = self.idx.fetch_add(1, Ordering::SeqCst); + if idx < self.data.len() { + let n = std::cmp::min(buf.len(), self.data[idx].len()); + buf[..n].copy_from_slice(&self.data[idx][..n]); + Ok(n) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + format!("idx {} >= data.len {}", idx, self.data.len()), + ) + .into()) + } + } + + async fn recv_from(&self, _buf: &mut [u8]) -> Result<(usize, SocketAddr)> { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + async fn send(&self, _buf: &[u8]) -> Result { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + async fn send_to(&self, _buf: &[u8], _target: SocketAddr) -> Result { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + fn local_addr(&self) -> Result { + Err(io::Error::new(io::ErrorKind::Other, "Not applicable").into()) + } + + fn remote_addr(&self) -> Option { + None + } + + async fn close(&self) -> Result<()> { + Ok(()) + } +} + +#[tokio::test] +async fn test_non_fatal_read() -> Result<()> { + let expected_data = b"expected_data".to_vec(); + + let conn = Arc::new(MuxErrorConn { + idx: AtomicUsize::new(0), + data: vec![ + expected_data.clone(), + expected_data.clone(), + expected_data.clone(), + ], + }); + + let mut m = Mux::new(Config { + conn, + buffer_size: TEST_PIPE_BUFFER_SIZE, + }); + + let e = m.new_endpoint(Box::new(match_all)).await; + let mut buff = vec![0u8; TEST_PIPE_BUFFER_SIZE]; + + let n = e.recv(&mut buff).await?; + assert_eq!(&buff[..n], expected_data); + + let n = e.recv(&mut buff).await?; + assert_eq!(&buff[..n], expected_data); + + let n = e.recv(&mut buff).await?; + assert_eq!(&buff[..n], expected_data); + + m.close().await; + + Ok(()) +} + +#[tokio::test] +async fn test_non_fatal_dispatch() -> Result<()> { + let (ca, cb) = pipe(); + + let mut m = Mux::new(Config { + conn: Arc::new(ca), + buffer_size: TEST_PIPE_BUFFER_SIZE, + }); + + let e = m.new_endpoint(Box::new(match_srtp)).await; + e.buffer.set_limit_size(1).await; + + for _ in 0..25 { + let srtp_packet = [128, 1, 2, 3, 4].to_vec(); + cb.send(&srtp_packet).await?; + } + + m.close().await; + + Ok(()) +} diff --git a/rtc/src/peer_connection/certificate.rs b/rtc/src/peer_connection/certificate.rs new file mode 100644 index 0000000..6efa905 --- /dev/null +++ b/rtc/src/peer_connection/certificate.rs @@ -0,0 +1,295 @@ +use std::ops::Add; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use dtls::crypto::{CryptoPrivateKey, CryptoPrivateKeyKind}; +use rcgen::{CertificateParams, KeyPair}; +use ring::rand::SystemRandom; +use ring::rsa; +use ring::signature::{EcdsaKeyPair, Ed25519KeyPair}; +use sha2::{Digest, Sha256}; + +use crate::dtls_transport::dtls_fingerprint::RTCDtlsFingerprint; +use crate::error::{Error, Result}; +use crate::peer_connection::math_rand_alpha; +use crate::stats::stats_collector::StatsCollector; +use crate::stats::{CertificateStats, StatsReportType}; + +/// Certificate represents a X.509 certificate used to authenticate WebRTC communications. +#[derive(Clone, Debug)] +pub struct RTCCertificate { + /// DTLS certificate. + pub(crate) dtls_certificate: dtls::crypto::Certificate, + /// Timestamp after which this certificate is no longer valid. + pub(crate) expires: SystemTime, + /// Certificate's ID used for statistics. + /// + /// Example: "certificate-1667202302853538793" + /// + /// See [`CertificateStats`]. + pub(crate) stats_id: String, +} + +impl PartialEq for RTCCertificate { + fn eq(&self, other: &Self) -> bool { + self.dtls_certificate == other.dtls_certificate + } +} + +impl RTCCertificate { + /// Generates a new certificate from the given parameters. + /// + /// See [`rcgen::Certificate::from_params`]. + pub fn from_params(params: CertificateParams) -> Result { + let not_after = params.not_after; + let x509_cert = rcgen::Certificate::from_params(params)?; + + let key_pair = x509_cert.get_key_pair(); + let serialized_der = key_pair.serialize_der(); + + let private_key = if key_pair.is_compatible(&rcgen::PKCS_ED25519) { + CryptoPrivateKey { + kind: CryptoPrivateKeyKind::Ed25519( + Ed25519KeyPair::from_pkcs8(&serialized_der) + .map_err(|e| Error::new(e.to_string()))?, + ), + serialized_der, + } + } else if key_pair.is_compatible(&rcgen::PKCS_ECDSA_P256_SHA256) { + CryptoPrivateKey { + kind: CryptoPrivateKeyKind::Ecdsa256( + EcdsaKeyPair::from_pkcs8( + &ring::signature::ECDSA_P256_SHA256_ASN1_SIGNING, + &serialized_der, + &SystemRandom::new(), + ) + .map_err(|e| Error::new(e.to_string()))?, + ), + serialized_der, + } + } else if key_pair.is_compatible(&rcgen::PKCS_RSA_SHA256) { + CryptoPrivateKey { + kind: CryptoPrivateKeyKind::Rsa256( + rsa::KeyPair::from_pkcs8(&serialized_der) + .map_err(|e| Error::new(e.to_string()))?, + ), + serialized_der, + } + } else { + return Err(Error::new("Unsupported key_pair".to_owned())); + }; + + let expires = if cfg!(target_arch = "arm") { + // Workaround for issue overflow when adding duration to instant on armv7 + // https://github.com/webrtc-rs/examples/issues/5 https://github.com/chronotope/chrono/issues/343 + SystemTime::now().add(Duration::from_secs(172800)) //60*60*48 or 2 days + } else { + not_after.into() + }; + + Ok(Self { + dtls_certificate: dtls::crypto::Certificate { + certificate: vec![rustls::Certificate(x509_cert.serialize_der()?)], + private_key, + }, + expires, + stats_id: gen_stats_id(), + }) + } + + /// Generates a new certificate with default [`CertificateParams`] using the given keypair. + pub fn from_key_pair(key_pair: KeyPair) -> Result { + let mut params = CertificateParams::new(vec![math_rand_alpha(16)]); + + if key_pair.is_compatible(&rcgen::PKCS_ED25519) { + params.alg = &rcgen::PKCS_ED25519; + } else if key_pair.is_compatible(&rcgen::PKCS_ECDSA_P256_SHA256) { + params.alg = &rcgen::PKCS_ECDSA_P256_SHA256; + } else if key_pair.is_compatible(&rcgen::PKCS_RSA_SHA256) { + params.alg = &rcgen::PKCS_RSA_SHA256; + } else { + return Err(Error::new("Unsupported key_pair".to_owned())); + }; + params.key_pair = Some(key_pair); + + RTCCertificate::from_params(params) + } + + /// Parses a certificate from the ASCII PEM format. + #[cfg(feature = "pem")] + pub fn from_pem(pem_str: &str) -> Result { + let mut pem_blocks = pem_str.split("\n\n"); + let first_block = if let Some(b) = pem_blocks.next() { + b + } else { + return Err(Error::InvalidPEM("empty PEM".into())); + }; + let expires_pem = + pem::parse(first_block).map_err(|e| Error::new(format!("can't parse PEM: {e}")))?; + if expires_pem.tag() != "EXPIRES" { + return Err(Error::InvalidPEM(format!( + "invalid tag (expected: 'EXPIRES', got '{}')", + expires_pem.tag() + ))); + } + let mut bytes = [0u8; 8]; + bytes.copy_from_slice(&expires_pem.contents()[..8]); + let expires = if let Some(e) = + SystemTime::UNIX_EPOCH.checked_add(Duration::from_secs(u64::from_le_bytes(bytes))) + { + e + } else { + return Err(Error::InvalidPEM("failed to calculate SystemTime".into())); + }; + let dtls_certificate = + dtls::crypto::Certificate::from_pem(&pem_blocks.collect::>().join("\n\n"))?; + Ok(RTCCertificate::from_existing(dtls_certificate, expires)) + } + + /// Builds a [`RTCCertificate`] using the existing DTLS certificate. + /// + /// Use this method when you have a persistent certificate (i.e. you don't want to generate a + /// new one for each DTLS connection). + /// + /// NOTE: ID used for statistics will be different as it's neither derived from the given + /// certificate nor persisted along it when using [`serialize_pem`]. + pub fn from_existing(dtls_certificate: dtls::crypto::Certificate, expires: SystemTime) -> Self { + Self { + dtls_certificate, + expires, + // TODO: figure out if it needs to be persisted + stats_id: gen_stats_id(), + } + } + + /// Serializes the certificate (including the private key) in PKCS#8 format in PEM. + #[cfg(feature = "pem")] + pub fn serialize_pem(&self) -> String { + // Encode `expires` as a PEM block. + // + // TODO: serialize as nanos when https://github.com/rust-lang/rust/issues/103332 is fixed. + let expires_pem = pem::Pem::new( + "EXPIRES".to_string(), + self.expires + .duration_since(SystemTime::UNIX_EPOCH) + .expect("expires to be valid") + .as_secs() + .to_le_bytes() + .to_vec(), + ); + format!( + "{}\n{}", + pem::encode(&expires_pem), + self.dtls_certificate.serialize_pem() + ) + } + + /// get_fingerprints returns a SHA-256 fingerprint of this certificate. + /// + /// TODO: return a fingerprint computed with the digest algorithm used in the certificate + /// signature. + pub fn get_fingerprints(&self) -> Vec { + let mut fingerprints = Vec::new(); + + for c in &self.dtls_certificate.certificate { + let mut h = Sha256::new(); + h.update(c.as_ref()); + let hashed = h.finalize(); + let values: Vec = hashed.iter().map(|x| format! {"{x:02x}"}).collect(); + + fingerprints.push(RTCDtlsFingerprint { + algorithm: "sha-256".to_owned(), + value: values.join(":"), + }); + } + + fingerprints + } + + pub(crate) async fn collect_stats(&self, collector: &StatsCollector) { + if let Some(fingerprint) = self.get_fingerprints().into_iter().next() { + let stats = CertificateStats::new(self, fingerprint); + collector.insert( + self.stats_id.clone(), + StatsReportType::CertificateStats(stats), + ); + } + } +} + +fn gen_stats_id() -> String { + format!( + "certificate-{}", + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() as u64 + ) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_generate_certificate_rsa() -> Result<()> { + let key_pair = KeyPair::generate(&rcgen::PKCS_RSA_SHA256); + assert!(key_pair.is_err(), "RcgenError::KeyGenerationUnavailable"); + + Ok(()) + } + + #[test] + fn test_generate_certificate_ecdsa() -> Result<()> { + let kp = KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)?; + let _cert = RTCCertificate::from_key_pair(kp)?; + + Ok(()) + } + + #[test] + fn test_generate_certificate_eddsa() -> Result<()> { + let kp = KeyPair::generate(&rcgen::PKCS_ED25519)?; + let _cert = RTCCertificate::from_key_pair(kp)?; + + Ok(()) + } + + #[test] + fn test_certificate_equal() -> Result<()> { + let kp1 = KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)?; + let cert1 = RTCCertificate::from_key_pair(kp1)?; + + let kp2 = KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)?; + let cert2 = RTCCertificate::from_key_pair(kp2)?; + + assert_ne!(cert1, cert2); + + Ok(()) + } + + #[test] + fn test_generate_certificate_expires_and_stats_id() -> Result<()> { + let kp = KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)?; + let cert = RTCCertificate::from_key_pair(kp)?; + + let now = SystemTime::now(); + assert!(cert.expires.duration_since(now).is_ok()); + assert!(cert.stats_id.contains("certificate")); + + Ok(()) + } + + #[cfg(feature = "pem")] + #[test] + fn test_certificate_serialize_pem_and_from_pem() -> Result<()> { + let kp = KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)?; + let cert = RTCCertificate::from_key_pair(kp)?; + + let pem = cert.serialize_pem(); + let loaded_cert = RTCCertificate::from_pem(&pem)?; + + assert_eq!(loaded_cert, cert); + + Ok(()) + } +} diff --git a/rtc/src/peer_connection/configuration.rs b/rtc/src/peer_connection/configuration.rs new file mode 100644 index 0000000..f5a723f --- /dev/null +++ b/rtc/src/peer_connection/configuration.rs @@ -0,0 +1,147 @@ +use crate::ice_transport::ice_server::RTCIceServer; +use crate::peer_connection::certificate::RTCCertificate; +use crate::peer_connection::policy::bundle_policy::RTCBundlePolicy; +use crate::peer_connection::policy::ice_transport_policy::RTCIceTransportPolicy; +use crate::peer_connection::policy::rtcp_mux_policy::RTCRtcpMuxPolicy; + +/// A Configuration defines how peer-to-peer communication via PeerConnection +/// is established or re-established. +/// Configurations may be set up once and reused across multiple connections. +/// Configurations are treated as readonly. As long as they are unmodified, +/// they are safe for concurrent use. +#[derive(Default, Clone)] +pub struct RTCConfiguration { + /// iceservers defines a slice describing servers available to be used by + /// ICE, such as STUN and TURN servers. + pub ice_servers: Vec, + + /// icetransport_policy indicates which candidates the ICEAgent is allowed + /// to use. + pub ice_transport_policy: RTCIceTransportPolicy, + + /// bundle_policy indicates which media-bundling policy to use when gathering + /// ICE candidates. + pub bundle_policy: RTCBundlePolicy, + + /// rtcp_mux_policy indicates which rtcp-mux policy to use when gathering ICE + /// candidates. + pub rtcp_mux_policy: RTCRtcpMuxPolicy, + + /// peer_identity sets the target peer identity for the PeerConnection. + /// The PeerConnection will not establish a connection to a remote peer + /// unless it can be successfully authenticated with the provided name. + pub peer_identity: String, + + /// Certificates describes a set of certificates that the PeerConnection + /// uses to authenticate. Valid values for this parameter are created + /// through calls to the generate_certificate function. Although any given + /// DTLS connection will use only one certificate, this attribute allows the + /// caller to provide multiple certificates that support different + /// algorithms. The final certificate will be selected based on the DTLS + /// handshake, which establishes which certificates are allowed. The + /// PeerConnection implementation selects which of the certificates is + /// used for a given connection; how certificates are selected is outside + /// the scope of this specification. If this value is absent, then a default + /// set of certificates is generated for each PeerConnection instance. + pub certificates: Vec, + + /// icecandidate_pool_size describes the size of the prefetched ICE pool. + pub ice_candidate_pool_size: u8, +} + +impl RTCConfiguration { + /// get_iceservers side-steps the strict parsing mode of the ice package + /// (as defined in https://tools.ietf.org/html/rfc7064) by copying and then + /// stripping any erroneous queries from "stun(s):" URLs before parsing. + pub(crate) fn get_ice_servers(&self) -> Vec { + let mut ice_servers = self.ice_servers.clone(); + + for ice_server in &mut ice_servers { + for raw_url in &mut ice_server.urls { + if raw_url.starts_with("stun") { + // strip the query from "stun(s):" if present + let parts: Vec<&str> = raw_url.split('?').collect(); + *raw_url = parts[0].to_owned(); + } + } + } + + ice_servers + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_configuration_get_iceservers() { + { + let expected_server_str = "stun:stun.l.google.com:19302"; + let cfg = RTCConfiguration { + ice_servers: vec![RTCIceServer { + urls: vec![expected_server_str.to_owned()], + ..Default::default() + }], + ..Default::default() + }; + + let parsed_urls = cfg.get_ice_servers(); + assert_eq!(parsed_urls[0].urls[0], expected_server_str); + } + + { + // ignore the fact that stun URLs shouldn't have a query + let server_str = "stun:global.stun.twilio.com:3478?transport=udp"; + let expected_server_str = "stun:global.stun.twilio.com:3478"; + let cfg = RTCConfiguration { + ice_servers: vec![RTCIceServer { + urls: vec![server_str.to_owned()], + ..Default::default() + }], + ..Default::default() + }; + + let parsed_urls = cfg.get_ice_servers(); + assert_eq!(parsed_urls[0].urls[0], expected_server_str); + } + } + + /*TODO:#[test] fn test_configuration_json() { + + let j = r#" + { + "iceServers": [{"URLs": ["turn:turn.example.org"], + "username": "jch", + "credential": "topsecret" + }], + "iceTransportPolicy": "relay", + "bundlePolicy": "balanced", + "rtcpMuxPolicy": "require" + }"#; + + conf := Configuration{ + ICEServers: []ICEServer{ + { + URLs: []string{"turn:turn.example.org"}, + Username: "jch", + Credential: "topsecret", + }, + }, + ICETransportPolicy: ICETransportPolicyRelay, + BundlePolicy: BundlePolicyBalanced, + RTCPMuxPolicy: RTCPMuxPolicyRequire, + } + + var conf2 Configuration + assert.NoError(t, json.Unmarshal([]byte(j), &conf2)) + assert.Equal(t, conf, conf2) + + j2, err := json.Marshal(conf2) + assert.NoError(t, err) + + var conf3 Configuration + assert.NoError(t, json.Unmarshal(j2, &conf3)) + assert.Equal(t, conf2, conf3) + }*/ +} diff --git a/rtc/src/peer_connection/mod.rs b/rtc/src/peer_connection/mod.rs new file mode 100644 index 0000000..b4989ce --- /dev/null +++ b/rtc/src/peer_connection/mod.rs @@ -0,0 +1,2108 @@ +#[cfg(test)] +pub(crate) mod peer_connection_test; + +pub mod certificate; +pub mod configuration; +pub mod offer_answer_options; +pub(crate) mod operation; +mod peer_connection_internal; +pub mod peer_connection_state; +pub mod policy; +pub mod sdp; +pub mod signaling_state; + +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering}; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; + +use ::ice::candidate::candidate_base::unmarshal_candidate; +use ::ice::candidate::Candidate; +use ::sdp::description::session::*; +use ::sdp::util::ConnectionRole; +use arc_swap::ArcSwapOption; +use async_trait::async_trait; +use interceptor::{stats, Attributes, Interceptor, RTCPWriter}; +use peer_connection_internal::*; +use rand::{thread_rng, Rng}; +use rcgen::KeyPair; +use smol_str::SmolStr; +use srtp::stream::Stream; +use tokio::sync::{mpsc, Mutex}; + +use crate::api::media_engine::MediaEngine; +use crate::api::setting_engine::SettingEngine; +use crate::api::API; +use crate::data_channel::data_channel_init::RTCDataChannelInit; +use crate::data_channel::data_channel_parameters::DataChannelParameters; +use crate::data_channel::data_channel_state::RTCDataChannelState; +use crate::data_channel::RTCDataChannel; +use crate::dtls_transport::dtls_fingerprint::RTCDtlsFingerprint; +use crate::dtls_transport::dtls_parameters::DTLSParameters; +use crate::dtls_transport::dtls_role::{ + DTLSRole, DEFAULT_DTLS_ROLE_ANSWER, DEFAULT_DTLS_ROLE_OFFER, +}; +use crate::dtls_transport::dtls_transport_state::RTCDtlsTransportState; +use crate::dtls_transport::RTCDtlsTransport; +use crate::error::{flatten_errs, Error, Result}; +use crate::ice_transport::ice_candidate::{RTCIceCandidate, RTCIceCandidateInit}; +use crate::ice_transport::ice_connection_state::RTCIceConnectionState; +use crate::ice_transport::ice_gatherer::{ + OnGatheringCompleteHdlrFn, OnICEGathererStateChangeHdlrFn, OnLocalCandidateHdlrFn, + RTCIceGatherOptions, RTCIceGatherer, +}; +use crate::ice_transport::ice_gatherer_state::RTCIceGathererState; +use crate::ice_transport::ice_gathering_state::RTCIceGatheringState; +use crate::ice_transport::ice_parameters::RTCIceParameters; +use crate::ice_transport::ice_role::RTCIceRole; +use crate::ice_transport::ice_transport_state::RTCIceTransportState; +use crate::ice_transport::RTCIceTransport; +use crate::peer_connection::certificate::RTCCertificate; +use crate::peer_connection::configuration::RTCConfiguration; +use crate::peer_connection::offer_answer_options::{RTCAnswerOptions, RTCOfferOptions}; +use crate::peer_connection::operation::{Operation, Operations}; +use crate::peer_connection::peer_connection_state::{ + NegotiationNeededState, RTCPeerConnectionState, +}; +use crate::peer_connection::sdp::sdp_type::RTCSdpType; +use crate::peer_connection::sdp::session_description::RTCSessionDescription; +use crate::peer_connection::sdp::*; +use crate::peer_connection::signaling_state::{ + check_next_signaling_state, RTCSignalingState, StateChangeOp, +}; +use crate::rtp_transceiver::rtp_codec::{RTCRtpHeaderExtensionCapability, RTPCodecType}; +use crate::rtp_transceiver::rtp_receiver::RTCRtpReceiver; +use crate::rtp_transceiver::rtp_sender::RTCRtpSender; +use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; +use crate::rtp_transceiver::{ + find_by_mid, handle_unknown_rtp_packet, satisfy_type_and_direction, RTCRtpTransceiver, + RTCRtpTransceiverInit, SSRC, +}; +use crate::sctp_transport::sctp_transport_capabilities::SCTPTransportCapabilities; +use crate::sctp_transport::sctp_transport_state::RTCSctpTransportState; +use crate::sctp_transport::RTCSctpTransport; +use crate::stats::StatsReport; +use crate::track::track_local::TrackLocal; +use crate::track::track_remote::TrackRemote; + +/// SIMULCAST_PROBE_COUNT is the amount of RTP Packets +/// that handleUndeclaredSSRC will read and try to dispatch from +/// mid and rid values +pub(crate) const SIMULCAST_PROBE_COUNT: usize = 10; + +/// SIMULCAST_MAX_PROBE_ROUTINES is how many active routines can be used to probe +/// If the total amount of incoming SSRCes exceeds this new requests will be ignored +pub(crate) const SIMULCAST_MAX_PROBE_ROUTINES: u64 = 25; + +pub(crate) const MEDIA_SECTION_APPLICATION: &str = "application"; + +const RUNES_ALPHA: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + +/// math_rand_alpha generates a mathematical random alphabet sequence of the requested length. +pub fn math_rand_alpha(n: usize) -> String { + let mut rng = thread_rng(); + + let rand_string: String = (0..n) + .map(|_| { + let idx = rng.gen_range(0..RUNES_ALPHA.len()); + RUNES_ALPHA[idx] as char + }) + .collect(); + + rand_string +} + +pub type OnSignalingStateChangeHdlrFn = Box< + dyn (FnMut(RTCSignalingState) -> Pin + Send + 'static>>) + + Send + + Sync, +>; + +pub type OnICEConnectionStateChangeHdlrFn = Box< + dyn (FnMut(RTCIceConnectionState) -> Pin + Send + 'static>>) + + Send + + Sync, +>; + +pub type OnPeerConnectionStateChangeHdlrFn = Box< + dyn (FnMut(RTCPeerConnectionState) -> Pin + Send + 'static>>) + + Send + + Sync, +>; + +pub type OnDataChannelHdlrFn = Box< + dyn (FnMut(Arc) -> Pin + Send + 'static>>) + + Send + + Sync, +>; + +pub type OnTrackHdlrFn = Box< + dyn (FnMut( + Arc, + Arc, + Arc, + ) -> Pin + Send + 'static>>) + + Send + + Sync, +>; + +pub type OnNegotiationNeededHdlrFn = + Box Pin + Send + 'static>>) + Send + Sync>; + +#[derive(Clone)] +struct StartTransportsParams { + ice_transport: Arc, + dtls_transport: Arc, + on_peer_connection_state_change_handler: Arc>>, + is_closed: Arc, + peer_connection_state: Arc, + ice_connection_state: Arc, +} + +#[derive(Clone)] +struct CheckNegotiationNeededParams { + sctp_transport: Arc, + rtp_transceivers: Arc>>>, + current_local_description: Arc>>, + current_remote_description: Arc>>, +} + +#[derive(Clone)] +struct NegotiationNeededParams { + on_negotiation_needed_handler: Arc>>, + is_closed: Arc, + ops: Arc, + negotiation_needed_state: Arc, + is_negotiation_needed: Arc, + signaling_state: Arc, + check_negotiation_needed_params: CheckNegotiationNeededParams, +} + +/// PeerConnection represents a WebRTC connection that establishes a +/// peer-to-peer communications with another PeerConnection instance in a +/// browser, or to another endpoint implementing the required protocols. +pub struct RTCPeerConnection { + stats_id: String, + idp_login_url: Option, + + configuration: RTCConfiguration, + + interceptor_rtcp_writer: Arc, + + interceptor: Arc, + + pub(crate) internal: Arc, +} + +impl std::fmt::Debug for RTCPeerConnection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RTCPeerConnection") + .field("stats_id", &self.stats_id) + .field("idp_login_url", &self.idp_login_url) + .field("signaling_state", &self.signaling_state()) + .field("ice_connection_state", &self.ice_connection_state()) + .finish() + } +} + +impl std::fmt::Display for RTCPeerConnection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "(RTCPeerConnection {})", self.stats_id) + } +} + +impl RTCPeerConnection { + /// creates a PeerConnection with the default codecs and + /// interceptors. See register_default_codecs and register_default_interceptors. + /// + /// If you wish to customize the set of available codecs or the set of + /// active interceptors, create a MediaEngine and call api.new_peer_connection + /// instead of this function. + pub(crate) async fn new(api: &API, mut configuration: RTCConfiguration) -> Result { + RTCPeerConnection::init_configuration(&mut configuration)?; + + let (interceptor, stats_interceptor): (Arc, _) = { + let mut chain = api.interceptor_registry.build_chain("")?; + let stats_interceptor = stats::make_stats_interceptor(""); + chain.add(stats_interceptor.clone()); + + (Arc::new(chain), stats_interceptor) + }; + + let weak_interceptor = Arc::downgrade(&interceptor); + let (internal, configuration) = + PeerConnectionInternal::new(api, weak_interceptor, stats_interceptor, configuration) + .await?; + let internal_rtcp_writer = Arc::clone(&internal) as Arc; + let interceptor_rtcp_writer = interceptor.bind_rtcp_writer(internal_rtcp_writer).await; + + // (Step #2) + // Some variables defined explicitly despite their implicit zero values to + // allow better readability to understand what is happening. + Ok(RTCPeerConnection { + stats_id: format!( + "PeerConnection-{}", + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() + ), + interceptor, + interceptor_rtcp_writer, + internal, + configuration, + idp_login_url: None, + }) + } + + /// init_configuration defines validation of the specified Configuration and + /// its assignment to the internal configuration variable. This function differs + /// from its set_configuration counterpart because most of the checks do not + /// include verification statements related to the existing state. Thus the + /// function describes only minor verification of some the struct variables. + fn init_configuration(configuration: &mut RTCConfiguration) -> Result<()> { + let sanitized_ice_servers = configuration.get_ice_servers(); + if !sanitized_ice_servers.is_empty() { + for server in &sanitized_ice_servers { + server.validate()?; + } + } + + // (step #3) + if !configuration.certificates.is_empty() { + let now = SystemTime::now(); + for cert in &configuration.certificates { + cert.expires + .duration_since(now) + .map_err(|_| Error::ErrCertificateExpired)?; + } + } else { + let kp = KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)?; + let cert = RTCCertificate::from_key_pair(kp)?; + configuration.certificates = vec![cert]; + }; + + Ok(()) + } + + /// on_signaling_state_change sets an event handler which is invoked when the + /// peer connection's signaling state changes + pub fn on_signaling_state_change(&self, f: OnSignalingStateChangeHdlrFn) { + self.internal + .on_signaling_state_change_handler + .store(Some(Arc::new(Mutex::new(f)))) + } + + async fn do_signaling_state_change(&self, new_state: RTCSignalingState) { + log::info!("signaling state changed to {}", new_state); + if let Some(handler) = &*self.internal.on_signaling_state_change_handler.load() { + let mut f = handler.lock().await; + f(new_state).await; + } + } + + /// on_data_channel sets an event handler which is invoked when a data + /// channel message arrives from a remote peer. + pub fn on_data_channel(&self, f: OnDataChannelHdlrFn) { + self.internal + .on_data_channel_handler + .store(Some(Arc::new(Mutex::new(f)))); + } + + /// on_negotiation_needed sets an event handler which is invoked when + /// a change has occurred which requires session negotiation + pub fn on_negotiation_needed(&self, f: OnNegotiationNeededHdlrFn) { + self.internal + .on_negotiation_needed_handler + .store(Some(Arc::new(Mutex::new(f)))); + } + + fn do_negotiation_needed_inner(params: &NegotiationNeededParams) -> bool { + // https://w3c.github.io/webrtc-pc/#updating-the-negotiation-needed-flag + // non-canon step 1 + let state: NegotiationNeededState = params + .negotiation_needed_state + .load(Ordering::SeqCst) + .into(); + if state == NegotiationNeededState::Run { + params + .negotiation_needed_state + .store(NegotiationNeededState::Queue as u8, Ordering::SeqCst); + false + } else if state == NegotiationNeededState::Queue { + false + } else { + params + .negotiation_needed_state + .store(NegotiationNeededState::Run as u8, Ordering::SeqCst); + true + } + } + /// do_negotiation_needed enqueues negotiation_needed_op if necessary + /// caller of this method should hold `pc.mu` lock + async fn do_negotiation_needed(params: NegotiationNeededParams) { + if !RTCPeerConnection::do_negotiation_needed_inner(¶ms) { + return; + } + + let params2 = params.clone(); + let _ = params + .ops + .enqueue(Operation::new( + move || { + let params3 = params2.clone(); + Box::pin(async move { RTCPeerConnection::negotiation_needed_op(params3).await }) + }, + "do_negotiation_needed", + )) + .await; + } + + async fn after_negotiation_needed_op(params: NegotiationNeededParams) -> bool { + let old_negotiation_needed_state = params.negotiation_needed_state.load(Ordering::SeqCst); + + params + .negotiation_needed_state + .store(NegotiationNeededState::Empty as u8, Ordering::SeqCst); + + if old_negotiation_needed_state == NegotiationNeededState::Queue as u8 { + RTCPeerConnection::do_negotiation_needed_inner(¶ms) + } else { + false + } + } + + async fn negotiation_needed_op(params: NegotiationNeededParams) -> bool { + // Don't run NegotiatedNeeded checks if on_negotiation_needed is not set + let handler = &*params.on_negotiation_needed_handler.load(); + if handler.is_none() { + return false; + } + + // https://www.w3.org/TR/webrtc/#updating-the-negotiation-needed-flag + // Step 2.1 + if params.is_closed.load(Ordering::SeqCst) { + return false; + } + // non-canon step 2.2 + if !params.ops.is_empty().await { + //enqueue negotiation_needed_op again by return true + return true; + } + + // non-canon, run again if there was a request + // starting defer(after_do_negotiation_needed(params).await); + + // Step 2.3 + if params.signaling_state.load(Ordering::SeqCst) != RTCSignalingState::Stable as u8 { + return RTCPeerConnection::after_negotiation_needed_op(params).await; + } + + // Step 2.4 + if !RTCPeerConnection::check_negotiation_needed(¶ms.check_negotiation_needed_params) + .await + { + params.is_negotiation_needed.store(false, Ordering::SeqCst); + return RTCPeerConnection::after_negotiation_needed_op(params).await; + } + + // Step 2.5 + if params.is_negotiation_needed.load(Ordering::SeqCst) { + return RTCPeerConnection::after_negotiation_needed_op(params).await; + } + + // Step 2.6 + params.is_negotiation_needed.store(true, Ordering::SeqCst); + + // Step 2.7 + if let Some(handler) = handler { + let mut f = handler.lock().await; + f().await; + } + + RTCPeerConnection::after_negotiation_needed_op(params).await + } + + async fn check_negotiation_needed(params: &CheckNegotiationNeededParams) -> bool { + // To check if negotiation is needed for connection, perform the following checks: + // Skip 1, 2 steps + // Step 3 + let current_local_description = { + let current_local_description = params.current_local_description.lock().await; + current_local_description.clone() + }; + let current_remote_description = { + let current_remote_description = params.current_remote_description.lock().await; + current_remote_description.clone() + }; + + if let Some(local_desc) = ¤t_local_description { + let len_data_channel = { + let data_channels = params.sctp_transport.data_channels.lock().await; + data_channels.len() + }; + + if len_data_channel != 0 && have_data_channel(local_desc).is_none() { + return true; + } + + let transceivers = params.rtp_transceivers.lock().await; + for t in &*transceivers { + // https://www.w3.org/TR/webrtc/#dfn-update-the-negotiation-needed-flag + // Step 5.1 + // if t.stopping && !t.stopped { + // return true + // } + let mid = t.mid(); + let m = mid + .as_ref() + .and_then(|mid| get_by_mid(mid.as_str(), local_desc)); + // Step 5.2 + if !t.stopped.load(Ordering::SeqCst) { + if m.is_none() { + return true; + } + + if let Some(m) = m { + // Step 5.3.1 + if t.direction().has_send() { + let dmsid = match m.attribute(ATTR_KEY_MSID).and_then(|o| o) { + Some(m) => m, + None => return true, // doesn't contain a single a=msid line + }; + + let sender = t.sender().await; + // (...)or the number of MSIDs from the a=msid lines in this m= section, + // or the MSID values themselves, differ from what is in + // transceiver.sender.[[AssociatedMediaStreamIds]], return true. + + // TODO: This check should be robuster by storing all streams in the + // local description so we can compare all of them. For no we only + // consider the first one. + + let stream_ids = sender.associated_media_stream_ids(); + // Different number of lines, 1 vs 0 + if stream_ids.is_empty() { + return true; + } + + // different stream id + if dmsid.split_whitespace().next() != Some(&stream_ids[0]) { + return true; + } + } + match local_desc.sdp_type { + RTCSdpType::Offer => { + // Step 5.3.2 + if let Some(remote_desc) = ¤t_remote_description { + if let Some(rm) = t + .mid() + .and_then(|mid| get_by_mid(mid.as_str(), remote_desc)) + { + if get_peer_direction(m) != t.direction() + && get_peer_direction(rm) != t.direction().reverse() + { + return true; + } + } else { + return true; + } + } + } + RTCSdpType::Answer => { + let remote_desc = match ¤t_remote_description { + Some(d) => d, + None => return true, + }; + let offered_direction = match t + .mid() + .and_then(|mid| get_by_mid(mid.as_str(), remote_desc)) + { + Some(d) => { + let dir = get_peer_direction(d); + if dir == RTCRtpTransceiverDirection::Unspecified { + RTCRtpTransceiverDirection::Inactive + } else { + dir + } + } + None => RTCRtpTransceiverDirection::Inactive, + }; + + let current_direction = get_peer_direction(m); + // Step 5.3.3 + if current_direction + != t.direction().intersect(offered_direction.reverse()) + { + return true; + } + } + _ => {} + }; + } + } + // Step 5.4 + if t.stopped.load(Ordering::SeqCst) { + let search_mid = match t.mid() { + Some(mid) => mid, + None => return false, + }; + + if let Some(remote_desc) = &*params.current_remote_description.lock().await { + return get_by_mid(search_mid.as_str(), local_desc).is_some() + || get_by_mid(search_mid.as_str(), remote_desc).is_some(); + } + } + } + // Step 6 + false + } else { + true + } + } + + /// on_ice_candidate sets an event handler which is invoked when a new ICE + /// candidate is found. + /// Take note that the handler is gonna be called with a nil pointer when + /// gathering is finished. + pub fn on_ice_candidate(&self, f: OnLocalCandidateHdlrFn) { + self.internal.ice_gatherer.on_local_candidate(f) + } + + /// on_ice_gathering_state_change sets an event handler which is invoked when the + /// ICE candidate gathering state has changed. + pub fn on_ice_gathering_state_change(&self, f: OnICEGathererStateChangeHdlrFn) { + self.internal.ice_gatherer.on_state_change(f) + } + + /// on_track sets an event handler which is called when remote track + /// arrives from a remote peer. + pub fn on_track(&self, f: OnTrackHdlrFn) { + self.internal + .on_track_handler + .store(Some(Arc::new(Mutex::new(f)))); + } + + fn do_track( + on_track_handler: Arc>>, + track: Arc, + receiver: Arc, + transceiver: Arc, + ) { + log::debug!("got new track: {:?}", track); + + tokio::spawn(async move { + if let Some(handler) = &*on_track_handler.load() { + let mut f = handler.lock().await; + f(track, receiver, transceiver).await; + } else { + log::warn!("on_track unset, unable to handle incoming media streams"); + } + }); + } + + /// on_ice_connection_state_change sets an event handler which is called + /// when an ICE connection state is changed. + pub fn on_ice_connection_state_change(&self, f: OnICEConnectionStateChangeHdlrFn) { + self.internal + .on_ice_connection_state_change_handler + .store(Some(Arc::new(Mutex::new(f)))); + } + + async fn do_ice_connection_state_change( + handler: &Arc>>, + ice_connection_state: &Arc, + cs: RTCIceConnectionState, + ) { + ice_connection_state.store(cs as u8, Ordering::SeqCst); + + log::info!("ICE connection state changed: {}", cs); + if let Some(handler) = &*handler.load() { + let mut f = handler.lock().await; + f(cs).await; + } + } + + /// on_peer_connection_state_change sets an event handler which is called + /// when the PeerConnectionState has changed + pub fn on_peer_connection_state_change(&self, f: OnPeerConnectionStateChangeHdlrFn) { + self.internal + .on_peer_connection_state_change_handler + .store(Some(Arc::new(Mutex::new(f)))); + } + + async fn do_peer_connection_state_change( + handler: &Arc>>, + cs: RTCPeerConnectionState, + ) { + if let Some(handler) = &*handler.load() { + let mut f = handler.lock().await; + f(cs).await; + } + } + + /*TODO: // set_configuration updates the configuration of this PeerConnection object. + pub async fn set_configuration(&mut self, configuration: Configuration) -> Result<()> { + //nolint:gocognit + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-setconfiguration (step #2) + if self.internal.is_closed.load(Ordering::SeqCst) { + return Err(Error::ErrConnectionClosed.into()); + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #3) + if !configuration.peer_identity.is_empty() { + if configuration.peer_identity != self.configuration.peer_identity { + return Err(Error::ErrModifyingPeerIdentity.into()); + } + self.configuration.peer_identity = configuration.peer_identity; + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #4) + if !configuration.certificates.is_empty() { + if configuration.certificates.len() != self.configuration.certificates.len() { + return Err(Error::ErrModifyingCertificates.into()); + } + + self.configuration.certificates = configuration.certificates; + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #5) + if configuration.bundle_policy != BundlePolicy::Unspecified { + if configuration.bundle_policy != self.configuration.bundle_policy { + return Err(Error::ErrModifyingBundlePolicy.into()); + } + self.configuration.bundle_policy = configuration.bundle_policy; + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #6) + if configuration.rtcp_mux_policy != RTCPMuxPolicy::Unspecified { + if configuration.rtcp_mux_policy != self.configuration.rtcp_mux_policy { + return Err(Error::ErrModifyingRTCPMuxPolicy.into()); + } + self.configuration.rtcp_mux_policy = configuration.rtcp_mux_policy; + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #7) + if configuration.ice_candidate_pool_size != 0 { + if self.configuration.ice_candidate_pool_size != configuration.ice_candidate_pool_size + && self.local_description().await.is_some() + { + return Err(Error::ErrModifyingICECandidatePoolSize.into()); + } + self.configuration.ice_candidate_pool_size = configuration.ice_candidate_pool_size; + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #8) + if configuration.ice_transport_policy != ICETransportPolicy::Unspecified { + self.configuration.ice_transport_policy = configuration.ice_transport_policy + } + + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11) + if !configuration.ice_servers.is_empty() { + // https://www.w3.org/TR/webrtc/#set-the-configuration (step #11.3) + for server in &configuration.ice_servers { + server.validate()?; + } + self.configuration.ice_servers = configuration.ice_servers + } + Ok(()) + }*/ + + /// get_configuration returns a Configuration object representing the current + /// configuration of this PeerConnection object. The returned object is a + /// copy and direct mutation on it will not take affect until set_configuration + /// has been called with Configuration passed as its only argument. + /// + pub fn get_configuration(&self) -> &RTCConfiguration { + &self.configuration + } + + pub fn get_stats_id(&self) -> &str { + self.stats_id.as_str() + } + + /// create_offer starts the PeerConnection and generates the localDescription + /// + pub async fn create_offer( + &self, + options: Option, + ) -> Result { + let use_identity = self.idp_login_url.is_some(); + if use_identity { + return Err(Error::ErrIdentityProviderNotImplemented); + } else if self.internal.is_closed.load(Ordering::SeqCst) { + return Err(Error::ErrConnectionClosed); + } + + if let Some(options) = options { + if options.ice_restart { + self.internal.ice_transport.restart().await?; + } + } + + // This may be necessary to recompute if, for example, createOffer was called when only an + // audio RTCRtpTransceiver was added to connection, but while performing the in-parallel + // steps to create an offer, a video RTCRtpTransceiver was added, requiring additional + // inspection of video system resources. + let mut count = 0; + let mut offer; + + loop { + // We cache current transceivers to ensure they aren't + // mutated during offer generation. We later check if they have + // been mutated and recompute the offer if necessary. + let current_transceivers = { + let rtp_transceivers = self.internal.rtp_transceivers.lock().await; + rtp_transceivers.clone() + }; + + // include unmatched local transceivers + // update the greater mid if the remote description provides a greater one + { + let current_remote_description = + self.internal.current_remote_description.lock().await; + if let Some(d) = &*current_remote_description { + if let Some(parsed) = &d.parsed { + for media in &parsed.media_descriptions { + if let Some(mid) = get_mid_value(media) { + if mid.is_empty() { + continue; + } + let numeric_mid = match mid.parse::() { + Ok(n) => n, + Err(_) => continue, + }; + if numeric_mid > self.internal.greater_mid.load(Ordering::SeqCst) { + self.internal + .greater_mid + .store(numeric_mid, Ordering::SeqCst); + } + } + } + } + } + } + for t in ¤t_transceivers { + if t.mid().is_some() { + continue; + } + + if let Some(gen) = &self.internal.setting_engine.mid_generator { + let current_greatest = self.internal.greater_mid.load(Ordering::SeqCst); + let mid = (gen)(current_greatest); + + // If it's possible to parse the returned mid as numeric, we will update the greater_mid field. + if let Ok(numeric_mid) = mid.parse::() { + if numeric_mid > self.internal.greater_mid.load(Ordering::SeqCst) { + self.internal + .greater_mid + .store(numeric_mid, Ordering::SeqCst); + } + } + + t.set_mid(SmolStr::from(mid))?; + } else { + let greater_mid = self.internal.greater_mid.fetch_add(1, Ordering::SeqCst); + t.set_mid(SmolStr::from(format!("{}", greater_mid + 1)))?; + } + } + + let current_remote_description_is_none = { + let current_remote_description = + self.internal.current_remote_description.lock().await; + current_remote_description.is_none() + }; + + let mut d = if current_remote_description_is_none { + self.internal + .generate_unmatched_sdp(current_transceivers, use_identity) + .await? + } else { + self.internal + .generate_matched_sdp( + current_transceivers, + use_identity, + true, /*includeUnmatched */ + DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), + ) + .await? + }; + + { + let mut sdp_origin = self.internal.sdp_origin.lock().await; + update_sdp_origin(&mut sdp_origin, &mut d); + } + let sdp = d.marshal(); + + offer = RTCSessionDescription { + sdp_type: RTCSdpType::Offer, + sdp, + parsed: Some(d), + }; + + // Verify local media hasn't changed during offer + // generation. Recompute if necessary + if !self.internal.has_local_description_changed(&offer).await { + break; + } + count += 1; + if count >= 128 { + return Err(Error::ErrExcessiveRetries); + } + } + + { + let mut last_offer = self.internal.last_offer.lock().await; + *last_offer = offer.sdp.clone(); + } + Ok(offer) + } + + /// Update the PeerConnectionState given the state of relevant transports + /// + async fn update_connection_state( + on_peer_connection_state_change_handler: &Arc< + ArcSwapOption>, + >, + is_closed: &Arc, + peer_connection_state: &Arc, + ice_connection_state: RTCIceConnectionState, + dtls_transport_state: RTCDtlsTransportState, + ) { + let connection_state = + // The RTCPeerConnection object's [[IsClosed]] slot is true. + if is_closed.load(Ordering::SeqCst) { + RTCPeerConnectionState::Closed + } else if ice_connection_state == RTCIceConnectionState::Failed || dtls_transport_state == RTCDtlsTransportState::Failed { + // Any of the RTCIceTransports or RTCDtlsTransports are in a "failed" state. + RTCPeerConnectionState::Failed + } else if ice_connection_state == RTCIceConnectionState::Disconnected { + // Any of the RTCIceTransports or RTCDtlsTransports are in the "disconnected" + // state and none of them are in the "failed" or "connecting" or "checking" state. + RTCPeerConnectionState::Disconnected + } else if ice_connection_state == RTCIceConnectionState::Connected && dtls_transport_state == RTCDtlsTransportState::Connected { + // All RTCIceTransports and RTCDtlsTransports are in the "connected", "completed" or "closed" + // state and at least one of them is in the "connected" or "completed" state. + RTCPeerConnectionState::Connected + } else if ice_connection_state == RTCIceConnectionState::Checking && dtls_transport_state == RTCDtlsTransportState::Connecting { + // Any of the RTCIceTransports or RTCDtlsTransports are in the "connecting" or + // "checking" state and none of them is in the "failed" state. + RTCPeerConnectionState::Connecting + } else { + RTCPeerConnectionState::New + }; + + if peer_connection_state.load(Ordering::SeqCst) == connection_state as u8 { + return; + } + + log::info!("peer connection state changed: {}", connection_state); + peer_connection_state.store(connection_state as u8, Ordering::SeqCst); + + RTCPeerConnection::do_peer_connection_state_change( + on_peer_connection_state_change_handler, + connection_state, + ) + .await; + } + + /// create_answer starts the PeerConnection and generates the localDescription + pub async fn create_answer( + &self, + _options: Option, + ) -> Result { + let use_identity = self.idp_login_url.is_some(); + let remote_desc = self.remote_description().await; + let remote_description: RTCSessionDescription; + if let Some(desc) = remote_desc { + remote_description = desc; + } else { + return Err(Error::ErrNoRemoteDescription); + } + if use_identity { + return Err(Error::ErrIdentityProviderNotImplemented); + } else if self.internal.is_closed.load(Ordering::SeqCst) { + return Err(Error::ErrConnectionClosed); + } else if self.signaling_state() != RTCSignalingState::HaveRemoteOffer + && self.signaling_state() != RTCSignalingState::HaveLocalPranswer + { + return Err(Error::ErrIncorrectSignalingState); + } + + let mut connection_role = self + .internal + .setting_engine + .answering_dtls_role + .to_connection_role(); + if connection_role == ConnectionRole::Unspecified { + connection_role = DEFAULT_DTLS_ROLE_ANSWER.to_connection_role(); + if let Some(parsed) = remote_description.parsed { + if Self::is_lite_set(&parsed) && !self.internal.setting_engine.candidates.ice_lite { + connection_role = DTLSRole::Server.to_connection_role(); + } + } + } + + let local_transceivers = self.get_transceivers().await; + let mut d = self + .internal + .generate_matched_sdp( + local_transceivers, + use_identity, + false, /*includeUnmatched */ + connection_role, + ) + .await?; + + { + let mut sdp_origin = self.internal.sdp_origin.lock().await; + update_sdp_origin(&mut sdp_origin, &mut d); + } + let sdp = d.marshal(); + + let answer = RTCSessionDescription { + sdp_type: RTCSdpType::Answer, + sdp, + parsed: Some(d), + }; + + { + let mut last_answer = self.internal.last_answer.lock().await; + *last_answer = answer.sdp.clone(); + } + Ok(answer) + } + + // 4.4.1.6 Set the SessionDescription + pub(crate) async fn set_description( + &self, + sd: &RTCSessionDescription, + op: StateChangeOp, + ) -> Result<()> { + if self.internal.is_closed.load(Ordering::SeqCst) { + return Err(Error::ErrConnectionClosed); + } else if sd.sdp_type == RTCSdpType::Unspecified { + return Err(Error::ErrPeerConnSDPTypeInvalidValue); + } + + let next_state = { + let cur = self.signaling_state(); + let new_sdpdoes_not_match_offer = Error::ErrSDPDoesNotMatchOffer; + let new_sdpdoes_not_match_answer = Error::ErrSDPDoesNotMatchAnswer; + + match op { + StateChangeOp::SetLocal => { + match sd.sdp_type { + // stable->SetLocal(offer)->have-local-offer + RTCSdpType::Offer => { + let check = { + let last_offer = self.internal.last_offer.lock().await; + sd.sdp != *last_offer + }; + if check { + Err(new_sdpdoes_not_match_offer) + } else { + let next_state = check_next_signaling_state( + cur, + RTCSignalingState::HaveLocalOffer, + StateChangeOp::SetLocal, + sd.sdp_type, + ); + if next_state.is_ok() { + let mut pending_local_description = + self.internal.pending_local_description.lock().await; + *pending_local_description = Some(sd.clone()); + } + next_state + } + } + // have-remote-offer->SetLocal(answer)->stable + // have-local-pranswer->SetLocal(answer)->stable + RTCSdpType::Answer => { + let check = { + let last_answer = self.internal.last_answer.lock().await; + sd.sdp != *last_answer + }; + if check { + Err(new_sdpdoes_not_match_answer) + } else { + let next_state = check_next_signaling_state( + cur, + RTCSignalingState::Stable, + StateChangeOp::SetLocal, + sd.sdp_type, + ); + if next_state.is_ok() { + let pending_remote_description = { + let mut pending_remote_description = + self.internal.pending_remote_description.lock().await; + pending_remote_description.take() + }; + let _pending_local_description = { + let mut pending_local_description = + self.internal.pending_local_description.lock().await; + pending_local_description.take() + }; + + { + let mut current_local_description = + self.internal.current_local_description.lock().await; + *current_local_description = Some(sd.clone()); + } + { + let mut current_remote_description = + self.internal.current_remote_description.lock().await; + *current_remote_description = pending_remote_description; + } + } + next_state + } + } + RTCSdpType::Rollback => { + let next_state = check_next_signaling_state( + cur, + RTCSignalingState::Stable, + StateChangeOp::SetLocal, + sd.sdp_type, + ); + if next_state.is_ok() { + let mut pending_local_description = + self.internal.pending_local_description.lock().await; + *pending_local_description = None; + } + next_state + } + // have-remote-offer->SetLocal(pranswer)->have-local-pranswer + RTCSdpType::Pranswer => { + let check = { + let last_answer = self.internal.last_answer.lock().await; + sd.sdp != *last_answer + }; + if check { + Err(new_sdpdoes_not_match_answer) + } else { + let next_state = check_next_signaling_state( + cur, + RTCSignalingState::HaveLocalPranswer, + StateChangeOp::SetLocal, + sd.sdp_type, + ); + if next_state.is_ok() { + let mut pending_local_description = + self.internal.pending_local_description.lock().await; + *pending_local_description = Some(sd.clone()); + } + next_state + } + } + _ => Err(Error::ErrPeerConnStateChangeInvalid), + } + } + StateChangeOp::SetRemote => { + match sd.sdp_type { + // stable->SetRemote(offer)->have-remote-offer + RTCSdpType::Offer => { + let next_state = check_next_signaling_state( + cur, + RTCSignalingState::HaveRemoteOffer, + StateChangeOp::SetRemote, + sd.sdp_type, + ); + if next_state.is_ok() { + let mut pending_remote_description = + self.internal.pending_remote_description.lock().await; + *pending_remote_description = Some(sd.clone()); + } + next_state + } + // have-local-offer->SetRemote(answer)->stable + // have-remote-pranswer->SetRemote(answer)->stable + RTCSdpType::Answer => { + let next_state = check_next_signaling_state( + cur, + RTCSignalingState::Stable, + StateChangeOp::SetRemote, + sd.sdp_type, + ); + if next_state.is_ok() { + let pending_local_description = { + let mut pending_local_description = + self.internal.pending_local_description.lock().await; + pending_local_description.take() + }; + + let _pending_remote_description = { + let mut pending_remote_description = + self.internal.pending_remote_description.lock().await; + pending_remote_description.take() + }; + + { + let mut current_remote_description = + self.internal.current_remote_description.lock().await; + *current_remote_description = Some(sd.clone()); + } + { + let mut current_local_description = + self.internal.current_local_description.lock().await; + *current_local_description = pending_local_description; + } + } + next_state + } + RTCSdpType::Rollback => { + let next_state = check_next_signaling_state( + cur, + RTCSignalingState::Stable, + StateChangeOp::SetRemote, + sd.sdp_type, + ); + if next_state.is_ok() { + let mut pending_remote_description = + self.internal.pending_remote_description.lock().await; + *pending_remote_description = None; + } + next_state + } + // have-local-offer->SetRemote(pranswer)->have-remote-pranswer + RTCSdpType::Pranswer => { + let next_state = check_next_signaling_state( + cur, + RTCSignalingState::HaveRemotePranswer, + StateChangeOp::SetRemote, + sd.sdp_type, + ); + if next_state.is_ok() { + let mut pending_remote_description = + self.internal.pending_remote_description.lock().await; + *pending_remote_description = Some(sd.clone()); + } + next_state + } + _ => Err(Error::ErrPeerConnStateChangeInvalid), + } + } //_ => Err(Error::ErrPeerConnStateChangeUnhandled.into()), + } + }; + + match next_state { + Ok(next_state) => { + self.internal + .signaling_state + .store(next_state as u8, Ordering::SeqCst); + if self.signaling_state() == RTCSignalingState::Stable { + self.internal + .is_negotiation_needed + .store(false, Ordering::SeqCst); + self.internal.trigger_negotiation_needed().await; + } + self.do_signaling_state_change(next_state).await; + Ok(()) + } + Err(err) => Err(err), + } + } + + /// set_local_description sets the SessionDescription of the local peer + pub async fn set_local_description(&self, mut desc: RTCSessionDescription) -> Result<()> { + if self.internal.is_closed.load(Ordering::SeqCst) { + return Err(Error::ErrConnectionClosed); + } + + let have_local_description = { + let current_local_description = self.internal.current_local_description.lock().await; + current_local_description.is_some() + }; + + // JSEP 5.4 + if desc.sdp.is_empty() { + match desc.sdp_type { + RTCSdpType::Answer | RTCSdpType::Pranswer => { + let last_answer = self.internal.last_answer.lock().await; + desc.sdp = last_answer.clone(); + } + RTCSdpType::Offer => { + let last_offer = self.internal.last_offer.lock().await; + desc.sdp = last_offer.clone(); + } + _ => return Err(Error::ErrPeerConnSDPTypeInvalidValueSetLocalDescription), + } + } + + desc.parsed = Some(desc.unmarshal()?); + self.set_description(&desc, StateChangeOp::SetLocal).await?; + + let we_answer = desc.sdp_type == RTCSdpType::Answer; + let remote_description = self.remote_description().await; + let mut local_transceivers = self.get_transceivers().await; + if we_answer { + if let Some(parsed) = desc.parsed { + // WebRTC Spec 1.0 https://www.w3.org/TR/webrtc/ + // Section 4.4.1.5 + for media in &parsed.media_descriptions { + if media.media_name.media == MEDIA_SECTION_APPLICATION { + continue; + } + + let kind = RTPCodecType::from(media.media_name.media.as_str()); + let direction = get_peer_direction(media); + if kind == RTPCodecType::Unspecified + || direction == RTCRtpTransceiverDirection::Unspecified + { + continue; + } + + let mid_value = match get_mid_value(media) { + Some(mid) if !mid.is_empty() => mid, + _ => continue, + }; + + let t = match find_by_mid(mid_value, &mut local_transceivers).await { + Some(t) => t, + None => continue, + }; + let previous_direction = t.current_direction(); + // 4.9.1.7.3 applying a local answer or pranswer + // Set transceiver.[[CurrentDirection]] and transceiver.[[FiredDirection]] to direction. + + // TODO: Also set FiredDirection here. + t.set_current_direction(direction); + t.process_new_current_direction(previous_direction).await?; + } + } + + if let Some(remote_desc) = remote_description { + self.start_rtp_senders().await?; + + let pci = Arc::clone(&self.internal); + let remote_desc = Arc::new(remote_desc); + self.internal + .ops + .enqueue(Operation::new( + move || { + let pc = Arc::clone(&pci); + let rd = Arc::clone(&remote_desc); + Box::pin(async move { + let _ = pc.start_rtp(have_local_description, rd).await; + false + }) + }, + "set_local_description", + )) + .await?; + } + } + + if self.internal.ice_gatherer.state() == RTCIceGathererState::New { + self.internal.ice_gatherer.gather().await + } else { + Ok(()) + } + } + + /// local_description returns PendingLocalDescription if it is not null and + /// otherwise it returns CurrentLocalDescription. This property is used to + /// determine if set_local_description has already been called. + /// + pub async fn local_description(&self) -> Option { + if let Some(pending_local_description) = self.pending_local_description().await { + return Some(pending_local_description); + } + self.current_local_description().await + } + + pub fn is_lite_set(desc: &SessionDescription) -> bool { + for a in &desc.attributes { + if a.key.trim() == ATTR_KEY_ICELITE { + return true; + } + } + false + } + + /// set_remote_description sets the SessionDescription of the remote peer + pub async fn set_remote_description(&self, mut desc: RTCSessionDescription) -> Result<()> { + if self.internal.is_closed.load(Ordering::SeqCst) { + return Err(Error::ErrConnectionClosed); + } + + let is_renegotiation = { + let current_remote_description = self.internal.current_remote_description.lock().await; + current_remote_description.is_some() + }; + + desc.parsed = Some(desc.unmarshal()?); + self.set_description(&desc, StateChangeOp::SetRemote) + .await?; + + if let Some(parsed) = &desc.parsed { + self.internal + .media_engine + .update_from_remote_description(parsed) + .await?; + + let mut local_transceivers = self.get_transceivers().await; + let remote_description = self.remote_description().await; + let we_offer = desc.sdp_type == RTCSdpType::Answer; + + if !we_offer { + if let Some(parsed) = remote_description.as_ref().and_then(|r| r.parsed.as_ref()) { + for media in &parsed.media_descriptions { + let mid_value = match get_mid_value(media) { + Some(m) => { + if m.is_empty() { + return Err(Error::ErrPeerConnRemoteDescriptionWithoutMidValue); + } else { + m + } + } + None => continue, + }; + + if media.media_name.media == MEDIA_SECTION_APPLICATION { + continue; + } + + let kind = RTPCodecType::from(media.media_name.media.as_str()); + let direction = get_peer_direction(media); + if kind == RTPCodecType::Unspecified + || direction == RTCRtpTransceiverDirection::Unspecified + { + continue; + } + + let t = if let Some(t) = + find_by_mid(mid_value, &mut local_transceivers).await + { + Some(t) + } else { + satisfy_type_and_direction(kind, direction, &mut local_transceivers) + .await + }; + + if let Some(t) = t { + if t.mid().is_none() { + t.set_mid(SmolStr::from(mid_value))?; + } + } else { + let local_direction = + if direction == RTCRtpTransceiverDirection::Recvonly { + RTCRtpTransceiverDirection::Sendonly + } else { + RTCRtpTransceiverDirection::Recvonly + }; + + let receive_mtu = self.internal.setting_engine.get_receive_mtu(); + + let receiver = Arc::new(RTCRtpReceiver::new( + receive_mtu, + kind, + Arc::clone(&self.internal.dtls_transport), + Arc::clone(&self.internal.media_engine), + Arc::clone(&self.interceptor), + )); + + let sender = Arc::new( + RTCRtpSender::new( + receive_mtu, + None, + Arc::clone(&self.internal.dtls_transport), + Arc::clone(&self.internal.media_engine), + Arc::clone(&self.interceptor), + false, + ) + .await, + ); + + let t = RTCRtpTransceiver::new( + receiver, + sender, + local_direction, + kind, + vec![], + Arc::clone(&self.internal.media_engine), + Some(Box::new(self.internal.make_negotiation_needed_trigger())), + ) + .await; + + self.internal.add_rtp_transceiver(Arc::clone(&t)).await; + + if t.mid().is_none() { + t.set_mid(SmolStr::from(mid_value))?; + } + } + } + } + } + + if we_offer { + // WebRTC Spec 1.0 https://www.w3.org/TR/webrtc/ + // 4.5.9.2 + // This is an answer from the remote. + if let Some(parsed) = remote_description.as_ref().and_then(|r| r.parsed.as_ref()) { + for media in &parsed.media_descriptions { + let mid_value = match get_mid_value(media) { + Some(m) => { + if m.is_empty() { + return Err(Error::ErrPeerConnRemoteDescriptionWithoutMidValue); + } else { + m + } + } + None => continue, + }; + + if media.media_name.media == MEDIA_SECTION_APPLICATION { + continue; + } + let kind = RTPCodecType::from(media.media_name.media.as_str()); + let direction = get_peer_direction(media); + if kind == RTPCodecType::Unspecified + || direction == RTCRtpTransceiverDirection::Unspecified + { + continue; + } + + if let Some(t) = find_by_mid(mid_value, &mut local_transceivers).await { + let previous_direction = t.current_direction(); + + // 4.5.9.2.9 + // Let direction be an RTCRtpTransceiverDirection value representing the direction + // from the media description, but with the send and receive directions reversed to + // represent this peer's point of view. If the media description is rejected, + // set direction to "inactive". + let reversed_direction = direction.reverse(); + + // 4.5.9.2.13.2 + // Set transceiver.[[CurrentDirection]] and transceiver.[[Direction]]s to direction. + t.set_current_direction(reversed_direction); + // TODO: According to the specification we should set + // transceiver.[[Direction]] here, however libWebrtc doesn't do this. + // NOTE: After raising this it seems like the specification might + // change to remove the setting of transceiver.[[Direction]]. + // See https://github.com/w3c/webrtc-pc/issues/2751#issuecomment-1185901962 + // t.set_direction_internal(reversed_direction); + t.process_new_current_direction(previous_direction).await?; + } + } + } + } + + let (remote_ufrag, remote_pwd, candidates) = extract_ice_details(parsed).await?; + + if is_renegotiation + && self + .internal + .ice_transport + .have_remote_credentials_change(&remote_ufrag, &remote_pwd) + .await + { + // An ICE Restart only happens implicitly for a set_remote_description of type offer + if !we_offer { + self.internal.ice_transport.restart().await?; + } + + self.internal + .ice_transport + .set_remote_credentials(remote_ufrag.clone(), remote_pwd.clone()) + .await?; + } + + for candidate in candidates { + self.internal + .ice_transport + .add_remote_candidate(Some(candidate)) + .await?; + } + + if is_renegotiation { + if we_offer { + self.start_rtp_senders().await?; + + let pci = Arc::clone(&self.internal); + let remote_desc = Arc::new(desc); + self.internal + .ops + .enqueue(Operation::new( + move || { + let pc = Arc::clone(&pci); + let rd = Arc::clone(&remote_desc); + Box::pin(async move { + let _ = pc.start_rtp(true, rd).await; + false + }) + }, + "set_remote_description renegotiation", + )) + .await?; + } + return Ok(()); + } + + let remote_is_lite = Self::is_lite_set(parsed); + + let (fingerprint, fingerprint_hash) = extract_fingerprint(parsed)?; + + // If one of the agents is lite and the other one is not, the lite agent must be the controlling agent. + // If both or neither agents are lite the offering agent is controlling. + // RFC 8445 S6.1.1 + let ice_role = if (we_offer + && remote_is_lite == self.internal.setting_engine.candidates.ice_lite) + || (remote_is_lite && !self.internal.setting_engine.candidates.ice_lite) + { + RTCIceRole::Controlling + } else { + RTCIceRole::Controlled + }; + + // Start the networking in a new routine since it will block until + // the connection is actually established. + if we_offer { + self.start_rtp_senders().await?; + } + + //log::trace!("start_transports: parsed={:?}", parsed); + + let pci = Arc::clone(&self.internal); + let dtls_role = DTLSRole::from(parsed); + let remote_desc = Arc::new(desc); + self.internal + .ops + .enqueue(Operation::new( + move || { + let pc = Arc::clone(&pci); + let rd = Arc::clone(&remote_desc); + let ru = remote_ufrag.clone(); + let rp = remote_pwd.clone(); + let fp = fingerprint.clone(); + let fp_hash = fingerprint_hash.clone(); + Box::pin(async move { + log::trace!( + "start_transports: ice_role={}, dtls_role={}", + ice_role, + dtls_role, + ); + pc.start_transports(ice_role, dtls_role, ru, rp, fp, fp_hash) + .await; + + if we_offer { + let _ = pc.start_rtp(false, rd).await; + } + false + }) + }, + "set_remote_description", + )) + .await?; + } + + Ok(()) + } + + /// start_rtp_senders starts all outbound RTP streams + pub(crate) async fn start_rtp_senders(&self) -> Result<()> { + let current_transceivers = self.internal.rtp_transceivers.lock().await; + for transceiver in &*current_transceivers { + let sender = transceiver.sender().await; + if sender.is_negotiated() && !sender.has_sent() { + sender.send(&sender.get_parameters().await).await?; + } + } + + Ok(()) + } + + /// remote_description returns pending_remote_description if it is not null and + /// otherwise it returns current_remote_description. This property is used to + /// determine if setRemoteDescription has already been called. + /// + pub async fn remote_description(&self) -> Option { + self.internal.remote_description().await + } + + /// add_ice_candidate accepts an ICE candidate string and adds it + /// to the existing set of candidates. + pub async fn add_ice_candidate(&self, candidate: RTCIceCandidateInit) -> Result<()> { + if self.remote_description().await.is_none() { + return Err(Error::ErrNoRemoteDescription); + } + + let candidate_value = match candidate.candidate.strip_prefix("candidate:") { + Some(s) => s, + None => candidate.candidate.as_str(), + }; + + let ice_candidate = if !candidate_value.is_empty() { + let candidate: Arc = + Arc::new(unmarshal_candidate(candidate_value)?); + + Some(RTCIceCandidate::from(&candidate)) + } else { + None + }; + + self.internal + .ice_transport + .add_remote_candidate(ice_candidate) + .await + } + + /// ice_connection_state returns the ICE connection state of the + /// PeerConnection instance. + pub fn ice_connection_state(&self) -> RTCIceConnectionState { + self.internal + .ice_connection_state + .load(Ordering::SeqCst) + .into() + } + + /// get_senders returns the RTPSender that are currently attached to this PeerConnection + pub async fn get_senders(&self) -> Vec> { + let mut senders = vec![]; + let rtp_transceivers = self.internal.rtp_transceivers.lock().await; + for transceiver in &*rtp_transceivers { + let sender = transceiver.sender().await; + senders.push(sender); + } + senders + } + + /// get_receivers returns the RTPReceivers that are currently attached to this PeerConnection + pub async fn get_receivers(&self) -> Vec> { + let mut receivers = vec![]; + let rtp_transceivers = self.internal.rtp_transceivers.lock().await; + for transceiver in &*rtp_transceivers { + receivers.push(transceiver.receiver().await); + } + receivers + } + + /// get_transceivers returns the RtpTransceiver that are currently attached to this PeerConnection + pub async fn get_transceivers(&self) -> Vec> { + let rtp_transceivers = self.internal.rtp_transceivers.lock().await; + rtp_transceivers.clone() + } + + /// add_track adds a Track to the PeerConnection + pub async fn add_track( + &self, + track: Arc, + ) -> Result> { + if self.internal.is_closed.load(Ordering::SeqCst) { + return Err(Error::ErrConnectionClosed); + } + + { + let rtp_transceivers = self.internal.rtp_transceivers.lock().await; + for t in &*rtp_transceivers { + if !t.stopped.load(Ordering::SeqCst) + && t.kind == track.kind() + && track.id() == t.sender().await.id + { + let sender = t.sender().await; + if sender.track().await.is_none() { + if let Err(err) = sender.replace_track(Some(track)).await { + let _ = sender.stop().await; + return Err(err); + } + + t.set_direction_internal(RTCRtpTransceiverDirection::from_send_recv( + true, + t.direction().has_recv(), + )); + + self.internal.trigger_negotiation_needed().await; + return Ok(sender); + } + } + } + } + + let transceiver = self + .internal + .new_transceiver_from_track(RTCRtpTransceiverDirection::Sendrecv, track) + .await?; + self.internal + .add_rtp_transceiver(Arc::clone(&transceiver)) + .await; + + Ok(transceiver.sender().await) + } + + /// remove_track removes a Track from the PeerConnection + pub async fn remove_track(&self, sender: &Arc) -> Result<()> { + if self.internal.is_closed.load(Ordering::SeqCst) { + return Err(Error::ErrConnectionClosed); + } + + let mut transceiver = None; + { + let rtp_transceivers = self.internal.rtp_transceivers.lock().await; + for t in &*rtp_transceivers { + if t.sender().await.id == sender.id { + if sender.track().await.is_none() { + return Ok(()); + } + transceiver = Some(t.clone()); + break; + } + } + } + + let t = transceiver.ok_or(Error::ErrSenderNotCreatedByConnection)?; + + // This also happens in `set_sending_track` but we need to make sure we do this + // before we call sender.stop to avoid a race condition when removing tracks and + // generating offers. + t.set_direction_internal(RTCRtpTransceiverDirection::from_send_recv( + false, + t.direction().has_recv(), + )); + // Stop the sender + let sender_result = sender.stop().await; + // This also updates direction + let sending_track_result = t.set_sending_track(None).await; + + if sender_result.is_ok() && sending_track_result.is_ok() { + self.internal.trigger_negotiation_needed().await; + } + Ok(()) + } + + /// add_transceiver_from_kind Create a new RtpTransceiver and adds it to the set of transceivers. + pub async fn add_transceiver_from_kind( + &self, + kind: RTPCodecType, + init: Option, + ) -> Result> { + self.internal.add_transceiver_from_kind(kind, init).await + } + + /// add_transceiver_from_track Create a new RtpTransceiver(SendRecv or SendOnly) and add it to the set of transceivers. + pub async fn add_transceiver_from_track( + &self, + track: Arc, + init: Option, + ) -> Result> { + if self.internal.is_closed.load(Ordering::SeqCst) { + return Err(Error::ErrConnectionClosed); + } + + let direction = init + .map(|init| init.direction) + .unwrap_or(RTCRtpTransceiverDirection::Sendrecv); + + let t = self + .internal + .new_transceiver_from_track(direction, track) + .await?; + + self.internal.add_rtp_transceiver(Arc::clone(&t)).await; + + Ok(t) + } + + /// create_data_channel creates a new DataChannel object with the given label + /// and optional DataChannelInit used to configure properties of the + /// underlying channel such as data reliability. + pub async fn create_data_channel( + &self, + label: &str, + options: Option, + ) -> Result> { + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #2) + if self.internal.is_closed.load(Ordering::SeqCst) { + return Err(Error::ErrConnectionClosed); + } + + let mut params = DataChannelParameters { + label: label.to_owned(), + ordered: true, + ..Default::default() + }; + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #19) + if let Some(options) = options { + // Ordered indicates if data is allowed to be delivered out of order. The + // default value of true, guarantees that data will be delivered in order. + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #9) + if let Some(ordered) = options.ordered { + params.ordered = ordered; + } + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #7) + if let Some(max_packet_life_time) = options.max_packet_life_time { + params.max_packet_life_time = max_packet_life_time; + } + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #8) + if let Some(max_retransmits) = options.max_retransmits { + params.max_retransmits = max_retransmits; + } + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #10) + if let Some(protocol) = options.protocol { + params.protocol = protocol; + } + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #11) + if params.protocol.len() > 65535 { + return Err(Error::ErrProtocolTooLarge); + } + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #12) + params.negotiated = options.negotiated; + } + + let d = Arc::new(RTCDataChannel::new( + params, + Arc::clone(&self.internal.setting_engine), + )); + + // https://w3c.github.io/webrtc-pc/#peer-to-peer-data-api (Step #16) + if d.max_packet_lifetime != 0 && d.max_retransmits != 0 { + return Err(Error::ErrRetransmitsOrPacketLifeTime); + } + + { + let mut data_channels = self.internal.sctp_transport.data_channels.lock().await; + data_channels.push(Arc::clone(&d)); + } + self.internal + .sctp_transport + .data_channels_requested + .fetch_add(1, Ordering::SeqCst); + + // If SCTP already connected open all the channels + if self.internal.sctp_transport.state() == RTCSctpTransportState::Connected { + d.open(Arc::clone(&self.internal.sctp_transport)).await?; + } + + self.internal.trigger_negotiation_needed().await; + + Ok(d) + } + + /// set_identity_provider is used to configure an identity provider to generate identity assertions + pub fn set_identity_provider(&self, _provider: &str) -> Result<()> { + Err(Error::ErrPeerConnSetIdentityProviderNotImplemented) + } + + /// write_rtcp sends a user provided RTCP packet to the connected peer. If no peer is connected the + /// packet is discarded. It also runs any configured interceptors. + pub async fn write_rtcp( + &self, + pkts: &[Box], + ) -> Result { + let a = Attributes::new(); + Ok(self.interceptor_rtcp_writer.write(pkts, &a).await?) + } + + /// close ends the PeerConnection + pub async fn close(&self) -> Result<()> { + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #1) + if self.internal.is_closed.load(Ordering::SeqCst) { + return Ok(()); + } + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #2) + self.internal.is_closed.store(true, Ordering::SeqCst); + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #3) + self.internal + .signaling_state + .store(RTCSignalingState::Closed as u8, Ordering::SeqCst); + + // Try closing everything and collect the errors + // Shutdown strategy: + // 1. All Conn close by closing their underlying Conn. + // 2. A Mux stops this chain. It won't close the underlying + // Conn if one of the endpoints is closed down. To + // continue the chain the Mux has to be closed. + let mut close_errs = vec![]; + + if let Err(err) = self.interceptor.close().await { + close_errs.push(Error::new(format!("interceptor: {err}"))); + } + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #4) + { + let mut rtp_transceivers = self.internal.rtp_transceivers.lock().await; + for t in &*rtp_transceivers { + if let Err(err) = t.stop().await { + close_errs.push(Error::new(format!("rtp_transceivers: {err}"))); + } + } + rtp_transceivers.clear(); + } + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #5) + { + let mut data_channels = self.internal.sctp_transport.data_channels.lock().await; + for d in &*data_channels { + if let Err(err) = d.close().await { + close_errs.push(Error::new(format!("data_channels: {err}"))); + } + } + data_channels.clear(); + } + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #6) + if let Err(err) = self.internal.sctp_transport.stop().await { + close_errs.push(Error::new(format!("sctp_transport: {err}"))); + } + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #7) + if let Err(err) = self.internal.dtls_transport.stop().await { + close_errs.push(Error::new(format!("dtls_transport: {err}"))); + } + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #8, #9, #10) + if let Err(err) = self.internal.ice_transport.stop().await { + close_errs.push(Error::new(format!("dtls_transport: {err}"))); + } + + // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #11) + RTCPeerConnection::update_connection_state( + &self.internal.on_peer_connection_state_change_handler, + &self.internal.is_closed, + &self.internal.peer_connection_state, + self.ice_connection_state(), + self.internal.dtls_transport.state(), + ) + .await; + + if let Err(err) = self.internal.ops.close().await { + close_errs.push(Error::new(format!("ops: {err}"))); + } + + flatten_errs(close_errs) + } + + /// CurrentLocalDescription represents the local description that was + /// successfully negotiated the last time the PeerConnection transitioned + /// into the stable state plus any local candidates that have been generated + /// by the ICEAgent since the offer or answer was created. + pub async fn current_local_description(&self) -> Option { + let local_description = { + let current_local_description = self.internal.current_local_description.lock().await; + current_local_description.clone() + }; + let ice_gather = Some(&self.internal.ice_gatherer); + let ice_gathering_state = self.ice_gathering_state(); + + populate_local_candidates(local_description.as_ref(), ice_gather, ice_gathering_state).await + } + + /// PendingLocalDescription represents a local description that is in the + /// process of being negotiated plus any local candidates that have been + /// generated by the ICEAgent since the offer or answer was created. If the + /// PeerConnection is in the stable state, the value is null. + pub async fn pending_local_description(&self) -> Option { + let local_description = { + let pending_local_description = self.internal.pending_local_description.lock().await; + pending_local_description.clone() + }; + let ice_gather = Some(&self.internal.ice_gatherer); + let ice_gathering_state = self.ice_gathering_state(); + + populate_local_candidates(local_description.as_ref(), ice_gather, ice_gathering_state).await + } + + /// current_remote_description represents the last remote description that was + /// successfully negotiated the last time the PeerConnection transitioned + /// into the stable state plus any remote candidates that have been supplied + /// via add_icecandidate() since the offer or answer was created. + pub async fn current_remote_description(&self) -> Option { + let current_remote_description = self.internal.current_remote_description.lock().await; + current_remote_description.clone() + } + + /// pending_remote_description represents a remote description that is in the + /// process of being negotiated, complete with any remote candidates that + /// have been supplied via add_icecandidate() since the offer or answer was + /// created. If the PeerConnection is in the stable state, the value is + /// null. + pub async fn pending_remote_description(&self) -> Option { + let pending_remote_description = self.internal.pending_remote_description.lock().await; + pending_remote_description.clone() + } + + /// signaling_state attribute returns the signaling state of the + /// PeerConnection instance. + pub fn signaling_state(&self) -> RTCSignalingState { + self.internal.signaling_state.load(Ordering::SeqCst).into() + } + + /// icegathering_state attribute returns the ICE gathering state of the + /// PeerConnection instance. + pub fn ice_gathering_state(&self) -> RTCIceGatheringState { + self.internal.ice_gathering_state() + } + + /// connection_state attribute returns the connection state of the + /// PeerConnection instance. + pub fn connection_state(&self) -> RTCPeerConnectionState { + self.internal + .peer_connection_state + .load(Ordering::SeqCst) + .into() + } + + pub async fn get_stats(&self) -> StatsReport { + self.internal + .get_stats(self.get_stats_id().to_owned()) + .await + .into() + } + + /// sctp returns the SCTPTransport for this PeerConnection + /// + /// The SCTP transport over which SCTP data is sent and received. If SCTP has not been negotiated, the value is nil. + /// + pub fn sctp(&self) -> Arc { + Arc::clone(&self.internal.sctp_transport) + } + + /// gathering_complete_promise is a Pion specific helper function that returns a channel that is closed when gathering is complete. + /// This function may be helpful in cases where you are unable to trickle your ICE Candidates. + /// + /// It is better to not use this function, and instead trickle candidates. If you use this function you will see longer connection startup times. + /// When the call is connected you will see no impact however. + pub async fn gathering_complete_promise(&self) -> mpsc::Receiver<()> { + let (gathering_complete_tx, gathering_complete_rx) = mpsc::channel(1); + + // It's possible to miss the GatherComplete event since setGatherCompleteHandler is an atomic operation and the + // promise might have been created after the gathering is finished. Therefore, we need to check if the ICE gathering + // state has changed to complete so that we don't block the caller forever. + let done = Arc::new(Mutex::new(Some(gathering_complete_tx))); + let done2 = Arc::clone(&done); + self.internal.set_gather_complete_handler(Box::new(move || { + log::trace!("setGatherCompleteHandler"); + let done3 = Arc::clone(&done2); + Box::pin(async move { + let mut d = done3.lock().await; + d.take(); + }) + })); + + if self.ice_gathering_state() == RTCIceGatheringState::Complete { + log::trace!("ICEGatheringState::Complete"); + let mut d = done.lock().await; + d.take(); + } + + gathering_complete_rx + } + + /// Returns the internal [`RTCDtlsTransport`]. + pub fn dtls_transport(&self) -> Arc { + Arc::clone(&self.internal.dtls_transport) + } + + /// Adds the specified [`RTCRtpTransceiver`] to this [`RTCPeerConnection`]. + pub async fn add_transceiver(&self, t: Arc) { + self.internal.add_rtp_transceiver(t).await + } +} diff --git a/rtc/src/peer_connection/offer_answer_options.rs b/rtc/src/peer_connection/offer_answer_options.rs new file mode 100644 index 0000000..a7be649 --- /dev/null +++ b/rtc/src/peer_connection/offer_answer_options.rs @@ -0,0 +1,22 @@ +/// AnswerOptions structure describes the options used to control the answer +/// creation process. +#[derive(Default, Debug, PartialEq, Eq, Copy, Clone)] +pub struct RTCAnswerOptions { + /// voice_activity_detection allows the application to provide information + /// about whether it wishes voice detection feature to be enabled or disabled. + pub voice_activity_detection: bool, +} + +/// OfferOptions structure describes the options used to control the offer +/// creation process +#[derive(Default, Debug, PartialEq, Eq, Copy, Clone)] +pub struct RTCOfferOptions { + /// voice_activity_detection allows the application to provide information + /// about whether it wishes voice detection feature to be enabled or disabled. + pub voice_activity_detection: bool, + + /// ice_restart forces the underlying ice gathering process to be restarted. + /// When this value is true, the generated description will have ICE + /// credentials that are different from the current credentials + pub ice_restart: bool, +} diff --git a/rtc/src/peer_connection/operation/mod.rs b/rtc/src/peer_connection/operation/mod.rs new file mode 100644 index 0000000..9f9e493 --- /dev/null +++ b/rtc/src/peer_connection/operation/mod.rs @@ -0,0 +1,139 @@ +#[cfg(test)] +mod operation_test; + +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use tokio::sync::mpsc; +use waitgroup::WaitGroup; + +use crate::error::Result; + +/// Operation is a function +pub struct Operation( + pub Box Pin + Send + 'static>>) + Send + Sync>, + pub &'static str, +); + +impl Operation { + pub(crate) fn new( + op: impl FnMut() -> Pin + Send + 'static>> + Send + Sync + 'static, + description: &'static str, + ) -> Self { + Self(Box::new(op), description) + } +} + +impl fmt::Debug for Operation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Operation") + .field(&"_") + .field(&self.1) + .finish() + } +} + +/// Operations is a task executor. +#[derive(Default)] +pub(crate) struct Operations { + length: Arc, + ops_tx: Option>>, + close_tx: Option>, +} + +impl Operations { + pub(crate) fn new() -> Self { + let length = Arc::new(AtomicUsize::new(0)); + let (ops_tx, ops_rx) = mpsc::unbounded_channel(); + let (close_tx, close_rx) = mpsc::channel(1); + let l = Arc::clone(&length); + let ops_tx = Arc::new(ops_tx); + let ops_tx2 = Arc::clone(&ops_tx); + tokio::spawn(async move { + Operations::start(l, ops_tx, ops_rx, close_rx).await; + }); + + Operations { + length, + ops_tx: Some(ops_tx2), + close_tx: Some(close_tx), + } + } + + /// enqueue adds a new action to be executed. If there are no actions scheduled, + /// the execution will start immediately in a new goroutine. + pub(crate) async fn enqueue(&self, op: Operation) -> Result<()> { + if let Some(ops_tx) = &self.ops_tx { + return Operations::enqueue_inner(op, ops_tx, &self.length); + } + + Ok(()) + } + + fn enqueue_inner( + op: Operation, + ops_tx: &Arc>, + length: &Arc, + ) -> Result<()> { + length.fetch_add(1, Ordering::SeqCst); + ops_tx.send(op)?; + + Ok(()) + } + + /// is_empty checks if there are tasks in the queue + pub(crate) async fn is_empty(&self) -> bool { + self.length.load(Ordering::SeqCst) == 0 + } + + /// Done blocks until all currently enqueued operations are finished executing. + /// For more complex synchronization, use Enqueue directly. + pub(crate) async fn done(&self) { + let wg = WaitGroup::new(); + let mut w = Some(wg.worker()); + let _ = self + .enqueue(Operation::new( + move || { + let _d = w.take(); + Box::pin(async { false }) + }, + "Operation::done", + )) + .await; + wg.wait().await; + } + + pub(crate) async fn start( + length: Arc, + ops_tx: Arc>, + mut ops_rx: mpsc::UnboundedReceiver, + mut close_rx: mpsc::Receiver<()>, + ) { + loop { + tokio::select! { + _ = close_rx.recv() => { + break; + } + result = ops_rx.recv() => { + if let Some(mut f) = result { + length.fetch_sub(1, Ordering::SeqCst); + if f.0().await { + // Requeue this operation + let _ = Operations::enqueue_inner(f, &ops_tx, &length); + } + } + } + } + } + } + + pub(crate) async fn close(&self) -> Result<()> { + if let Some(close_tx) = &self.close_tx { + close_tx.send(()).await?; + } + Ok(()) + } +} diff --git a/rtc/src/peer_connection/operation/operation_test.rs b/rtc/src/peer_connection/operation/operation_test.rs new file mode 100644 index 0000000..0ecc344 --- /dev/null +++ b/rtc/src/peer_connection/operation/operation_test.rs @@ -0,0 +1,47 @@ +use tokio::sync::Mutex; + +use super::*; +use crate::error::Result; + +#[tokio::test] +async fn test_operations_enqueue() -> Result<()> { + let ops = Operations::new(); + for _ in 0..100 { + let results = Arc::new(Mutex::new(vec![0; 16])); + for k in 0..16 { + let r = Arc::clone(&results); + ops.enqueue(Operation::new( + move || { + let r2 = Arc::clone(&r); + Box::pin(async move { + let mut r3 = r2.lock().await; + r3[k] += k * k; + r3[k] == 225 + }) + }, + "test_operations_enqueue", + )) + .await?; + } + + ops.done().await; + let expected = vec![ + 0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 450, + ]; + { + let r = results.lock().await; + assert_eq!(r.len(), expected.len()); + assert_eq!(&*r, &expected); + } + } + + Ok(()) +} + +#[tokio::test] +async fn test_operations_done() -> Result<()> { + let ops = Operations::new(); + ops.done().await; + + Ok(()) +} diff --git a/rtc/src/peer_connection/peer_connection_internal.rs b/rtc/src/peer_connection/peer_connection_internal.rs new file mode 100644 index 0000000..a02bd20 --- /dev/null +++ b/rtc/src/peer_connection/peer_connection_internal.rs @@ -0,0 +1,1492 @@ +use std::collections::VecDeque; +use std::sync::atomic::AtomicIsize; +use std::sync::Weak; + +use arc_swap::ArcSwapOption; +use smol_str::SmolStr; +use tokio::time::Instant; +use util::Unmarshal; + +use super::*; +use crate::rtp_transceiver::create_stream_info; +use crate::stats::stats_collector::StatsCollector; +use crate::stats::{ + InboundRTPStats, OutboundRTPStats, RTCStatsType, RemoteInboundRTPStats, RemoteOutboundRTPStats, + StatsReportType, +}; +use crate::track::TrackStream; +use crate::{SDES_REPAIR_RTP_STREAM_ID_URI, SDP_ATTRIBUTE_RID}; + +pub(crate) struct PeerConnectionInternal { + /// a value containing the last known greater mid value + /// we internally generate mids as numbers. Needed since JSEP + /// requires that when reusing a media section a new unique mid + /// should be defined (see JSEP 3.4.1). + pub(super) greater_mid: AtomicIsize, + pub(super) sdp_origin: Mutex<::sdp::description::session::Origin>, + pub(super) last_offer: Mutex, + pub(super) last_answer: Mutex, + + pub(super) on_negotiation_needed_handler: Arc>>, + pub(super) is_closed: Arc, + + /// ops is an operations queue which will ensure the enqueued actions are + /// executed in order. It is used for asynchronously, but serially processing + /// remote and local descriptions + pub(crate) ops: Arc, + pub(super) negotiation_needed_state: Arc, + pub(super) is_negotiation_needed: Arc, + pub(super) signaling_state: Arc, + + pub(super) ice_transport: Arc, + pub(super) dtls_transport: Arc, + pub(super) on_peer_connection_state_change_handler: + Arc>>, + pub(super) peer_connection_state: Arc, + pub(super) ice_connection_state: Arc, + + pub(super) sctp_transport: Arc, + pub(super) rtp_transceivers: Arc>>>, + + pub(super) on_track_handler: Arc>>, + pub(super) on_signaling_state_change_handler: + ArcSwapOption>, + pub(super) on_ice_connection_state_change_handler: + Arc>>, + pub(super) on_data_channel_handler: Arc>>, + + pub(super) ice_gatherer: Arc, + + pub(super) current_local_description: Arc>>, + pub(super) current_remote_description: Arc>>, + pub(super) pending_local_description: Arc>>, + pub(super) pending_remote_description: Arc>>, + + // A reference to the associated API state used by this connection + pub(super) setting_engine: Arc, + pub(crate) media_engine: Arc, + pub(super) interceptor: Weak, + stats_interceptor: Arc, +} + +impl PeerConnectionInternal { + pub(super) async fn new( + api: &API, + interceptor: Weak, + stats_interceptor: Arc, + mut configuration: RTCConfiguration, + ) -> Result<(Arc, RTCConfiguration)> { + let mut pc = PeerConnectionInternal { + greater_mid: AtomicIsize::new(-1), + sdp_origin: Mutex::new(Default::default()), + last_offer: Mutex::new("".to_owned()), + last_answer: Mutex::new("".to_owned()), + + on_negotiation_needed_handler: Arc::new(ArcSwapOption::empty()), + ops: Arc::new(Operations::new()), + is_closed: Arc::new(AtomicBool::new(false)), + is_negotiation_needed: Arc::new(AtomicBool::new(false)), + negotiation_needed_state: Arc::new(AtomicU8::new(NegotiationNeededState::Empty as u8)), + signaling_state: Arc::new(AtomicU8::new(RTCSignalingState::Stable as u8)), + ice_transport: Arc::new(Default::default()), + dtls_transport: Arc::new(Default::default()), + ice_connection_state: Arc::new(AtomicU8::new(RTCIceConnectionState::New as u8)), + sctp_transport: Arc::new(Default::default()), + rtp_transceivers: Arc::new(Default::default()), + on_track_handler: Arc::new(ArcSwapOption::empty()), + on_signaling_state_change_handler: ArcSwapOption::empty(), + on_ice_connection_state_change_handler: Arc::new(ArcSwapOption::empty()), + on_data_channel_handler: Arc::new(Default::default()), + ice_gatherer: Arc::new(Default::default()), + current_local_description: Arc::new(Default::default()), + current_remote_description: Arc::new(Default::default()), + pending_local_description: Arc::new(Default::default()), + peer_connection_state: Arc::new(AtomicU8::new(RTCPeerConnectionState::New as u8)), + + setting_engine: Arc::clone(&api.setting_engine), + media_engine: if !api.setting_engine.disable_media_engine_copy { + Arc::new(api.media_engine.clone_to()) + } else { + Arc::clone(&api.media_engine) + }, + interceptor, + stats_interceptor, + on_peer_connection_state_change_handler: Arc::new(ArcSwapOption::empty()), + pending_remote_description: Arc::new(Default::default()), + }; + + // Create the ice gatherer + pc.ice_gatherer = Arc::new(api.new_ice_gatherer(RTCIceGatherOptions { + ice_servers: configuration.get_ice_servers(), + ice_gather_policy: configuration.ice_transport_policy, + })?); + + // Create the ice transport + pc.ice_transport = pc.create_ice_transport(api).await; + + // Create the DTLS transport + let certificates = configuration.certificates.drain(..).collect(); + pc.dtls_transport = + Arc::new(api.new_dtls_transport(Arc::clone(&pc.ice_transport), certificates)?); + + // Create the SCTP transport + pc.sctp_transport = Arc::new(api.new_sctp_transport(Arc::clone(&pc.dtls_transport))?); + + // Wire up the on datachannel handler + let on_data_channel_handler = Arc::clone(&pc.on_data_channel_handler); + pc.sctp_transport + .on_data_channel(Box::new(move |d: Arc| { + let on_data_channel_handler2 = Arc::clone(&on_data_channel_handler); + Box::pin(async move { + if let Some(handler) = &*on_data_channel_handler2.load() { + let mut f = handler.lock().await; + f(d).await; + } + }) + })); + + Ok((Arc::new(pc), configuration)) + } + + pub(super) async fn start_rtp( + self: &Arc, + is_renegotiation: bool, + remote_desc: Arc, + ) -> Result<()> { + let mut track_details = if let Some(parsed) = &remote_desc.parsed { + track_details_from_sdp(parsed, false) + } else { + vec![] + }; + + let current_transceivers = { + let current_transceivers = self.rtp_transceivers.lock().await; + current_transceivers.clone() + }; + + if !is_renegotiation { + self.undeclared_media_processor(); + } else { + for t in ¤t_transceivers { + let receiver = t.receiver().await; + let tracks = receiver.tracks().await; + if tracks.is_empty() { + continue; + } + + let mut receiver_needs_stopped = false; + + for t in tracks { + if !t.rid().is_empty() { + if let Some(details) = + track_details_for_rid(&track_details, SmolStr::from(t.rid())) + { + t.set_id(details.id.clone()); + t.set_stream_id(details.stream_id.clone()); + continue; + } + } else if t.ssrc() != 0 { + if let Some(details) = track_details_for_ssrc(&track_details, t.ssrc()) { + t.set_id(details.id.clone()); + t.set_stream_id(details.stream_id.clone()); + continue; + } + } + + receiver_needs_stopped = true; + } + + if !receiver_needs_stopped { + continue; + } + + log::info!("Stopping receiver {:?}", receiver); + if let Err(err) = receiver.stop().await { + log::warn!("Failed to stop RtpReceiver: {}", err); + continue; + } + + let interceptor = self + .interceptor + .upgrade() + .ok_or(Error::ErrInterceptorNotBind)?; + + let receiver = Arc::new(RTCRtpReceiver::new( + self.setting_engine.get_receive_mtu(), + receiver.kind(), + Arc::clone(&self.dtls_transport), + Arc::clone(&self.media_engine), + interceptor, + )); + t.set_receiver(receiver).await; + } + } + + self.start_rtp_receivers(&mut track_details, ¤t_transceivers) + .await?; + if let Some(parsed) = &remote_desc.parsed { + if have_application_media_section(parsed) { + self.start_sctp().await; + } + } + + Ok(()) + } + + /// undeclared_media_processor handles RTP/RTCP packets that don't match any a:ssrc lines + fn undeclared_media_processor(self: &Arc) { + let dtls_transport = Arc::clone(&self.dtls_transport); + let is_closed = Arc::clone(&self.is_closed); + let pci = Arc::clone(self); + + // SRTP acceptor + tokio::spawn(async move { + let simulcast_routine_count = Arc::new(AtomicU64::new(0)); + loop { + let srtp_session = match dtls_transport.get_srtp_session().await { + Some(s) => s, + None => { + log::warn!("undeclared_media_processor failed to open SrtpSession"); + return; + } + }; + + let stream = match srtp_session.accept().await { + Ok(stream) => stream, + Err(err) => { + log::warn!("Failed to accept RTP {}", err); + return; + } + }; + + if is_closed.load(Ordering::SeqCst) { + if let Err(err) = stream.close().await { + log::warn!("Failed to close RTP stream {}", err); + } + continue; + } + + if simulcast_routine_count.fetch_add(1, Ordering::SeqCst) + 1 + >= SIMULCAST_MAX_PROBE_ROUTINES + { + simulcast_routine_count.fetch_sub(1, Ordering::SeqCst); + log::warn!("{:?}", Error::ErrSimulcastProbeOverflow); + continue; + } + + { + let dtls_transport = Arc::clone(&dtls_transport); + let simulcast_routine_count = Arc::clone(&simulcast_routine_count); + let pci = Arc::clone(&pci); + tokio::spawn(async move { + let ssrc = stream.get_ssrc(); + + dtls_transport + .store_simulcast_stream(ssrc, Arc::clone(&stream)) + .await; + + if let Err(err) = pci.handle_incoming_ssrc(stream, ssrc).await { + log::error!( + "Incoming unhandled RTP ssrc({}), on_track will not be fired. {}", + ssrc, + err + ); + } + + simulcast_routine_count.fetch_sub(1, Ordering::SeqCst); + }); + } + } + }); + + // SRTCP acceptor + { + let dtls_transport = Arc::clone(&self.dtls_transport); + tokio::spawn(async move { + loop { + let srtcp_session = match dtls_transport.get_srtcp_session().await { + Some(s) => s, + None => { + log::warn!("undeclared_media_processor failed to open SrtcpSession"); + return; + } + }; + + let stream = match srtcp_session.accept().await { + Ok(stream) => stream, + Err(err) => { + log::warn!("Failed to accept RTCP {}", err); + return; + } + }; + log::warn!( + "Incoming unhandled RTCP ssrc({}), on_track will not be fired", + stream.get_ssrc() + ); + } + }); + } + } + + /// start_rtp_receivers opens knows inbound SRTP streams from the remote_description + async fn start_rtp_receivers( + self: &Arc, + incoming_tracks: &mut Vec, + local_transceivers: &[Arc], + ) -> Result<()> { + // Ensure we haven't already started a transceiver for this ssrc + let mut filtered_tracks = incoming_tracks.clone(); + for incoming_track in incoming_tracks { + // If we already have a TrackRemote for a given SSRC don't handle it again + for t in local_transceivers { + let receiver = t.receiver().await; + for track in receiver.tracks().await { + for ssrc in &incoming_track.ssrcs { + if *ssrc == track.ssrc() { + filter_track_with_ssrc(&mut filtered_tracks, track.ssrc()); + } + } + } + } + } + + let mut unhandled_tracks = vec![]; // filtered_tracks[:0] + for incoming_track in filtered_tracks.iter() { + let mut track_handled = false; + for t in local_transceivers { + if t.mid().as_ref() != Some(&incoming_track.mid) { + continue; + } + + if (incoming_track.kind != t.kind()) + || (t.direction() != RTCRtpTransceiverDirection::Recvonly + && t.direction() != RTCRtpTransceiverDirection::Sendrecv) + { + continue; + } + + let receiver = t.receiver().await; + if receiver.have_received().await { + continue; + } + PeerConnectionInternal::start_receiver( + self.setting_engine.get_receive_mtu(), + incoming_track, + receiver, + Arc::clone(t), + Arc::clone(&self.on_track_handler), + ) + .await; + track_handled = true; + } + + if !track_handled { + unhandled_tracks.push(incoming_track); + } + } + + Ok(()) + } + + /// Start SCTP subsystem + async fn start_sctp(&self) { + // Start sctp + if let Err(err) = self + .sctp_transport + .start(SCTPTransportCapabilities { + max_message_size: 0, + }) + .await + { + log::warn!("Failed to start SCTP: {}", err); + if let Err(err) = self.sctp_transport.stop().await { + log::warn!("Failed to stop SCTPTransport: {}", err); + } + + return; + } + + // DataChannels that need to be opened now that SCTP is available + // make a copy we may have incoming DataChannels mutating this while we open + let data_channels = { + let data_channels = self.sctp_transport.data_channels.lock().await; + data_channels.clone() + }; + + let mut opened_dc_count = 0; + for d in data_channels { + if d.ready_state() == RTCDataChannelState::Connecting { + if let Err(err) = d.open(Arc::clone(&self.sctp_transport)).await { + log::warn!("failed to open data channel: {}", err); + continue; + } + opened_dc_count += 1; + } + } + + self.sctp_transport + .data_channels_opened + .fetch_add(opened_dc_count, Ordering::SeqCst); + } + + pub(super) async fn add_transceiver_from_kind( + &self, + kind: RTPCodecType, + init: Option, + ) -> Result> { + if self.is_closed.load(Ordering::SeqCst) { + return Err(Error::ErrConnectionClosed); + } + + let direction = init + .map(|value| value.direction) + .unwrap_or(RTCRtpTransceiverDirection::Sendrecv); + + if direction == RTCRtpTransceiverDirection::Unspecified { + return Err(Error::ErrPeerConnAddTransceiverFromKindSupport); + } + + let interceptor = self + .interceptor + .upgrade() + .ok_or(Error::ErrInterceptorNotBind)?; + let receiver = Arc::new(RTCRtpReceiver::new( + self.setting_engine.get_receive_mtu(), + kind, + Arc::clone(&self.dtls_transport), + Arc::clone(&self.media_engine), + Arc::clone(&interceptor), + )); + + let sender = Arc::new( + RTCRtpSender::new( + self.setting_engine.get_receive_mtu(), + None, + Arc::clone(&self.dtls_transport), + Arc::clone(&self.media_engine), + interceptor, + false, + ) + .await, + ); + + let t = RTCRtpTransceiver::new( + receiver, + sender, + direction, + kind, + vec![], + Arc::clone(&self.media_engine), + Some(Box::new(self.make_negotiation_needed_trigger())), + ) + .await; + + self.add_rtp_transceiver(Arc::clone(&t)).await; + + Ok(t) + } + + pub(super) async fn new_transceiver_from_track( + &self, + direction: RTCRtpTransceiverDirection, + track: Arc, + ) -> Result> { + let interceptor = self + .interceptor + .upgrade() + .ok_or(Error::ErrInterceptorNotBind)?; + + if direction == RTCRtpTransceiverDirection::Unspecified { + return Err(Error::ErrPeerConnAddTransceiverFromTrackSupport); + } + + let r = Arc::new(RTCRtpReceiver::new( + self.setting_engine.get_receive_mtu(), + track.kind(), + Arc::clone(&self.dtls_transport), + Arc::clone(&self.media_engine), + Arc::clone(&interceptor), + )); + + let s = Arc::new( + RTCRtpSender::new( + self.setting_engine.get_receive_mtu(), + Some(Arc::clone(&track)), + Arc::clone(&self.dtls_transport), + Arc::clone(&self.media_engine), + Arc::clone(&interceptor), + false, + ) + .await, + ); + + Ok(RTCRtpTransceiver::new( + r, + s, + direction, + track.kind(), + vec![], + Arc::clone(&self.media_engine), + Some(Box::new(self.make_negotiation_needed_trigger())), + ) + .await) + } + + /// add_rtp_transceiver appends t into rtp_transceivers + /// and fires onNegotiationNeeded; + /// caller of this method should hold `self.mu` lock + pub(super) async fn add_rtp_transceiver(&self, t: Arc) { + { + let mut rtp_transceivers = self.rtp_transceivers.lock().await; + rtp_transceivers.push(t); + } + self.trigger_negotiation_needed().await; + } + + /// Helper to trigger a negotiation needed. + pub(crate) async fn trigger_negotiation_needed(&self) { + RTCPeerConnection::do_negotiation_needed(self.create_negotiation_needed_params()).await; + } + + /// Creates the parameters needed to trigger a negotiation needed. + fn create_negotiation_needed_params(&self) -> NegotiationNeededParams { + NegotiationNeededParams { + on_negotiation_needed_handler: Arc::clone(&self.on_negotiation_needed_handler), + is_closed: Arc::clone(&self.is_closed), + ops: Arc::clone(&self.ops), + negotiation_needed_state: Arc::clone(&self.negotiation_needed_state), + is_negotiation_needed: Arc::clone(&self.is_negotiation_needed), + signaling_state: Arc::clone(&self.signaling_state), + check_negotiation_needed_params: CheckNegotiationNeededParams { + sctp_transport: Arc::clone(&self.sctp_transport), + rtp_transceivers: Arc::clone(&self.rtp_transceivers), + current_local_description: Arc::clone(&self.current_local_description), + current_remote_description: Arc::clone(&self.current_remote_description), + }, + } + } + + pub(crate) fn make_negotiation_needed_trigger( + &self, + ) -> impl Fn() -> Pin + Send + Sync>> + Send + Sync { + let params = self.create_negotiation_needed_params(); + move || { + let params = params.clone(); + Box::pin(async move { + let params = params.clone(); + RTCPeerConnection::do_negotiation_needed(params).await; + }) + } + } + + pub(super) async fn remote_description(&self) -> Option { + let pending_remote_description = self.pending_remote_description.lock().await; + if pending_remote_description.is_some() { + pending_remote_description.clone() + } else { + let current_remote_description = self.current_remote_description.lock().await; + current_remote_description.clone() + } + } + + pub(super) fn set_gather_complete_handler(&self, f: OnGatheringCompleteHdlrFn) { + self.ice_gatherer.on_gathering_complete(f); + } + + /// Start all transports. PeerConnection now has enough state + pub(super) async fn start_transports( + self: &Arc, + ice_role: RTCIceRole, + dtls_role: DTLSRole, + remote_ufrag: String, + remote_pwd: String, + fingerprint: String, + fingerprint_hash: String, + ) { + // Start the ice transport + if let Err(err) = self + .ice_transport + .start( + &RTCIceParameters { + username_fragment: remote_ufrag, + password: remote_pwd, + ice_lite: false, + }, + Some(ice_role), + ) + .await + { + log::warn!("Failed to start manager ice: {}", err); + return; + } + + // Start the dtls_transport transport + let result = self + .dtls_transport + .start(DTLSParameters { + role: dtls_role, + fingerprints: vec![RTCDtlsFingerprint { + algorithm: fingerprint_hash, + value: fingerprint, + }], + }) + .await; + RTCPeerConnection::update_connection_state( + &self.on_peer_connection_state_change_handler, + &self.is_closed, + &self.peer_connection_state, + self.ice_connection_state.load(Ordering::SeqCst).into(), + self.dtls_transport.state(), + ) + .await; + if let Err(err) = result { + log::warn!("Failed to start manager dtls: {}", err); + } + } + + /// generate_unmatched_sdp generates an SDP that doesn't take remote state into account + /// This is used for the initial call for CreateOffer + pub(super) async fn generate_unmatched_sdp( + &self, + local_transceivers: Vec>, + use_identity: bool, + ) -> Result { + let d = SessionDescription::new_jsep_session_description(use_identity); + + let ice_params = self.ice_gatherer.get_local_parameters().await?; + + let candidates = self.ice_gatherer.get_local_candidates().await?; + + let mut media_sections = vec![]; + + for t in &local_transceivers { + if t.stopped.load(Ordering::SeqCst) { + // An "m=" section is generated for each + // RtpTransceiver that has been added to the PeerConnection, excluding + // any stopped RtpTransceivers; + continue; + } + + // TODO: This is dubious because of rollbacks. + t.sender().await.set_negotiated(); + media_sections.push(MediaSection { + id: t.mid().unwrap().to_string(), + transceivers: vec![Arc::clone(t)], + ..Default::default() + }); + } + + if self + .sctp_transport + .data_channels_requested + .load(Ordering::SeqCst) + != 0 + { + media_sections.push(MediaSection { + id: format!("{}", media_sections.len()), + data: true, + ..Default::default() + }); + } + + let dtls_fingerprints = if let Some(cert) = self.dtls_transport.certificates.first() { + cert.get_fingerprints() + } else { + return Err(Error::ErrNonCertificate); + }; + + let params = PopulateSdpParams { + media_description_fingerprint: self.setting_engine.sdp_media_level_fingerprints, + is_icelite: self.setting_engine.candidates.ice_lite, + connection_role: DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), + ice_gathering_state: self.ice_gathering_state(), + }; + populate_sdp( + d, + &dtls_fingerprints, + &self.media_engine, + &candidates, + &ice_params, + &media_sections, + params, + ) + .await + } + + /// generate_matched_sdp generates a SDP and takes the remote state into account + /// this is used everytime we have a remote_description + pub(super) async fn generate_matched_sdp( + &self, + mut local_transceivers: Vec>, + use_identity: bool, + include_unmatched: bool, + connection_role: ConnectionRole, + ) -> Result { + let d = SessionDescription::new_jsep_session_description(use_identity); + + let ice_params = self.ice_gatherer.get_local_parameters().await?; + let candidates = self.ice_gatherer.get_local_candidates().await?; + + let remote_description = self.remote_description().await; + let mut media_sections = vec![]; + let mut already_have_application_media_section = false; + if let Some(remote_description) = remote_description.as_ref() { + if let Some(parsed) = &remote_description.parsed { + for media in &parsed.media_descriptions { + if let Some(mid_value) = get_mid_value(media) { + if mid_value.is_empty() { + return Err(Error::ErrPeerConnRemoteDescriptionWithoutMidValue); + } + + if media.media_name.media == MEDIA_SECTION_APPLICATION { + media_sections.push(MediaSection { + id: mid_value.to_owned(), + data: true, + ..Default::default() + }); + already_have_application_media_section = true; + continue; + } + + let kind = RTPCodecType::from(media.media_name.media.as_str()); + let direction = get_peer_direction(media); + if kind == RTPCodecType::Unspecified + || direction == RTCRtpTransceiverDirection::Unspecified + { + continue; + } + + if let Some(t) = find_by_mid(mid_value, &mut local_transceivers).await { + t.sender().await.set_negotiated(); + let media_transceivers = vec![t]; + + // NB: The below could use `then_some`, but with our current MSRV + // it's not possible to actually do this. The clippy version that + // ships with 1.64.0 complains about this so we disable it for now. + #[allow(clippy::unnecessary_lazy_evaluations)] + media_sections.push(MediaSection { + id: mid_value.to_owned(), + transceivers: media_transceivers, + rid_map: get_rids(media), + offered_direction: (!include_unmatched).then(|| direction), + ..Default::default() + }); + } else { + return Err(Error::ErrPeerConnTransceiverMidNil); + } + } + } + } + } + + // If we are offering also include unmatched local transceivers + if include_unmatched { + for t in &local_transceivers { + t.sender().await.set_negotiated(); + media_sections.push(MediaSection { + id: t.mid().unwrap().to_string(), + transceivers: vec![Arc::clone(t)], + ..Default::default() + }); + } + + if self + .sctp_transport + .data_channels_requested + .load(Ordering::SeqCst) + != 0 + && !already_have_application_media_section + { + media_sections.push(MediaSection { + id: format!("{}", media_sections.len()), + data: true, + ..Default::default() + }); + } + } + + let dtls_fingerprints = if let Some(cert) = self.dtls_transport.certificates.first() { + cert.get_fingerprints() + } else { + return Err(Error::ErrNonCertificate); + }; + + let params = PopulateSdpParams { + media_description_fingerprint: self.setting_engine.sdp_media_level_fingerprints, + is_icelite: self.setting_engine.candidates.ice_lite, + connection_role, + ice_gathering_state: self.ice_gathering_state(), + }; + populate_sdp( + d, + &dtls_fingerprints, + &self.media_engine, + &candidates, + &ice_params, + &media_sections, + params, + ) + .await + } + + pub(super) fn ice_gathering_state(&self) -> RTCIceGatheringState { + match self.ice_gatherer.state() { + RTCIceGathererState::New => RTCIceGatheringState::New, + RTCIceGathererState::Gathering => RTCIceGatheringState::Gathering, + _ => RTCIceGatheringState::Complete, + } + } + + async fn handle_undeclared_ssrc( + self: &Arc, + ssrc: SSRC, + remote_description: &SessionDescription, + ) -> Result { + if remote_description.media_descriptions.len() != 1 { + return Ok(false); + } + + let only_media_section = &remote_description.media_descriptions[0]; + let mut stream_id = ""; + let mut id = ""; + + for a in &only_media_section.attributes { + match a.key.as_str() { + ATTR_KEY_MSID => { + if let Some(value) = &a.value { + let split: Vec<&str> = value.split(' ').collect(); + if split.len() == 2 { + stream_id = split[0]; + id = split[1]; + } + } + } + ATTR_KEY_SSRC => return Err(Error::ErrPeerConnSingleMediaSectionHasExplicitSSRC), + SDP_ATTRIBUTE_RID => return Ok(false), + _ => {} + }; + } + + let mut incoming = TrackDetails { + ssrcs: vec![ssrc], + kind: RTPCodecType::Video, + stream_id: stream_id.to_owned(), + id: id.to_owned(), + ..Default::default() + }; + if only_media_section.media_name.media == RTPCodecType::Audio.to_string() { + incoming.kind = RTPCodecType::Audio; + } + + let t = self + .add_transceiver_from_kind( + incoming.kind, + Some(RTCRtpTransceiverInit { + direction: RTCRtpTransceiverDirection::Sendrecv, + send_encodings: vec![], + }), + ) + .await?; + + let receiver = t.receiver().await; + PeerConnectionInternal::start_receiver( + self.setting_engine.get_receive_mtu(), + &incoming, + receiver, + t, + Arc::clone(&self.on_track_handler), + ) + .await; + Ok(true) + } + + async fn handle_incoming_ssrc( + self: &Arc, + rtp_stream: Arc, + ssrc: SSRC, + ) -> Result<()> { + let parsed = match self.remote_description().await.and_then(|rd| rd.parsed) { + Some(r) => r, + None => return Err(Error::ErrPeerConnRemoteDescriptionNil), + }; + // If the remote SDP was only one media section the ssrc doesn't have to be explicitly declared + let handled = self.handle_undeclared_ssrc(ssrc, &parsed).await?; + if handled { + return Ok(()); + } + + // Get MID extension ID + let (mid_extension_id, audio_supported, video_supported) = self + .media_engine + .get_header_extension_id(RTCRtpHeaderExtensionCapability { + uri: ::sdp::extmap::SDES_MID_URI.to_owned(), + }) + .await; + if !audio_supported && !video_supported { + return Err(Error::ErrPeerConnSimulcastMidRTPExtensionRequired); + } + + // Get RID extension ID + let (sid_extension_id, audio_supported, video_supported) = self + .media_engine + .get_header_extension_id(RTCRtpHeaderExtensionCapability { + uri: ::sdp::extmap::SDES_RTP_STREAM_ID_URI.to_owned(), + }) + .await; + if !audio_supported && !video_supported { + return Err(Error::ErrPeerConnSimulcastStreamIDRTPExtensionRequired); + } + + let (rsid_extension_id, _, _) = self + .media_engine + .get_header_extension_id(RTCRtpHeaderExtensionCapability { + uri: SDES_REPAIR_RTP_STREAM_ID_URI.to_owned(), + }) + .await; + + let mut buf = vec![0u8; self.setting_engine.get_receive_mtu()]; + // Packets that we read as part of simulcast probing that we need to make available + // if we do find a track later. + let mut buffered_packets: VecDeque<(rtp::packet::Packet, Attributes)> = VecDeque::default(); + + let n = rtp_stream.read(&mut buf).await?; + + let (mut mid, mut rid, mut rsid, payload_type) = handle_unknown_rtp_packet( + &buf[..n], + mid_extension_id as u8, + sid_extension_id as u8, + rsid_extension_id as u8, + )?; + + let packet = rtp::packet::Packet::unmarshal(&mut buf.as_slice()).unwrap(); + + // TODO: Can we have attributes on the first packets? + buffered_packets.push_back((packet, Attributes::new())); + + let params = self + .media_engine + .get_rtp_parameters_by_payload_type(payload_type) + .await?; + + let icpr = match self.interceptor.upgrade() { + Some(i) => i, + None => return Err(Error::ErrInterceptorNotBind), + }; + + let stream_info = create_stream_info( + "".to_owned(), + ssrc, + params.codecs[0].payload_type, + params.codecs[0].capability.clone(), + ¶ms.header_extensions, + ); + let (rtp_read_stream, rtp_interceptor, rtcp_read_stream, rtcp_interceptor) = self + .dtls_transport + .streams_for_ssrc(ssrc, &stream_info, &icpr) + .await?; + + let a = Attributes::new(); + for _ in 0..=SIMULCAST_PROBE_COUNT { + if mid.is_empty() || (rid.is_empty() && rsid.is_empty()) { + let (pkt, _) = rtp_interceptor.read(&mut buf, &a).await?; + let (m, r, rs, _) = handle_unknown_rtp_packet( + &buf[..n], + mid_extension_id as u8, + sid_extension_id as u8, + rsid_extension_id as u8, + )?; + mid = m; + rid = r; + rsid = rs; + + buffered_packets.push_back((pkt, a.clone())); + continue; + } + + let transceivers = self.rtp_transceivers.lock().await; + for t in &*transceivers { + if t.mid().as_ref() != Some(&SmolStr::from(&mid)) { + continue; + } + + let receiver = t.receiver().await; + + if !rsid.is_empty() { + return receiver + .receive_for_rtx( + 0, + rsid, + TrackStream { + stream_info: Some(stream_info.clone()), + rtp_read_stream: Some(rtp_read_stream), + rtp_interceptor: Some(rtp_interceptor), + rtcp_read_stream: Some(rtcp_read_stream), + rtcp_interceptor: Some(rtcp_interceptor), + }, + ) + .await; + } + + let track = receiver + .receive_for_rid( + SmolStr::from(rid), + params, + TrackStream { + stream_info: Some(stream_info.clone()), + rtp_read_stream: Some(rtp_read_stream), + rtp_interceptor: Some(rtp_interceptor), + rtcp_read_stream: Some(rtcp_read_stream), + rtcp_interceptor: Some(rtcp_interceptor), + }, + ) + .await?; + track.prepopulate_peeked_data(buffered_packets).await; + + RTCPeerConnection::do_track( + Arc::clone(&self.on_track_handler), + track, + receiver, + Arc::clone(t), + ); + return Ok(()); + } + } + + let _ = rtp_read_stream.close().await; + let _ = rtcp_read_stream.close().await; + icpr.unbind_remote_stream(&stream_info).await; + self.dtls_transport.remove_simulcast_stream(ssrc).await; + + Err(Error::ErrPeerConnSimulcastIncomingSSRCFailed) + } + + async fn start_receiver( + receive_mtu: usize, + incoming: &TrackDetails, + receiver: Arc, + transceiver: Arc, + on_track_handler: Arc>>, + ) { + receiver.start(incoming).await; + for t in receiver.tracks().await { + if t.ssrc() == 0 { + return; + } + + let receiver = Arc::clone(&receiver); + let transceiver = Arc::clone(&transceiver); + let on_track_handler = Arc::clone(&on_track_handler); + tokio::spawn(async move { + if let Some(track) = receiver.track().await { + let mut b = vec![0u8; receive_mtu]; + let pkt = match track.peek(&mut b).await { + Ok((pkt, _)) => pkt, + Err(err) => { + log::warn!( + "Could not determine PayloadType for SSRC {} ({})", + track.ssrc(), + err + ); + return; + } + }; + + if let Err(err) = track.check_and_update_track(&pkt).await { + log::warn!( + "Failed to set codec settings for track SSRC {} ({})", + track.ssrc(), + err + ); + return; + } + + RTCPeerConnection::do_track(on_track_handler, track, receiver, transceiver); + } + }); + } + } + + pub(super) async fn create_ice_transport(&self, api: &API) -> Arc { + let ice_transport = Arc::new(api.new_ice_transport(Arc::clone(&self.ice_gatherer))); + + let ice_connection_state = Arc::clone(&self.ice_connection_state); + let peer_connection_state = Arc::clone(&self.peer_connection_state); + let is_closed = Arc::clone(&self.is_closed); + let dtls_transport = Arc::clone(&self.dtls_transport); + let on_ice_connection_state_change_handler = + Arc::clone(&self.on_ice_connection_state_change_handler); + let on_peer_connection_state_change_handler = + Arc::clone(&self.on_peer_connection_state_change_handler); + + ice_transport.on_connection_state_change(Box::new(move |state: RTCIceTransportState| { + let cs = match state { + RTCIceTransportState::New => RTCIceConnectionState::New, + RTCIceTransportState::Checking => RTCIceConnectionState::Checking, + RTCIceTransportState::Connected => RTCIceConnectionState::Connected, + RTCIceTransportState::Completed => RTCIceConnectionState::Completed, + RTCIceTransportState::Failed => RTCIceConnectionState::Failed, + RTCIceTransportState::Disconnected => RTCIceConnectionState::Disconnected, + RTCIceTransportState::Closed => RTCIceConnectionState::Closed, + _ => { + log::warn!("on_connection_state_change: unhandled ICE state: {}", state); + return Box::pin(async {}); + } + }; + + let ice_connection_state2 = Arc::clone(&ice_connection_state); + let on_ice_connection_state_change_handler2 = + Arc::clone(&on_ice_connection_state_change_handler); + let on_peer_connection_state_change_handler2 = + Arc::clone(&on_peer_connection_state_change_handler); + let is_closed2 = Arc::clone(&is_closed); + let dtls_transport_state = dtls_transport.state(); + let peer_connection_state2 = Arc::clone(&peer_connection_state); + Box::pin(async move { + RTCPeerConnection::do_ice_connection_state_change( + &on_ice_connection_state_change_handler2, + &ice_connection_state2, + cs, + ) + .await; + + RTCPeerConnection::update_connection_state( + &on_peer_connection_state_change_handler2, + &is_closed2, + &peer_connection_state2, + cs, + dtls_transport_state, + ) + .await; + }) + })); + + ice_transport + } + + /// has_local_description_changed returns whether local media (rtp_transceivers) has changed + /// caller of this method should hold `pc.mu` lock + pub(super) async fn has_local_description_changed(&self, desc: &RTCSessionDescription) -> bool { + let rtp_transceivers = self.rtp_transceivers.lock().await; + for t in &*rtp_transceivers { + let m = match t.mid().and_then(|mid| get_by_mid(mid.as_str(), desc)) { + Some(m) => m, + None => return true, + }; + + if get_peer_direction(m) != t.direction() { + return true; + } + } + false + } + + pub(super) async fn get_stats(&self, stats_id: String) -> StatsCollector { + let collector = StatsCollector::new(); + let transceivers = { self.rtp_transceivers.lock().await.clone() }; + + tokio::join!( + self.ice_gatherer.collect_stats(&collector), + self.ice_transport.collect_stats(&collector), + self.sctp_transport.collect_stats(&collector, stats_id), + self.dtls_transport.collect_stats(&collector), + self.media_engine.collect_stats(&collector), + self.collect_inbound_stats(&collector, transceivers.clone()), + self.collect_outbound_stats(&collector, transceivers) + ); + + collector + } + + async fn collect_inbound_stats( + &self, + collector: &StatsCollector, + transceivers: Vec>, + ) { + // TODO: There's a lot of await points here that could run concurrently with `futures::join_all`. + struct TrackInfo { + ssrc: SSRC, + mid: SmolStr, + track_id: String, + kind: &'static str, + } + let mut track_infos = vec![]; + for transeiver in transceivers { + let receiver = transeiver.receiver().await; + + if let Some(mid) = transeiver.mid() { + let tracks = receiver.tracks().await; + + for track in tracks { + let track_id = track.id(); + let kind = match track.kind() { + RTPCodecType::Unspecified => continue, + RTPCodecType::Audio => "audio", + RTPCodecType::Video => "video", + }; + + track_infos.push(TrackInfo { + ssrc: track.ssrc(), + mid: mid.clone(), + track_id, + kind, + }); + } + } + } + + let stream_stats = self + .stats_interceptor + .fetch_inbound_stats(track_infos.iter().map(|t| t.ssrc).collect()) + .await; + + for (stats, info) in + (stream_stats.into_iter().zip(track_infos)).filter_map(|(s, i)| s.map(|s| (s, i))) + { + let ssrc = info.ssrc; + let kind = info.kind; + + let id = format!("RTCInboundRTP{}Stream_{}", capitalize(kind), ssrc); + let ( + packets_received, + header_bytes_received, + bytes_received, + last_packet_received_timestamp, + nack_count, + remote_packets_sent, + remote_bytes_sent, + remote_reports_sent, + remote_round_trip_time, + remote_total_round_trip_time, + remote_round_trip_time_measurements, + ) = ( + stats.packets_received(), + stats.header_bytes_received(), + stats.payload_bytes_received(), + stats.last_packet_received_timestamp(), + stats.nacks_sent(), + stats.remote_packets_sent(), + stats.remote_bytes_sent(), + stats.remote_reports_sent(), + stats.remote_round_trip_time(), + stats.remote_total_round_trip_time(), + stats.remote_round_trip_time_measurements(), + ); + + collector.insert( + id.clone(), + crate::stats::StatsReportType::InboundRTP(InboundRTPStats { + timestamp: Instant::now(), + stats_type: RTCStatsType::InboundRTP, + id: id.clone(), + ssrc, + kind, + packets_received, + track_identifier: info.track_id, + mid: info.mid, + last_packet_received_timestamp, + header_bytes_received, + bytes_received, + nack_count, + + fir_count: (info.kind == "video").then(|| stats.firs_sent()), + pli_count: (info.kind == "video").then(|| stats.plis_sent()), + }), + ); + + let local_id = id; + let id = format!( + "RTCRemoteOutboundRTP{}Stream_{}", + capitalize(info.kind), + info.ssrc + ); + collector.insert( + id.clone(), + crate::stats::StatsReportType::RemoteOutboundRTP(RemoteOutboundRTPStats { + timestamp: Instant::now(), + stats_type: RTCStatsType::RemoteOutboundRTP, + id, + + ssrc, + kind, + + packets_sent: remote_packets_sent as u64, + bytes_sent: remote_bytes_sent as u64, + local_id, + reports_sent: remote_reports_sent, + round_trip_time: remote_round_trip_time, + total_round_trip_time: remote_total_round_trip_time, + round_trip_time_measurements: remote_round_trip_time_measurements, + }), + ); + } + } + + async fn collect_outbound_stats( + &self, + collector: &StatsCollector, + transceivers: Vec>, + ) { + // TODO: There's a lot of await points here that could run concurrently with `futures::join_all`. + struct TrackInfo { + track_id: String, + ssrc: SSRC, + mid: SmolStr, + rid: Option, + kind: &'static str, + } + let mut track_infos = vec![]; + for transceiver in transceivers { + let sender = transceiver.sender().await; + + let mid = match transceiver.mid() { + Some(mid) => mid, + None => continue, + }; + + let track = match sender.track().await { + Some(track) => track, + None => continue, + }; + + let track_id = track.id().to_string(); + let kind = match track.kind() { + RTPCodecType::Unspecified => continue, + RTPCodecType::Audio => "audio", + RTPCodecType::Video => "video", + }; + + track_infos.push(TrackInfo { + track_id, + ssrc: sender.ssrc, + mid, + rid: None, + kind, + }); + } + + let stream_stats = self + .stats_interceptor + .fetch_outbound_stats(track_infos.iter().map(|t| t.ssrc).collect()) + .await; + + for (stats, info) in stream_stats + .into_iter() + .zip(track_infos) + .filter_map(|(s, i)| s.map(|s| (s, i))) + { + // RTCOutboundRtpStreamStats + let id = format!( + "RTCOutboundRTP{}Stream_{}", + capitalize(info.kind), + info.ssrc + ); + let ( + packets_sent, + bytes_sent, + header_bytes_sent, + nack_count, + remote_inbound_packets_received, + remote_inbound_packets_lost, + remote_rtt_ms, + remote_total_rtt_ms, + remote_rtt_measurements, + remote_fraction_lost, + ) = ( + stats.packets_sent(), + stats.payload_bytes_sent(), + stats.header_bytes_sent(), + stats.nacks_received(), + stats.remote_packets_received(), + stats.remote_total_lost(), + stats.remote_round_trip_time(), + stats.remote_total_round_trip_time(), + stats.remote_round_trip_time_measurements(), + stats.remote_fraction_lost(), + ); + + let TrackInfo { + mid, + ssrc, + rid, + kind, + track_id: track_identifier, + } = info; + + collector.insert( + id.clone(), + crate::stats::StatsReportType::OutboundRTP(OutboundRTPStats { + timestamp: Instant::now(), + stats_type: RTCStatsType::OutboundRTP, + track_identifier, + id: id.clone(), + ssrc, + kind, + packets_sent, + mid, + rid, + header_bytes_sent, + bytes_sent, + nack_count, + + fir_count: (info.kind == "video").then(|| stats.firs_received()), + pli_count: (info.kind == "video").then(|| stats.plis_received()), + }), + ); + + let local_id = id; + let id = format!( + "RTCRemoteInboundRTP{}Stream_{}", + capitalize(info.kind), + info.ssrc + ); + + collector.insert( + id.clone(), + StatsReportType::RemoteInboundRTP(RemoteInboundRTPStats { + timestamp: Instant::now(), + stats_type: RTCStatsType::RemoteInboundRTP, + id, + ssrc, + kind, + + packets_received: remote_inbound_packets_received, + packets_lost: remote_inbound_packets_lost as i64, + + local_id, + + round_trip_time: remote_rtt_ms, + total_round_trip_time: remote_total_rtt_ms, + fraction_lost: remote_fraction_lost.unwrap_or(0.0), + round_trip_time_measurements: remote_rtt_measurements, + }), + ); + } + } +} + +type IResult = std::result::Result; + +#[async_trait] +impl RTCPWriter for PeerConnectionInternal { + async fn write( + &self, + pkts: &[Box], + _a: &Attributes, + ) -> IResult { + Ok(self.dtls_transport.write_rtcp(pkts).await?) + } +} + +fn capitalize(s: &str) -> String { + let first = s + .chars() + .next() + .expect("Must have at least one character to uppercase") + .to_uppercase(); + let mut result = String::new(); + + result.extend(first); + result.extend(s.chars().skip(1)); + + result +} diff --git a/rtc/src/peer_connection/peer_connection_state.rs b/rtc/src/peer_connection/peer_connection_state.rs new file mode 100644 index 0000000..905e267 --- /dev/null +++ b/rtc/src/peer_connection/peer_connection_state.rs @@ -0,0 +1,151 @@ +use std::fmt; + +/// PeerConnectionState indicates the state of the PeerConnection. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RTCPeerConnectionState { + #[default] + Unspecified, + + /// PeerConnectionStateNew indicates that any of the ICETransports or + /// DTLSTransports are in the "new" state and none of the transports are + /// in the "connecting", "checking", "failed" or "disconnected" state, or + /// all transports are in the "closed" state, or there are no transports. + New, + + /// PeerConnectionStateConnecting indicates that any of the + /// ICETransports or DTLSTransports are in the "connecting" or + /// "checking" state and none of them is in the "failed" state. + Connecting, + + /// PeerConnectionStateConnected indicates that all ICETransports and + /// DTLSTransports are in the "connected", "completed" or "closed" state + /// and at least one of them is in the "connected" or "completed" state. + Connected, + + /// PeerConnectionStateDisconnected indicates that any of the + /// ICETransports or DTLSTransports are in the "disconnected" state + /// and none of them are in the "failed" or "connecting" or "checking" state. + Disconnected, + + /// PeerConnectionStateFailed indicates that any of the ICETransports + /// or DTLSTransports are in a "failed" state. + Failed, + + /// PeerConnectionStateClosed indicates the peer connection is closed + /// and the isClosed member variable of PeerConnection is true. + Closed, +} + +const PEER_CONNECTION_STATE_NEW_STR: &str = "new"; +const PEER_CONNECTION_STATE_CONNECTING_STR: &str = "connecting"; +const PEER_CONNECTION_STATE_CONNECTED_STR: &str = "connected"; +const PEER_CONNECTION_STATE_DISCONNECTED_STR: &str = "disconnected"; +const PEER_CONNECTION_STATE_FAILED_STR: &str = "failed"; +const PEER_CONNECTION_STATE_CLOSED_STR: &str = "closed"; + +impl From<&str> for RTCPeerConnectionState { + fn from(raw: &str) -> Self { + match raw { + PEER_CONNECTION_STATE_NEW_STR => RTCPeerConnectionState::New, + PEER_CONNECTION_STATE_CONNECTING_STR => RTCPeerConnectionState::Connecting, + PEER_CONNECTION_STATE_CONNECTED_STR => RTCPeerConnectionState::Connected, + PEER_CONNECTION_STATE_DISCONNECTED_STR => RTCPeerConnectionState::Disconnected, + PEER_CONNECTION_STATE_FAILED_STR => RTCPeerConnectionState::Failed, + PEER_CONNECTION_STATE_CLOSED_STR => RTCPeerConnectionState::Closed, + _ => RTCPeerConnectionState::Unspecified, + } + } +} + +impl From for RTCPeerConnectionState { + fn from(v: u8) -> Self { + match v { + 1 => RTCPeerConnectionState::New, + 2 => RTCPeerConnectionState::Connecting, + 3 => RTCPeerConnectionState::Connected, + 4 => RTCPeerConnectionState::Disconnected, + 5 => RTCPeerConnectionState::Failed, + 6 => RTCPeerConnectionState::Closed, + _ => RTCPeerConnectionState::Unspecified, + } + } +} + +impl fmt::Display for RTCPeerConnectionState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + RTCPeerConnectionState::New => PEER_CONNECTION_STATE_NEW_STR, + RTCPeerConnectionState::Connecting => PEER_CONNECTION_STATE_CONNECTING_STR, + RTCPeerConnectionState::Connected => PEER_CONNECTION_STATE_CONNECTED_STR, + RTCPeerConnectionState::Disconnected => PEER_CONNECTION_STATE_DISCONNECTED_STR, + RTCPeerConnectionState::Failed => PEER_CONNECTION_STATE_FAILED_STR, + RTCPeerConnectionState::Closed => PEER_CONNECTION_STATE_CLOSED_STR, + RTCPeerConnectionState::Unspecified => crate::UNSPECIFIED_STR, + }; + write!(f, "{s}") + } +} + +#[derive(Default, Debug, Copy, Clone, PartialEq)] +pub(crate) enum NegotiationNeededState { + /// NegotiationNeededStateEmpty not running and queue is empty + #[default] + Empty, + /// NegotiationNeededStateEmpty running and queue is empty + Run, + /// NegotiationNeededStateEmpty running and queue + Queue, +} + +impl From for NegotiationNeededState { + fn from(v: u8) -> Self { + match v { + 1 => NegotiationNeededState::Run, + 2 => NegotiationNeededState::Queue, + _ => NegotiationNeededState::Empty, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_peer_connection_state() { + let tests = vec![ + (crate::UNSPECIFIED_STR, RTCPeerConnectionState::Unspecified), + ("new", RTCPeerConnectionState::New), + ("connecting", RTCPeerConnectionState::Connecting), + ("connected", RTCPeerConnectionState::Connected), + ("disconnected", RTCPeerConnectionState::Disconnected), + ("failed", RTCPeerConnectionState::Failed), + ("closed", RTCPeerConnectionState::Closed), + ]; + + for (state_string, expected_state) in tests { + assert_eq!( + RTCPeerConnectionState::from(state_string), + expected_state, + "testCase: {expected_state}", + ); + } + } + + #[test] + fn test_peer_connection_state_string() { + let tests = vec![ + (RTCPeerConnectionState::Unspecified, crate::UNSPECIFIED_STR), + (RTCPeerConnectionState::New, "new"), + (RTCPeerConnectionState::Connecting, "connecting"), + (RTCPeerConnectionState::Connected, "connected"), + (RTCPeerConnectionState::Disconnected, "disconnected"), + (RTCPeerConnectionState::Failed, "failed"), + (RTCPeerConnectionState::Closed, "closed"), + ]; + + for (state, expected_string) in tests { + assert_eq!(state.to_string(), expected_string) + } + } +} diff --git a/rtc/src/peer_connection/peer_connection_test.rs b/rtc/src/peer_connection/peer_connection_test.rs new file mode 100644 index 0000000..6073054 --- /dev/null +++ b/rtc/src/peer_connection/peer_connection_test.rs @@ -0,0 +1,423 @@ +use std::sync::atomic::AtomicU32; +use std::sync::Arc; + +use bytes::Bytes; +use interceptor::registry::Registry; +use media::Sample; +use tokio::time::Duration; +use util::vnet::net::{Net, NetConfig}; +use util::vnet::router::{Router, RouterConfig}; +use waitgroup::WaitGroup; + +use super::*; +use crate::api::interceptor_registry::register_default_interceptors; +use crate::api::media_engine::{MediaEngine, MIME_TYPE_VP8}; +use crate::api::APIBuilder; +use crate::ice_transport::ice_candidate_pair::RTCIceCandidatePair; +use crate::ice_transport::ice_server::RTCIceServer; +use crate::peer_connection::configuration::RTCConfiguration; +use crate::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; +use crate::stats::StatsReportType; +use crate::track::track_local::track_local_static_sample::TrackLocalStaticSample; +use crate::Error; + +pub(crate) async fn create_vnet_pair( +) -> Result<(RTCPeerConnection, RTCPeerConnection, Arc>)> { + // Create a root router + let wan = Arc::new(Mutex::new(Router::new(RouterConfig { + cidr: "1.2.3.0/24".to_owned(), + ..Default::default() + })?)); + + // Create a network interface for offerer + let offer_vnet = Arc::new(Net::new(Some(NetConfig { + static_ips: vec!["1.2.3.4".to_owned()], + ..Default::default() + }))); + + // Add the network interface to the router + let nic = offer_vnet.get_nic()?; + { + let mut w = wan.lock().await; + w.add_net(Arc::clone(&nic)).await?; + } + { + let n = nic.lock().await; + n.set_router(Arc::clone(&wan)).await?; + } + + let mut offer_setting_engine = SettingEngine::default(); + offer_setting_engine.set_vnet(Some(offer_vnet)); + offer_setting_engine.set_ice_timeouts( + Some(Duration::from_secs(1)), + Some(Duration::from_secs(1)), + Some(Duration::from_millis(200)), + ); + + // Create a network interface for answerer + let answer_vnet = Arc::new(Net::new(Some(NetConfig { + static_ips: vec!["1.2.3.5".to_owned()], + ..Default::default() + }))); + + // Add the network interface to the router + let nic = answer_vnet.get_nic()?; + { + let mut w = wan.lock().await; + w.add_net(Arc::clone(&nic)).await?; + } + { + let n = nic.lock().await; + n.set_router(Arc::clone(&wan)).await?; + } + + let mut answer_setting_engine = SettingEngine::default(); + answer_setting_engine.set_vnet(Some(answer_vnet)); + answer_setting_engine.set_ice_timeouts( + Some(Duration::from_secs(1)), + Some(Duration::from_secs(1)), + Some(Duration::from_millis(200)), + ); + + // Start the virtual network by calling Start() on the root router + { + let mut w = wan.lock().await; + w.start().await?; + } + + let mut offer_media_engine = MediaEngine::default(); + offer_media_engine.register_default_codecs()?; + let offer_peer_connection = APIBuilder::new() + .with_setting_engine(offer_setting_engine) + .with_media_engine(offer_media_engine) + .build() + .new_peer_connection(RTCConfiguration::default()) + .await?; + + let mut answer_media_engine = MediaEngine::default(); + answer_media_engine.register_default_codecs()?; + let answer_peer_connection = APIBuilder::new() + .with_setting_engine(answer_setting_engine) + .with_media_engine(answer_media_engine) + .build() + .new_peer_connection(RTCConfiguration::default()) + .await?; + + Ok((offer_peer_connection, answer_peer_connection, wan)) +} + +/// new_pair creates two new peer connections (an offerer and an answerer) +/// *without* using an api (i.e. using the default settings). +pub(crate) async fn new_pair(api: &API) -> Result<(RTCPeerConnection, RTCPeerConnection)> { + let pca = api.new_peer_connection(RTCConfiguration::default()).await?; + let pcb = api.new_peer_connection(RTCConfiguration::default()).await?; + + Ok((pca, pcb)) +} + +pub(crate) async fn signal_pair( + pc_offer: &mut RTCPeerConnection, + pc_answer: &mut RTCPeerConnection, +) -> Result<()> { + // Note(albrow): We need to create a data channel in order to trigger ICE + // candidate gathering in the background for the JavaScript/Wasm bindings. If + // we don't do this, the complete offer including ICE candidates will never be + // generated. + pc_offer + .create_data_channel("initial_data_channel", None) + .await?; + + let offer = pc_offer.create_offer(None).await?; + + let mut offer_gathering_complete = pc_offer.gathering_complete_promise().await; + pc_offer.set_local_description(offer).await?; + + let _ = offer_gathering_complete.recv().await; + + pc_answer + .set_remote_description( + pc_offer + .local_description() + .await + .ok_or(Error::new("non local description".to_owned()))?, + ) + .await?; + + let answer = pc_answer.create_answer(None).await?; + + let mut answer_gathering_complete = pc_answer.gathering_complete_promise().await; + pc_answer.set_local_description(answer).await?; + + let _ = answer_gathering_complete.recv().await; + + pc_offer + .set_remote_description( + pc_answer + .local_description() + .await + .ok_or(Error::new("non local description".to_owned()))?, + ) + .await +} + +pub(crate) async fn close_pair_now(pc1: &RTCPeerConnection, pc2: &RTCPeerConnection) { + let mut fail = false; + if let Err(err) = pc1.close().await { + log::error!("Failed to close PeerConnection: {}", err); + fail = true; + } + if let Err(err) = pc2.close().await { + log::error!("Failed to close PeerConnection: {}", err); + fail = true; + } + + assert!(!fail); +} + +pub(crate) async fn close_pair( + pc1: &RTCPeerConnection, + pc2: &RTCPeerConnection, + mut done_rx: mpsc::Receiver<()>, +) { + let timeout = tokio::time::sleep(Duration::from_secs(10)); + tokio::pin!(timeout); + + tokio::select! { + _ = timeout.as_mut() =>{ + panic!("close_pair timed out waiting for done signal"); + } + _ = done_rx.recv() =>{ + close_pair_now(pc1, pc2).await; + } + } +} + +/* +func offerMediaHasDirection(offer SessionDescription, kind RTPCodecType, direction RTPTransceiverDirection) bool { + parsed := &sdp.SessionDescription{} + if err := parsed.Unmarshal([]byte(offer.SDP)); err != nil { + return false + } + + for _, media := range parsed.MediaDescriptions { + if media.MediaName.Media == kind.String() { + _, exists := media.Attribute(direction.String()) + return exists + } + } + return false +}*/ + +pub(crate) async fn send_video_until_done( + mut done_rx: mpsc::Receiver<()>, + tracks: Vec>, + data: Bytes, + max_sends: Option, +) -> bool { + let mut sends = 0; + + loop { + let timeout = tokio::time::sleep(Duration::from_millis(20)); + tokio::pin!(timeout); + + tokio::select! { + biased; + + _ = done_rx.recv() =>{ + log::debug!("sendVideoUntilDone received done"); + return false; + } + + _ = timeout.as_mut() =>{ + if max_sends.map(|s| sends >= s).unwrap_or(false) { + continue; + } + + log::debug!("sendVideoUntilDone timeout"); + for track in &tracks { + log::debug!("sendVideoUntilDone track.WriteSample"); + let result = track.write_sample(&Sample{ + data: data.clone(), + duration: Duration::from_secs(1), + ..Default::default() + }).await; + assert!(result.is_ok()); + sends += 1; + } + } + } + } +} + +pub(crate) async fn until_connection_state( + pc: &mut RTCPeerConnection, + wg: &WaitGroup, + state: RTCPeerConnectionState, +) { + let w = Arc::new(Mutex::new(Some(wg.worker()))); + pc.on_peer_connection_state_change(Box::new(move |pcs: RTCPeerConnectionState| { + let w2 = Arc::clone(&w); + Box::pin(async move { + if pcs == state { + let mut worker = w2.lock().await; + worker.take(); + } + }) + })); +} + +#[tokio::test] +async fn test_get_stats() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (mut pc_offer, mut pc_answer) = new_pair(&api).await?; + + let (ice_complete_tx, mut ice_complete_rx) = mpsc::channel::<()>(1); + let ice_complete_tx = Arc::new(Mutex::new(Some(ice_complete_tx))); + pc_answer.on_ice_connection_state_change(Box::new(move |ice_state: RTCIceConnectionState| { + let ice_complete_tx2 = Arc::clone(&ice_complete_tx); + Box::pin(async move { + if ice_state == RTCIceConnectionState::Connected { + tokio::time::sleep(Duration::from_secs(1)).await; + let mut done = ice_complete_tx2.lock().await; + done.take(); + } + }) + })); + + let sender_called_candidate_change = Arc::new(AtomicU32::new(0)); + let sender_called_candidate_change2 = Arc::clone(&sender_called_candidate_change); + pc_offer + .sctp() + .transport() + .ice_transport() + .on_selected_candidate_pair_change(Box::new(move |_: RTCIceCandidatePair| { + sender_called_candidate_change2.store(1, Ordering::SeqCst); + Box::pin(async {}) + })); + let track = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + pc_offer + .add_track(track.clone()) + .await + .expect("Failed to add track"); + let (packet_tx, packet_rx) = mpsc::channel(1); + + pc_answer.on_track(Box::new(move |track, _, _| { + let packet_tx = packet_tx.clone(); + tokio::spawn(async move { + while let Ok((pkt, _)) = track.read_rtp().await { + dbg!(&pkt); + let last = pkt.payload[pkt.payload.len() - 1]; + + if last == 0xAA { + let _ = packet_tx.send(()).await; + break; + } + } + }); + + Box::pin(async move {}) + })); + + signal_pair(&mut pc_offer, &mut pc_answer).await?; + + let _ = ice_complete_rx.recv().await; + send_video_until_done( + packet_rx, + vec![track], + Bytes::from_static(b"\xDE\xAD\xBE\xEF\xAA"), + Some(1), + ) + .await; + + let offer_stats = pc_offer.get_stats().await; + assert!(!offer_stats.reports.is_empty()); + + match offer_stats.reports.get("ice_transport") { + Some(StatsReportType::Transport(ice_transport_stats)) => { + assert!(ice_transport_stats.bytes_received > 0); + assert!(ice_transport_stats.bytes_sent > 0); + } + Some(_other) => panic!("found the wrong type"), + None => panic!("missed it"), + } + let outbound_stats = offer_stats + .reports + .values() + .find_map(|v| match v { + StatsReportType::OutboundRTP(d) => Some(d), + _ => None, + }) + .expect("Should have produced an RTP Outbound stat"); + assert_eq!(outbound_stats.packets_sent, 1); + assert_eq!(outbound_stats.kind, "video"); + assert_eq!(outbound_stats.bytes_sent, 8); + assert_eq!(outbound_stats.header_bytes_sent, 12); + + let answer_stats = pc_answer.get_stats().await; + let inbound_stats = answer_stats + .reports + .values() + .find_map(|v| match v { + StatsReportType::InboundRTP(d) => Some(d), + _ => None, + }) + .expect("Should have produced an RTP inbound stat"); + assert_eq!(inbound_stats.packets_received, 1); + assert_eq!(inbound_stats.kind, "video"); + assert_eq!(inbound_stats.bytes_received, 8); + assert_eq!(inbound_stats.header_bytes_received, 12); + + close_pair_now(&pc_offer, &pc_answer).await; + + Ok(()) +} + +#[tokio::test] +async fn test_peer_connection_close_is_send() -> Result<()> { + let handle = tokio::spawn(async move { peer().await }); + tokio::join!(handle).0.unwrap() +} + +async fn peer() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let mut registry = Registry::new(); + registry = register_default_interceptors(registry, &mut m)?; + let api = APIBuilder::new() + .with_media_engine(m) + .with_interceptor_registry(registry) + .build(); + + let config = RTCConfiguration { + ice_servers: vec![RTCIceServer { + urls: vec!["stun:stun.l.google.com:19302".to_owned()], + ..Default::default() + }], + ..Default::default() + }; + + let peer_connection = Arc::new(api.new_peer_connection(config).await?); + + let offer = peer_connection.create_offer(None).await?; + let mut gather_complete = peer_connection.gathering_complete_promise().await; + peer_connection.set_local_description(offer).await?; + let _ = gather_complete.recv().await; + + if peer_connection.local_description().await.is_some() { + //TODO? + } + + peer_connection.close().await?; + + Ok(()) +} diff --git a/rtc/src/peer_connection/policy/bundle_policy.rs b/rtc/src/peer_connection/policy/bundle_policy.rs new file mode 100644 index 0000000..040228f --- /dev/null +++ b/rtc/src/peer_connection/policy/bundle_policy.rs @@ -0,0 +1,94 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +/// BundlePolicy affects which media tracks are negotiated if the remote +/// endpoint is not bundle-aware, and what ICE candidates are gathered. If the +/// remote endpoint is bundle-aware, all media tracks and data channels are +/// bundled onto the same transport. +#[derive(Default, Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub enum RTCBundlePolicy { + #[default] + Unspecified = 0, + + /// BundlePolicyBalanced indicates to gather ICE candidates for each + /// media type in use (audio, video, and data). If the remote endpoint is + /// not bundle-aware, negotiate only one audio and video track on separate + /// transports. + #[serde(rename = "balanced")] + Balanced = 1, + + /// BundlePolicyMaxCompat indicates to gather ICE candidates for each + /// track. If the remote endpoint is not bundle-aware, negotiate all media + /// tracks on separate transports. + #[serde(rename = "max-compat")] + MaxCompat = 2, + + /// BundlePolicyMaxBundle indicates to gather ICE candidates for only + /// one track. If the remote endpoint is not bundle-aware, negotiate only + /// one media track. + #[serde(rename = "max-bundle")] + MaxBundle = 3, +} + +/// This is done this way because of a linter. +const BUNDLE_POLICY_BALANCED_STR: &str = "balanced"; +const BUNDLE_POLICY_MAX_COMPAT_STR: &str = "max-compat"; +const BUNDLE_POLICY_MAX_BUNDLE_STR: &str = "max-bundle"; + +impl From<&str> for RTCBundlePolicy { + /// NewSchemeType defines a procedure for creating a new SchemeType from a raw + /// string naming the scheme type. + fn from(raw: &str) -> Self { + match raw { + BUNDLE_POLICY_BALANCED_STR => RTCBundlePolicy::Balanced, + BUNDLE_POLICY_MAX_COMPAT_STR => RTCBundlePolicy::MaxCompat, + BUNDLE_POLICY_MAX_BUNDLE_STR => RTCBundlePolicy::MaxBundle, + _ => RTCBundlePolicy::Unspecified, + } + } +} + +impl fmt::Display for RTCBundlePolicy { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + RTCBundlePolicy::Balanced => write!(f, "{BUNDLE_POLICY_BALANCED_STR}"), + RTCBundlePolicy::MaxCompat => write!(f, "{BUNDLE_POLICY_MAX_COMPAT_STR}"), + RTCBundlePolicy::MaxBundle => write!(f, "{BUNDLE_POLICY_MAX_BUNDLE_STR}"), + _ => write!(f, "{}", crate::UNSPECIFIED_STR), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_bundle_policy() { + let tests = vec![ + ("Unspecified", RTCBundlePolicy::Unspecified), + ("balanced", RTCBundlePolicy::Balanced), + ("max-compat", RTCBundlePolicy::MaxCompat), + ("max-bundle", RTCBundlePolicy::MaxBundle), + ]; + + for (policy_string, expected_policy) in tests { + assert_eq!(RTCBundlePolicy::from(policy_string), expected_policy); + } + } + + #[test] + fn test_bundle_policy_string() { + let tests = vec![ + (RTCBundlePolicy::Unspecified, "Unspecified"), + (RTCBundlePolicy::Balanced, "balanced"), + (RTCBundlePolicy::MaxCompat, "max-compat"), + (RTCBundlePolicy::MaxBundle, "max-bundle"), + ]; + + for (policy, expected_string) in tests { + assert_eq!(policy.to_string(), expected_string); + } + } +} diff --git a/rtc/src/peer_connection/policy/ice_transport_policy.rs b/rtc/src/peer_connection/policy/ice_transport_policy.rs new file mode 100644 index 0000000..331ebe6 --- /dev/null +++ b/rtc/src/peer_connection/policy/ice_transport_policy.rs @@ -0,0 +1,77 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +/// ICETransportPolicy defines the ICE candidate policy surface the +/// permitted candidates. Only these candidates are used for connectivity checks. +#[derive(Default, Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub enum RTCIceTransportPolicy { + #[default] + Unspecified = 0, + + /// ICETransportPolicyAll indicates any type of candidate is used. + #[serde(rename = "all")] + All = 1, + + /// ICETransportPolicyRelay indicates only media relay candidates such + /// as candidates passing through a TURN server are used. + #[serde(rename = "relay")] + Relay = 2, +} + +/// ICEGatherPolicy is the ORTC equivalent of ICETransportPolicy +pub type ICEGatherPolicy = RTCIceTransportPolicy; + +const ICE_TRANSPORT_POLICY_RELAY_STR: &str = "relay"; +const ICE_TRANSPORT_POLICY_ALL_STR: &str = "all"; + +/// takes a string and converts it to ICETransportPolicy +impl From<&str> for RTCIceTransportPolicy { + fn from(raw: &str) -> Self { + match raw { + ICE_TRANSPORT_POLICY_RELAY_STR => RTCIceTransportPolicy::Relay, + ICE_TRANSPORT_POLICY_ALL_STR => RTCIceTransportPolicy::All, + _ => RTCIceTransportPolicy::Unspecified, + } + } +} + +impl fmt::Display for RTCIceTransportPolicy { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + RTCIceTransportPolicy::Relay => ICE_TRANSPORT_POLICY_RELAY_STR, + RTCIceTransportPolicy::All => ICE_TRANSPORT_POLICY_ALL_STR, + RTCIceTransportPolicy::Unspecified => crate::UNSPECIFIED_STR, + }; + write!(f, "{s}") + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_ice_transport_policy() { + let tests = vec![ + ("relay", RTCIceTransportPolicy::Relay), + ("all", RTCIceTransportPolicy::All), + ]; + + for (policy_string, expected_policy) in tests { + assert_eq!(RTCIceTransportPolicy::from(policy_string), expected_policy); + } + } + + #[test] + fn test_ice_transport_policy_string() { + let tests = vec![ + (RTCIceTransportPolicy::Relay, "relay"), + (RTCIceTransportPolicy::All, "all"), + ]; + + for (policy, expected_string) in tests { + assert_eq!(policy.to_string(), expected_string); + } + } +} diff --git a/rtc/src/peer_connection/policy/mod.rs b/rtc/src/peer_connection/policy/mod.rs new file mode 100644 index 0000000..82036c5 --- /dev/null +++ b/rtc/src/peer_connection/policy/mod.rs @@ -0,0 +1,4 @@ +pub mod bundle_policy; +pub mod ice_transport_policy; +pub mod rtcp_mux_policy; +pub mod sdp_semantics; diff --git a/rtc/src/peer_connection/policy/rtcp_mux_policy.rs b/rtc/src/peer_connection/policy/rtcp_mux_policy.rs new file mode 100644 index 0000000..35eccca --- /dev/null +++ b/rtc/src/peer_connection/policy/rtcp_mux_policy.rs @@ -0,0 +1,79 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +/// RTCPMuxPolicy affects what ICE candidates are gathered to support +/// non-multiplexed RTCP. +#[derive(Default, Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub enum RTCRtcpMuxPolicy { + #[default] + Unspecified = 0, + + /// RTCPMuxPolicyNegotiate indicates to gather ICE candidates for both + /// RTP and RTCP candidates. If the remote-endpoint is capable of + /// multiplexing RTCP, multiplex RTCP on the RTP candidates. If it is not, + /// use both the RTP and RTCP candidates separately. + #[serde(rename = "negotiate")] + Negotiate = 1, + + /// RTCPMuxPolicyRequire indicates to gather ICE candidates only for + /// RTP and multiplex RTCP on the RTP candidates. If the remote endpoint is + /// not capable of rtcp-mux, session negotiation will fail. + #[serde(rename = "require")] + Require = 2, +} + +const RTCP_MUX_POLICY_NEGOTIATE_STR: &str = "negotiate"; +const RTCP_MUX_POLICY_REQUIRE_STR: &str = "require"; + +impl From<&str> for RTCRtcpMuxPolicy { + fn from(raw: &str) -> Self { + match raw { + RTCP_MUX_POLICY_NEGOTIATE_STR => RTCRtcpMuxPolicy::Negotiate, + RTCP_MUX_POLICY_REQUIRE_STR => RTCRtcpMuxPolicy::Require, + _ => RTCRtcpMuxPolicy::Unspecified, + } + } +} + +impl fmt::Display for RTCRtcpMuxPolicy { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + RTCRtcpMuxPolicy::Negotiate => RTCP_MUX_POLICY_NEGOTIATE_STR, + RTCRtcpMuxPolicy::Require => RTCP_MUX_POLICY_REQUIRE_STR, + RTCRtcpMuxPolicy::Unspecified => crate::UNSPECIFIED_STR, + }; + write!(f, "{s}") + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_rtcp_mux_policy() { + let tests = vec![ + ("Unspecified", RTCRtcpMuxPolicy::Unspecified), + ("negotiate", RTCRtcpMuxPolicy::Negotiate), + ("require", RTCRtcpMuxPolicy::Require), + ]; + + for (policy_string, expected_policy) in tests { + assert_eq!(RTCRtcpMuxPolicy::from(policy_string), expected_policy); + } + } + + #[test] + fn test_rtcp_mux_policy_string() { + let tests = vec![ + (RTCRtcpMuxPolicy::Unspecified, "Unspecified"), + (RTCRtcpMuxPolicy::Negotiate, "negotiate"), + (RTCRtcpMuxPolicy::Require, "require"), + ]; + + for (policy, expected_string) in tests { + assert_eq!(policy.to_string(), expected_string); + } + } +} diff --git a/rtc/src/peer_connection/policy/sdp_semantics.rs b/rtc/src/peer_connection/policy/sdp_semantics.rs new file mode 100644 index 0000000..4fe510f --- /dev/null +++ b/rtc/src/peer_connection/policy/sdp_semantics.rs @@ -0,0 +1,112 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +/// SDPSemantics determines which style of SDP offers and answers +/// can be used. +/// +/// This is unused, we only support UnifiedPlan. +#[derive(Default, Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub enum RTCSdpSemantics { + Unspecified = 0, + + /// UnifiedPlan uses unified-plan offers and answers + /// (the default in Chrome since M72) + /// + #[serde(rename = "unified-plan")] + #[default] + UnifiedPlan = 1, + + /// PlanB uses plan-b offers and answers + /// NB: This format should be considered deprecated + /// + #[serde(rename = "plan-b")] + PlanB = 2, + + /// UnifiedPlanWithFallback prefers unified-plan + /// offers and answers, but will respond to a plan-b offer + /// with a plan-b answer + #[serde(rename = "unified-plan-with-fallback")] + UnifiedPlanWithFallback = 3, +} + +const SDP_SEMANTICS_UNIFIED_PLAN_WITH_FALLBACK: &str = "unified-plan-with-fallback"; +const SDP_SEMANTICS_UNIFIED_PLAN: &str = "unified-plan"; +const SDP_SEMANTICS_PLAN_B: &str = "plan-b"; + +impl From<&str> for RTCSdpSemantics { + fn from(raw: &str) -> Self { + match raw { + SDP_SEMANTICS_UNIFIED_PLAN_WITH_FALLBACK => RTCSdpSemantics::UnifiedPlanWithFallback, + SDP_SEMANTICS_UNIFIED_PLAN => RTCSdpSemantics::UnifiedPlan, + SDP_SEMANTICS_PLAN_B => RTCSdpSemantics::PlanB, + _ => RTCSdpSemantics::Unspecified, + } + } +} + +impl fmt::Display for RTCSdpSemantics { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + RTCSdpSemantics::UnifiedPlanWithFallback => SDP_SEMANTICS_UNIFIED_PLAN_WITH_FALLBACK, + RTCSdpSemantics::UnifiedPlan => SDP_SEMANTICS_UNIFIED_PLAN, + RTCSdpSemantics::PlanB => SDP_SEMANTICS_PLAN_B, + RTCSdpSemantics::Unspecified => crate::UNSPECIFIED_STR, + }; + write!(f, "{s}") + } +} + +#[cfg(test)] +mod test { + use std::collections::HashSet; + + use sdp::description::media::MediaDescription; + use sdp::description::session::{SessionDescription, ATTR_KEY_SSRC}; + + use super::*; + + #[test] + fn test_sdp_semantics_string() { + let tests = vec![ + (RTCSdpSemantics::Unspecified, "Unspecified"), + ( + RTCSdpSemantics::UnifiedPlanWithFallback, + "unified-plan-with-fallback", + ), + (RTCSdpSemantics::PlanB, "plan-b"), + (RTCSdpSemantics::UnifiedPlan, "unified-plan"), + ]; + + for (value, expected_string) in tests { + assert_eq!(value.to_string(), expected_string); + } + } + + // The following tests are for non-standard SDP semantics + // (i.e. not unified-unified) + fn get_md_names(sdp: &SessionDescription) -> Vec { + sdp.media_descriptions + .iter() + .map(|md| md.media_name.media.clone()) + .collect() + } + + fn extract_ssrc_list(md: &MediaDescription) -> Vec { + let mut ssrcs = HashSet::new(); + for attr in &md.attributes { + if attr.key == ATTR_KEY_SSRC { + if let Some(value) = &attr.value { + let fields: Vec<&str> = value.split_whitespace().collect(); + if let Some(ssrc) = fields.first() { + ssrcs.insert(*ssrc); + } + } + } + } + ssrcs + .into_iter() + .map(|ssrc| ssrc.to_owned()) + .collect::>() + } +} diff --git a/rtc/src/peer_connection/sdp/mod.rs b/rtc/src/peer_connection/sdp/mod.rs new file mode 100644 index 0000000..21b8275 --- /dev/null +++ b/rtc/src/peer_connection/sdp/mod.rs @@ -0,0 +1,1096 @@ +#[cfg(test)] +mod sdp_test; + +use crate::api::media_engine::MediaEngine; +use crate::dtls_transport::dtls_fingerprint::RTCDtlsFingerprint; +use crate::error::{Error, Result}; +use crate::ice_transport::ice_candidate::RTCIceCandidate; +use crate::ice_transport::ice_gatherer::RTCIceGatherer; +use crate::ice_transport::ice_gathering_state::RTCIceGatheringState; +use crate::ice_transport::ice_parameters::RTCIceParameters; +use crate::rtp_transceiver::rtp_codec::{ + RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType, +}; +use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; +use crate::rtp_transceiver::{PayloadType, RTCPFeedback, RTCRtpTransceiver, SSRC}; + +pub mod sdp_type; +pub mod session_description; + +use std::collections::HashMap; +use std::convert::From; +use std::io::BufReader; +use std::sync::Arc; + +use ice::candidate::candidate_base::unmarshal_candidate; +use ice::candidate::Candidate; +use sdp::description::common::{Address, ConnectionInformation}; +use sdp::description::media::{MediaDescription, MediaName, RangedPort}; +use sdp::description::session::*; +use sdp::extmap::ExtMap; +use sdp::util::ConnectionRole; +use smol_str::SmolStr; +use url::Url; + +use crate::peer_connection::MEDIA_SECTION_APPLICATION; +use crate::{SDP_ATTRIBUTE_RID, SDP_ATTRIBUTE_SIMULCAST}; + +/// TrackDetails represents any media source that can be represented in a SDP +/// This isn't keyed by SSRC because it also needs to support rid based sources +#[derive(Default, Debug, Clone)] +pub(crate) struct TrackDetails { + pub(crate) mid: SmolStr, + pub(crate) kind: RTPCodecType, + pub(crate) stream_id: String, + pub(crate) id: String, + pub(crate) ssrcs: Vec, + pub(crate) repair_ssrc: SSRC, + pub(crate) rids: Vec, +} + +pub(crate) fn track_details_for_ssrc( + track_details: &[TrackDetails], + ssrc: SSRC, +) -> Option<&TrackDetails> { + track_details.iter().find(|x| x.ssrcs.contains(&ssrc)) +} + +pub(crate) fn track_details_for_rid( + track_details: &[TrackDetails], + rid: SmolStr, +) -> Option<&TrackDetails> { + track_details.iter().find(|x| x.rids.contains(&rid)) +} + +pub(crate) fn filter_track_with_ssrc(incoming_tracks: &mut Vec, ssrc: SSRC) { + incoming_tracks.retain(|x| !x.ssrcs.contains(&ssrc)); +} + +/// extract all TrackDetails from an SDP. +pub(crate) fn track_details_from_sdp( + s: &SessionDescription, + exclude_inactive: bool, +) -> Vec { + let mut incoming_tracks = vec![]; + + for media in &s.media_descriptions { + let mut tracks_in_media_section = vec![]; + let mut rtx_repair_flows = HashMap::new(); + + let mut stream_id = ""; + let mut track_id = ""; + + // If media section is recvonly or inactive skip + if media.attribute(ATTR_KEY_RECV_ONLY).is_some() + || (exclude_inactive && media.attribute(ATTR_KEY_INACTIVE).is_some()) + { + continue; + } + + let mid_value = match get_mid_value(media) { + Some(mid_value) => mid_value, + None => continue, + }; + + let codec_type = RTPCodecType::from(media.media_name.media.as_str()); + if codec_type == RTPCodecType::Unspecified { + continue; + } + + for attr in &media.attributes { + match attr.key.as_str() { + ATTR_KEY_SSRCGROUP => { + if let Some(value) = &attr.value { + let split: Vec<&str> = value.split(' ').collect(); + if split[0] == SEMANTIC_TOKEN_FLOW_IDENTIFICATION { + // Add rtx ssrcs to blacklist, to avoid adding them as tracks + // Essentially lines like `a=ssrc-group:FID 2231627014 632943048` are processed by this section + // as this declares that the second SSRC (632943048) is a rtx repair flow (RFC4588) for the first + // (2231627014) as specified in RFC5576 + if split.len() == 3 { + let base_ssrc = match split[1].parse::() { + Ok(ssrc) => ssrc, + Err(err) => { + log::warn!("Failed to parse SSRC: {}", err); + continue; + } + }; + let rtx_repair_flow = match split[2].parse::() { + Ok(n) => n, + Err(err) => { + log::warn!("Failed to parse SSRC: {}", err); + continue; + } + }; + rtx_repair_flows.insert(rtx_repair_flow, base_ssrc); + // Remove if rtx was added as track before + filter_track_with_ssrc( + &mut tracks_in_media_section, + rtx_repair_flow as SSRC, + ); + } + } + } + } + + // Handle `a=msid: ` The first value is the same as MediaStream.id + // in the browser and can be used to figure out which tracks belong to the same stream. The browser should + // figure this out automatically when an ontrack event is emitted on RTCPeerConnection. + ATTR_KEY_MSID => { + if let Some(value) = &attr.value { + let mut split = value.split(' '); + + if let (Some(sid), Some(tid), None) = + (split.next(), split.next(), split.next()) + { + stream_id = sid; + track_id = tid; + } + } + } + + ATTR_KEY_SSRC => { + if let Some(value) = &attr.value { + let split: Vec<&str> = value.split(' ').collect(); + let ssrc = match split[0].parse::() { + Ok(ssrc) => ssrc, + Err(err) => { + log::warn!("Failed to parse SSRC: {}", err); + continue; + } + }; + + if rtx_repair_flows.contains_key(&ssrc) { + continue; // This ssrc is a RTX repair flow, ignore + } + + if split.len() == 3 && split[1].starts_with("msid:") { + stream_id = &split[1]["msid:".len()..]; + track_id = split[2]; + } + + let mut track_idx = tracks_in_media_section.len(); + + for (i, t) in tracks_in_media_section.iter().enumerate() { + if t.ssrcs.contains(&ssrc) { + track_idx = i; + //TODO: no break? + } + } + + let mut repair_ssrc = 0; + for (repair, base) in &rtx_repair_flows { + if *base == ssrc { + repair_ssrc = *repair; + //TODO: no break? + } + } + + if track_idx < tracks_in_media_section.len() { + tracks_in_media_section[track_idx].mid = SmolStr::from(mid_value); + tracks_in_media_section[track_idx].kind = codec_type; + tracks_in_media_section[track_idx].stream_id = stream_id.to_owned(); + tracks_in_media_section[track_idx].id = track_id.to_owned(); + tracks_in_media_section[track_idx].ssrcs = vec![ssrc]; + tracks_in_media_section[track_idx].repair_ssrc = repair_ssrc; + } else { + let track_details = TrackDetails { + mid: SmolStr::from(mid_value), + kind: codec_type, + stream_id: stream_id.to_owned(), + id: track_id.to_owned(), + ssrcs: vec![ssrc], + repair_ssrc, + ..Default::default() + }; + tracks_in_media_section.push(track_details); + } + } + } + _ => {} + }; + } + + let rids = get_rids(media); + if !rids.is_empty() && !track_id.is_empty() && !stream_id.is_empty() { + let mut simulcast_track = TrackDetails { + mid: SmolStr::from(mid_value), + kind: codec_type, + stream_id: stream_id.to_owned(), + id: track_id.to_owned(), + rids: vec![], + ..Default::default() + }; + for rid in &rids { + simulcast_track.rids.push(SmolStr::from(&rid.id)); + } + if simulcast_track.rids.len() == tracks_in_media_section.len() { + for track in &tracks_in_media_section { + simulcast_track.ssrcs.extend(&track.ssrcs) + } + } + + tracks_in_media_section = vec![simulcast_track]; + } + + incoming_tracks.extend(tracks_in_media_section); + } + + incoming_tracks +} + +pub(crate) fn get_rids(media: &MediaDescription) -> Vec { + let mut rids = vec![]; + let mut simulcast_attr: Option = None; + for attr in &media.attributes { + if attr.key.as_str() == SDP_ATTRIBUTE_RID { + if let Err(err) = attr + .value + .as_ref() + .ok_or(SimulcastRidParseError::SyntaxIdDirSplit) + .and_then(SimulcastRid::try_from) + .map(|rid| rids.push(rid)) + { + log::warn!("Failed to parse RID: {}", err); + } + } else if attr.key.as_str() == SDP_ATTRIBUTE_SIMULCAST { + simulcast_attr = attr.value.clone(); + } + } + + if let Some(attr) = simulcast_attr { + let mut split = attr.split(' '); + loop { + let _dir = split.next(); + let sc_str_list = split.next(); + if let Some(list) = sc_str_list { + let sc_list: Vec<&str> = list.split(';').flat_map(|alt| alt.split(',')).collect(); + for sc_id in sc_list { + let (sc_id, paused) = if let Some(sc_id) = sc_id.strip_prefix('~') { + (sc_id, true) + } else { + (sc_id, false) + }; + + if let Some(rid) = rids.iter_mut().find(|f| f.id == sc_id) { + rid.paused = paused; + } + } + } else { + break; + } + } + } + + rids +} + +pub(crate) async fn add_candidates_to_media_descriptions( + candidates: &[RTCIceCandidate], + mut m: MediaDescription, + ice_gathering_state: RTCIceGatheringState, +) -> Result { + let append_candidate_if_new = |c: &dyn Candidate, m: MediaDescription| -> MediaDescription { + let marshaled = c.marshal(); + for a in &m.attributes { + if let Some(value) = &a.value { + if &marshaled == value { + return m; + } + } + } + + m.with_value_attribute("candidate".to_owned(), marshaled) + }; + + for c in candidates { + let candidate = c.to_ice()?; + + candidate.set_component(1); + m = append_candidate_if_new(&candidate, m); + + candidate.set_component(2); + m = append_candidate_if_new(&candidate, m); + } + + if ice_gathering_state != RTCIceGatheringState::Complete { + return Ok(m); + } + for a in &m.attributes { + if &a.key == "end-of-candidates" { + return Ok(m); + } + } + + Ok(m.with_property_attribute("end-of-candidates".to_owned())) +} + +pub(crate) struct AddDataMediaSectionParams { + should_add_candidates: bool, + mid_value: String, + ice_params: RTCIceParameters, + dtls_role: ConnectionRole, + ice_gathering_state: RTCIceGatheringState, +} + +pub(crate) async fn add_data_media_section( + d: SessionDescription, + dtls_fingerprints: &[RTCDtlsFingerprint], + candidates: &[RTCIceCandidate], + params: AddDataMediaSectionParams, +) -> Result { + let mut media = MediaDescription { + media_name: MediaName { + media: MEDIA_SECTION_APPLICATION.to_owned(), + port: RangedPort { + value: 9, + range: None, + }, + protos: vec!["UDP".to_owned(), "DTLS".to_owned(), "SCTP".to_owned()], + formats: vec!["webrtc-datachannel".to_owned()], + }, + media_title: None, + connection_information: Some(ConnectionInformation { + network_type: "IN".to_owned(), + address_type: "IP4".to_owned(), + address: Some(Address { + address: "0.0.0.0".to_owned(), + ttl: None, + range: None, + }), + }), + bandwidth: vec![], + encryption_key: None, + attributes: vec![], + } + .with_value_attribute( + ATTR_KEY_CONNECTION_SETUP.to_owned(), + params.dtls_role.to_string(), + ) + .with_value_attribute(ATTR_KEY_MID.to_owned(), params.mid_value) + .with_property_attribute(RTCRtpTransceiverDirection::Sendrecv.to_string()) + .with_property_attribute("sctp-port:5000".to_owned()) + .with_ice_credentials( + params.ice_params.username_fragment, + params.ice_params.password, + ); + + for f in dtls_fingerprints { + media = media.with_fingerprint(f.algorithm.clone(), f.value.to_uppercase()); + } + + if params.should_add_candidates { + media = add_candidates_to_media_descriptions(candidates, media, params.ice_gathering_state) + .await?; + } + + Ok(d.with_media(media)) +} + +pub(crate) async fn populate_local_candidates( + session_description: Option<&session_description::RTCSessionDescription>, + ice_gatherer: Option<&Arc>, + ice_gathering_state: RTCIceGatheringState, +) -> Option { + if session_description.is_none() || ice_gatherer.is_none() { + return session_description.cloned(); + } + + if let (Some(sd), Some(ice)) = (session_description, ice_gatherer) { + let candidates = match ice.get_local_candidates().await { + Ok(candidates) => candidates, + Err(_) => return Some(sd.clone()), + }; + + let mut parsed = match sd.unmarshal() { + Ok(parsed) => parsed, + Err(_) => return Some(sd.clone()), + }; + + if !parsed.media_descriptions.is_empty() { + let mut m = parsed.media_descriptions.remove(0); + m = match add_candidates_to_media_descriptions(&candidates, m, ice_gathering_state) + .await + { + Ok(m) => m, + Err(_) => return Some(sd.clone()), + }; + parsed.media_descriptions.insert(0, m); + } + + Some(session_description::RTCSessionDescription { + sdp_type: sd.sdp_type, + sdp: parsed.marshal(), + parsed: Some(parsed), + }) + } else { + None + } +} + +pub(crate) struct AddTransceiverSdpParams { + should_add_candidates: bool, + mid_value: String, + dtls_role: ConnectionRole, + ice_gathering_state: RTCIceGatheringState, + offered_direction: Option, +} + +pub(crate) async fn add_transceiver_sdp( + mut d: SessionDescription, + dtls_fingerprints: &[RTCDtlsFingerprint], + media_engine: &Arc, + ice_params: &RTCIceParameters, + candidates: &[RTCIceCandidate], + media_section: &MediaSection, + params: AddTransceiverSdpParams, +) -> Result<(SessionDescription, bool)> { + if media_section.transceivers.is_empty() { + return Err(Error::ErrSDPZeroTransceivers); + } + let (should_add_candidates, mid_value, dtls_role, ice_gathering_state) = ( + params.should_add_candidates, + params.mid_value, + params.dtls_role, + params.ice_gathering_state, + ); + + let transceivers = &media_section.transceivers; + // Use the first transceiver to generate the section attributes + let t = &transceivers[0]; + let mut media = MediaDescription::new_jsep_media_description(t.kind.to_string(), vec![]) + .with_value_attribute(ATTR_KEY_CONNECTION_SETUP.to_owned(), dtls_role.to_string()) + .with_value_attribute(ATTR_KEY_MID.to_owned(), mid_value.clone()) + .with_ice_credentials( + ice_params.username_fragment.clone(), + ice_params.password.clone(), + ) + .with_property_attribute(ATTR_KEY_RTCPMUX.to_owned()) + .with_property_attribute(ATTR_KEY_RTCPRSIZE.to_owned()); + + let codecs = t.get_codecs().await; + for codec in &codecs { + let name = codec + .capability + .mime_type + .trim_start_matches("audio/") + .trim_start_matches("video/") + .to_owned(); + media = media.with_codec( + codec.payload_type, + name, + codec.capability.clock_rate, + codec.capability.channels, + codec.capability.sdp_fmtp_line.clone(), + ); + + for feedback in &codec.capability.rtcp_feedback { + media = media.with_value_attribute( + "rtcp-fb".to_owned(), + format!( + "{} {} {}", + codec.payload_type, feedback.typ, feedback.parameter + ), + ); + } + } + if codecs.is_empty() { + // If we are sender and we have no codecs throw an error early + if t.sender().await.track().await.is_some() { + return Err(Error::ErrSenderWithNoCodecs); + } + + // Explicitly reject track if we don't have the codec + d = d.with_media(MediaDescription { + media_name: sdp::description::media::MediaName { + media: t.kind.to_string(), + port: RangedPort { + value: 0, + range: None, + }, + protos: vec![ + "UDP".to_owned(), + "TLS".to_owned(), + "RTP".to_owned(), + "SAVPF".to_owned(), + ], + formats: vec!["0".to_owned()], + }, + media_title: None, + // We need to include connection information even if we're rejecting a track, otherwise Firefox will fail to + // parse the SDP with an error like: + // SIPCC Failed to parse SDP: SDP Parse Error on line 50: c= connection line not specified for every media level, validation failed. + // In addition this makes our SDP compliant with RFC 4566 Section 5.7: https://datatracker.ietf.org/doc/html/rfc4566#section-5.7 + connection_information: Some(ConnectionInformation { + network_type: "IN".to_owned(), + address_type: "IP4".to_owned(), + address: Some(Address { + address: "0.0.0.0".to_owned(), + ttl: None, + range: None, + }), + }), + bandwidth: vec![], + encryption_key: None, + attributes: vec![], + }); + return Ok((d, false)); + } + + let parameters = media_engine.get_rtp_parameters_by_kind(t.kind, t.direction()); + for rtp_extension in ¶meters.header_extensions { + let ext_url = Url::parse(rtp_extension.uri.as_str())?; + media = media.with_extmap(sdp::extmap::ExtMap { + value: rtp_extension.id, + uri: Some(ext_url), + ..Default::default() + }); + } + + if !media_section.rid_map.is_empty() { + let mut recv_sc_list: Vec = vec![]; + let mut send_sc_list: Vec = vec![]; + + for rid in &media_section.rid_map { + let rid_syntax = match rid.direction { + SimulcastDirection::Send => { + // If Send rid, then reply with a recv rid + if rid.paused { + recv_sc_list.push(format!("~{}", rid.id)); + } else { + recv_sc_list.push(rid.id.to_owned()); + } + format!("{} recv", rid.id) + } + SimulcastDirection::Recv => { + // If Recv rid, then reply with a send rid + if rid.paused { + send_sc_list.push(format!("~{}", rid.id)); + } else { + send_sc_list.push(rid.id.to_owned()); + } + format!("{} send", rid.id) + } + }; + media = media.with_value_attribute(SDP_ATTRIBUTE_RID.to_owned(), rid_syntax); + } + + // Simulcast + let mut sc_attr = String::new(); + if !recv_sc_list.is_empty() { + sc_attr.push_str(&format!("recv {}", recv_sc_list.join(";"))); + } + if !send_sc_list.is_empty() { + sc_attr.push_str(&format!("send {}", send_sc_list.join(";"))); + } + media = media.with_value_attribute(SDP_ATTRIBUTE_SIMULCAST.to_owned(), sc_attr); + } + + for mt in transceivers { + let sender = mt.sender().await; + if let Some(track) = sender.track().await { + media = media.with_media_source( + sender.ssrc, + track.stream_id().to_owned(), /* cname */ + track.stream_id().to_owned(), /* streamLabel */ + track.id().to_owned(), + ); + + // Send msid based on the configured track if we haven't already + // sent on this sender. If we have sent we must keep the msid line consistent, this + // is handled below. + if sender.initial_track_id().is_none() { + for stream_id in sender.associated_media_stream_ids() { + media = + media.with_property_attribute(format!("msid:{} {}", stream_id, track.id())); + } + + sender.set_initial_track_id(track.id().to_string())?; + break; + } + } + + if let Some(track_id) = sender.initial_track_id() { + // After we have include an msid attribute in an offer it must stay the same for + // all subsequent offer even if the track or transceiver direction changes. + // + // [RFC 8829 Section 5.2.2](https://datatracker.ietf.org/doc/html/rfc8829#section-5.2.2) + // + // For RtpTransceivers that are not stopped, the "a=msid" line or + // lines MUST stay the same if they are present in the current + // description, regardless of changes to the transceiver's direction + // or track. If no "a=msid" line is present in the current + // description, "a=msid" line(s) MUST be generated according to the + // same rules as for an initial offer. + for stream_id in sender.associated_media_stream_ids() { + media = media.with_property_attribute(format!("msid:{stream_id} {track_id}")); + } + + break; + } + } + + let direction = match params.offered_direction { + Some(offered_direction) => { + use RTCRtpTransceiverDirection::*; + let transceiver_direction = t.direction(); + + match offered_direction { + Sendonly | Recvonly => { + // If a stream is offered as sendonly, the corresponding stream MUST be + // marked as recvonly or inactive in the answer. + + // If a media stream is + // listed as recvonly in the offer, the answer MUST be marked as + // sendonly or inactive in the answer. + offered_direction.reverse().intersect(transceiver_direction) + } + // If an offered media stream is + // listed as sendrecv (or if there is no direction attribute at the + // media or session level, in which case the stream is sendrecv by + // default), the corresponding stream in the answer MAY be marked as + // sendonly, recvonly, sendrecv, or inactive + Sendrecv | Unspecified => t.direction(), + // If an offered media + // stream is listed as inactive, it MUST be marked as inactive in the + // answer. + Inactive => Inactive, + } + } + None => { + // If don't have an offered direction to intersect with just use the transceivers + // current direction. + // + // https://datatracker.ietf.org/doc/html/rfc8829#section-4.2.3 + // + // When creating offers, the transceiver direction is directly reflected + // in the output, even for re-offers. + t.direction() + } + }; + media = media.with_property_attribute(direction.to_string()); + + for fingerprint in dtls_fingerprints { + media = media.with_fingerprint( + fingerprint.algorithm.to_owned(), + fingerprint.value.to_uppercase(), + ); + } + + if should_add_candidates { + media = + add_candidates_to_media_descriptions(candidates, media, ice_gathering_state).await?; + } + + Ok((d.with_media(media), true)) +} + +#[derive(thiserror::Error, Debug, PartialEq)] +pub(crate) enum SimulcastRidParseError { + /// SyntaxIdDirSplit indicates rid-syntax could not be parsed. + #[error("RFC8851 mandates rid-syntax = %s\"a=rid:\" rid-id SP rid-dir")] + SyntaxIdDirSplit, + /// UnknownDirection indicates rid-dir was not parsed. Should be "send" or "recv". + #[error("RFC8851 mandates rid-dir = %s\"send\" / %s\"recv\"")] + UnknownDirection, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(crate) enum SimulcastDirection { + Send, + Recv, +} + +impl TryFrom<&str> for SimulcastDirection { + type Error = SimulcastRidParseError; + fn try_from(value: &str) -> std::result::Result { + match value.to_lowercase().as_str() { + "send" => Ok(SimulcastDirection::Send), + "recv" => Ok(SimulcastDirection::Recv), + _ => Err(SimulcastRidParseError::UnknownDirection), + } + } +} + +#[derive(Clone, Debug)] +pub(crate) struct SimulcastRid { + pub(crate) id: String, + pub(crate) direction: SimulcastDirection, + pub(crate) params: String, + pub(crate) paused: bool, +} + +impl TryFrom<&String> for SimulcastRid { + type Error = SimulcastRidParseError; + fn try_from(value: &String) -> std::result::Result { + let mut split = value.split(' '); + let id = split + .next() + .ok_or(SimulcastRidParseError::SyntaxIdDirSplit)? + .to_owned(); + let direction = SimulcastDirection::try_from( + split + .next() + .ok_or(SimulcastRidParseError::SyntaxIdDirSplit)?, + )?; + let params = split.collect(); + + Ok(Self { + id, + direction, + params, + paused: false, + }) + } +} + +#[derive(Default)] +pub(crate) struct MediaSection { + pub(crate) id: String, + pub(crate) transceivers: Vec>, + pub(crate) data: bool, + pub(crate) rid_map: Vec, + pub(crate) offered_direction: Option, +} + +pub(crate) struct PopulateSdpParams { + pub(crate) media_description_fingerprint: bool, + pub(crate) is_icelite: bool, + pub(crate) connection_role: ConnectionRole, + pub(crate) ice_gathering_state: RTCIceGatheringState, +} + +/// populate_sdp serializes a PeerConnections state into an SDP +pub(crate) async fn populate_sdp( + mut d: SessionDescription, + dtls_fingerprints: &[RTCDtlsFingerprint], + media_engine: &Arc, + candidates: &[RTCIceCandidate], + ice_params: &RTCIceParameters, + media_sections: &[MediaSection], + params: PopulateSdpParams, +) -> Result { + let media_dtls_fingerprints = if params.media_description_fingerprint { + dtls_fingerprints.to_vec() + } else { + vec![] + }; + + let mut bundle_value = "BUNDLE".to_owned(); + let mut bundle_count = 0; + let append_bundle = |mid_value: &str, value: &mut String, count: &mut i32| { + *value = value.clone() + " " + mid_value; + *count += 1; + }; + + for (i, m) in media_sections.iter().enumerate() { + if m.data && !m.transceivers.is_empty() { + return Err(Error::ErrSDPMediaSectionMediaDataChanInvalid); + } else if m.transceivers.len() > 1 { + return Err(Error::ErrSDPMediaSectionMultipleTrackInvalid); + } + + let should_add_candidates = i == 0; + + let should_add_id = if m.data { + let params = AddDataMediaSectionParams { + should_add_candidates, + mid_value: m.id.clone(), + ice_params: ice_params.clone(), + dtls_role: params.connection_role, + ice_gathering_state: params.ice_gathering_state, + }; + d = add_data_media_section(d, &media_dtls_fingerprints, candidates, params).await?; + true + } else { + let params = AddTransceiverSdpParams { + should_add_candidates, + mid_value: m.id.clone(), + dtls_role: params.connection_role, + ice_gathering_state: params.ice_gathering_state, + offered_direction: m.offered_direction, + }; + let (d1, should_add_id) = add_transceiver_sdp( + d, + &media_dtls_fingerprints, + media_engine, + ice_params, + candidates, + m, + params, + ) + .await?; + d = d1; + should_add_id + }; + + if should_add_id { + append_bundle(&m.id, &mut bundle_value, &mut bundle_count); + } + } + + if !params.media_description_fingerprint { + for fingerprint in dtls_fingerprints { + d = d.with_fingerprint( + fingerprint.algorithm.clone(), + fingerprint.value.to_uppercase(), + ); + } + } + + if params.is_icelite { + // RFC 5245 S15.3 + d = d.with_value_attribute(ATTR_KEY_ICELITE.to_owned(), ATTR_KEY_ICELITE.to_owned()); + } + + Ok(d.with_value_attribute(ATTR_KEY_GROUP.to_owned(), bundle_value)) +} + +pub(crate) fn get_mid_value(media: &MediaDescription) -> Option<&String> { + for attr in &media.attributes { + if attr.key == "mid" { + return attr.value.as_ref(); + } + } + None +} + +pub(crate) fn get_peer_direction(media: &MediaDescription) -> RTCRtpTransceiverDirection { + for a in &media.attributes { + let direction = RTCRtpTransceiverDirection::from(a.key.as_str()); + if direction != RTCRtpTransceiverDirection::Unspecified { + return direction; + } + } + RTCRtpTransceiverDirection::Unspecified +} + +pub(crate) fn extract_fingerprint(desc: &SessionDescription) -> Result<(String, String)> { + let mut fingerprints = vec![]; + + if let Some(fingerprint) = desc.attribute("fingerprint") { + fingerprints.push(fingerprint.clone()); + } + + for m in &desc.media_descriptions { + if let Some(fingerprint) = m.attribute("fingerprint").and_then(|o| o) { + fingerprints.push(fingerprint.to_owned()); + } + } + + if fingerprints.is_empty() { + return Err(Error::ErrSessionDescriptionNoFingerprint); + } + + for m in 1..fingerprints.len() { + if fingerprints[m] != fingerprints[0] { + return Err(Error::ErrSessionDescriptionConflictingFingerprints); + } + } + + let parts: Vec<&str> = fingerprints[0].split(' ').collect(); + if parts.len() != 2 { + return Err(Error::ErrSessionDescriptionInvalidFingerprint); + } + + Ok((parts[1].to_owned(), parts[0].to_owned())) +} + +pub(crate) async fn extract_ice_details( + desc: &SessionDescription, +) -> Result<(String, String, Vec)> { + let mut candidates = vec![]; + let mut remote_pwds = vec![]; + let mut remote_ufrags = vec![]; + + if let Some(ufrag) = desc.attribute("ice-ufrag") { + remote_ufrags.push(ufrag.clone()); + } + if let Some(pwd) = desc.attribute("ice-pwd") { + remote_pwds.push(pwd.clone()); + } + + for m in &desc.media_descriptions { + if let Some(ufrag) = m.attribute("ice-ufrag").and_then(|o| o) { + remote_ufrags.push(ufrag.to_owned()); + } + if let Some(pwd) = m.attribute("ice-pwd").and_then(|o| o) { + remote_pwds.push(pwd.to_owned()); + } + + for a in &m.attributes { + if a.is_ice_candidate() { + if let Some(value) = &a.value { + let c: Arc = Arc::new(unmarshal_candidate(value)?); + let candidate = RTCIceCandidate::from(&c); + candidates.push(candidate); + } + } + } + } + + if remote_ufrags.is_empty() { + return Err(Error::ErrSessionDescriptionMissingIceUfrag); + } else if remote_pwds.is_empty() { + return Err(Error::ErrSessionDescriptionMissingIcePwd); + } + + for m in 1..remote_ufrags.len() { + if remote_ufrags[m] != remote_ufrags[0] { + return Err(Error::ErrSessionDescriptionConflictingIceUfrag); + } + } + + for m in 1..remote_pwds.len() { + if remote_pwds[m] != remote_pwds[0] { + return Err(Error::ErrSessionDescriptionConflictingIcePwd); + } + } + + Ok((remote_ufrags[0].clone(), remote_pwds[0].clone(), candidates)) +} + +pub(crate) fn have_application_media_section(desc: &SessionDescription) -> bool { + for m in &desc.media_descriptions { + if m.media_name.media == MEDIA_SECTION_APPLICATION { + return true; + } + } + + false +} + +pub(crate) fn get_by_mid<'a>( + search_mid: &str, + desc: &'a session_description::RTCSessionDescription, +) -> Option<&'a MediaDescription> { + if let Some(parsed) = &desc.parsed { + for m in &parsed.media_descriptions { + if let Some(mid) = m.attribute(ATTR_KEY_MID).flatten() { + if mid == search_mid { + return Some(m); + } + } + } + } + None +} + +/// have_data_channel return MediaDescription with MediaName equal application +pub(crate) fn have_data_channel( + desc: &session_description::RTCSessionDescription, +) -> Option<&MediaDescription> { + if let Some(parsed) = &desc.parsed { + for d in &parsed.media_descriptions { + if d.media_name.media == MEDIA_SECTION_APPLICATION { + return Some(d); + } + } + } + None +} + +pub(crate) fn codecs_from_media_description( + m: &MediaDescription, +) -> Result> { + let s = SessionDescription { + media_descriptions: vec![m.clone()], + ..Default::default() + }; + + let mut out = vec![]; + for payload_str in &m.media_name.formats { + let payload_type: PayloadType = payload_str.parse::()?; + let codec = match s.get_codec_for_payload_type(payload_type) { + Ok(codec) => codec, + Err(err) => { + if payload_type == 0 { + continue; + } + return Err(err.into()); + } + }; + + let channels = codec.encoding_parameters.parse::().unwrap_or(0); + + let mut feedback = vec![]; + for raw in &codec.rtcp_feedback { + let split: Vec<&str> = raw.split(' ').collect(); + + let entry = if split.len() == 2 { + RTCPFeedback { + typ: split[0].to_string(), + parameter: split[1].to_string(), + } + } else { + RTCPFeedback { + typ: split[0].to_string(), + parameter: String::new(), + } + }; + + feedback.push(entry); + } + + out.push(RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: m.media_name.media.clone() + "/" + codec.name.as_str(), + clock_rate: codec.clock_rate, + channels, + sdp_fmtp_line: codec.fmtp.clone(), + rtcp_feedback: feedback, + }, + payload_type, + stats_id: String::new(), + }) + } + + Ok(out) +} + +pub(crate) fn rtp_extensions_from_media_description( + m: &MediaDescription, +) -> Result> { + let mut out = HashMap::new(); + + for a in &m.attributes { + if a.key == ATTR_KEY_EXT_MAP { + let a_str = a.to_string(); + let mut reader = BufReader::new(a_str.as_bytes()); + let e = ExtMap::unmarshal(&mut reader)?; + + if let Some(uri) = e.uri { + out.insert(uri.to_string(), e.value); + } + } + } + + Ok(out) +} + +/// update_sdp_origin saves sdp.Origin in PeerConnection when creating 1st local SDP; +/// for subsequent calling, it updates Origin for SessionDescription from saved one +/// and increments session version by one. +/// +pub(crate) fn update_sdp_origin(origin: &mut Origin, d: &mut SessionDescription) { + //TODO: if atomic.CompareAndSwapUint64(&origin.SessionVersion, 0, d.Origin.SessionVersion) + if origin.session_version == 0 { + // store + origin.session_version = d.origin.session_version; + //atomic.StoreUint64(&origin.SessionID, d.Origin.SessionID) + origin.session_id = d.origin.session_id; + } else { + // load + /*for { // awaiting for saving session id + d.Origin.SessionID = atomic.LoadUint64(&origin.SessionID) + if d.Origin.SessionID != 0 { + break + } + }*/ + d.origin.session_id = origin.session_id; + + //d.Origin.SessionVersion = atomic.AddUint64(&origin.SessionVersion, 1) + origin.session_version += 1; + d.origin.session_version += 1; + } +} diff --git a/rtc/src/peer_connection/sdp/sdp_test.rs b/rtc/src/peer_connection/sdp/sdp_test.rs new file mode 100644 index 0000000..753007f --- /dev/null +++ b/rtc/src/peer_connection/sdp/sdp_test.rs @@ -0,0 +1,1150 @@ +use rcgen::KeyPair; +use sdp::description::common::Attribute; + +use super::*; +use crate::api::media_engine::{MIME_TYPE_OPUS, MIME_TYPE_VP8}; +use crate::api::setting_engine::SettingEngine; +use crate::api::APIBuilder; +use crate::dtls_transport::dtls_role::DEFAULT_DTLS_ROLE_OFFER; +use crate::dtls_transport::RTCDtlsTransport; +use crate::peer_connection::certificate::RTCCertificate; +use crate::rtp_transceiver::rtp_sender::RTCRtpSender; +use crate::track::track_local::track_local_static_sample::TrackLocalStaticSample; +use crate::track::track_local::TrackLocal; + +#[test] +fn test_extract_fingerprint() -> Result<()> { + //"Good Session Fingerprint" + { + let s = SessionDescription { + attributes: vec![Attribute { + key: "fingerprint".to_owned(), + value: Some("foo bar".to_owned()), + }], + ..Default::default() + }; + + let (fingerprint, hash) = extract_fingerprint(&s)?; + assert_eq!(fingerprint, "bar"); + assert_eq!(hash, "foo"); + } + + //"Good Media Fingerprint" + { + let s = SessionDescription { + media_descriptions: vec![MediaDescription { + attributes: vec![Attribute { + key: "fingerprint".to_owned(), + value: Some("foo bar".to_owned()), + }], + ..Default::default() + }], + ..Default::default() + }; + + let (fingerprint, hash) = extract_fingerprint(&s)?; + assert_eq!(fingerprint, "bar"); + assert_eq!(hash, "foo"); + } + + //"No Fingerprint" + { + let s = SessionDescription::default(); + + if let Err(err) = extract_fingerprint(&s) { + assert_eq!(err, Error::ErrSessionDescriptionNoFingerprint); + } else { + panic!(); + } + } + + //"Invalid Fingerprint" + { + let s = SessionDescription { + attributes: vec![Attribute { + key: "fingerprint".to_owned(), + value: Some("foo".to_owned()), + }], + ..Default::default() + }; + + if let Err(err) = extract_fingerprint(&s) { + assert_eq!(err, Error::ErrSessionDescriptionInvalidFingerprint); + } else { + panic!(); + } + } + + //"Conflicting Fingerprint" + { + let s = SessionDescription { + attributes: vec![Attribute { + key: "fingerprint".to_owned(), + value: Some("foo".to_owned()), + }], + media_descriptions: vec![MediaDescription { + attributes: vec![Attribute { + key: "fingerprint".to_owned(), + value: Some("foo bar".to_owned()), + }], + ..Default::default() + }], + ..Default::default() + }; + + if let Err(err) = extract_fingerprint(&s) { + assert_eq!(err, Error::ErrSessionDescriptionConflictingFingerprints); + } else { + panic!(); + } + } + + Ok(()) +} + +#[tokio::test] +async fn test_extract_ice_details() -> Result<()> { + const DEFAULT_UFRAG: &str = "DEFAULT_PWD"; + const DEFAULT_PWD: &str = "DEFAULT_UFRAG"; + + //"Missing ice-pwd" + { + let s = SessionDescription { + media_descriptions: vec![MediaDescription { + attributes: vec![Attribute { + key: "ice-ufrag".to_owned(), + value: Some(DEFAULT_UFRAG.to_owned()), + }], + ..Default::default() + }], + ..Default::default() + }; + + if let Err(err) = extract_ice_details(&s).await { + assert_eq!(err, Error::ErrSessionDescriptionMissingIcePwd); + } else { + panic!(); + } + } + + //"Missing ice-ufrag" + { + let s = SessionDescription { + media_descriptions: vec![MediaDescription { + attributes: vec![Attribute { + key: "ice-pwd".to_owned(), + value: Some(DEFAULT_PWD.to_owned()), + }], + ..Default::default() + }], + ..Default::default() + }; + + if let Err(err) = extract_ice_details(&s).await { + assert_eq!(err, Error::ErrSessionDescriptionMissingIceUfrag); + } else { + panic!(); + } + } + + //"ice details at session level" + { + let s = SessionDescription { + attributes: vec![ + Attribute { + key: "ice-ufrag".to_owned(), + value: Some(DEFAULT_UFRAG.to_owned()), + }, + Attribute { + key: "ice-pwd".to_owned(), + value: Some(DEFAULT_PWD.to_owned()), + }, + ], + media_descriptions: vec![], + ..Default::default() + }; + + let (ufrag, pwd, _) = extract_ice_details(&s).await?; + assert_eq!(ufrag, DEFAULT_UFRAG); + assert_eq!(pwd, DEFAULT_PWD); + } + + //"ice details at media level" + { + let s = SessionDescription { + media_descriptions: vec![MediaDescription { + attributes: vec![ + Attribute { + key: "ice-ufrag".to_owned(), + value: Some(DEFAULT_UFRAG.to_owned()), + }, + Attribute { + key: "ice-pwd".to_owned(), + value: Some(DEFAULT_PWD.to_owned()), + }, + ], + ..Default::default() + }], + ..Default::default() + }; + + let (ufrag, pwd, _) = extract_ice_details(&s).await?; + assert_eq!(ufrag, DEFAULT_UFRAG); + assert_eq!(pwd, DEFAULT_PWD); + } + + //"Conflict ufrag" + { + let s = SessionDescription { + attributes: vec![Attribute { + key: "ice-ufrag".to_owned(), + value: Some("invalidUfrag".to_owned()), + }], + media_descriptions: vec![MediaDescription { + attributes: vec![ + Attribute { + key: "ice-ufrag".to_owned(), + value: Some(DEFAULT_UFRAG.to_owned()), + }, + Attribute { + key: "ice-pwd".to_owned(), + value: Some(DEFAULT_PWD.to_owned()), + }, + ], + ..Default::default() + }], + ..Default::default() + }; + + if let Err(err) = extract_ice_details(&s).await { + assert_eq!(err, Error::ErrSessionDescriptionConflictingIceUfrag); + } else { + panic!(); + } + } + + //"Conflict pwd" + { + let s = SessionDescription { + attributes: vec![Attribute { + key: "ice-pwd".to_owned(), + value: Some("invalidPwd".to_owned()), + }], + media_descriptions: vec![MediaDescription { + attributes: vec![ + Attribute { + key: "ice-ufrag".to_owned(), + value: Some(DEFAULT_UFRAG.to_owned()), + }, + Attribute { + key: "ice-pwd".to_owned(), + value: Some(DEFAULT_PWD.to_owned()), + }, + ], + ..Default::default() + }], + ..Default::default() + }; + + if let Err(err) = extract_ice_details(&s).await { + assert_eq!(err, Error::ErrSessionDescriptionConflictingIcePwd); + } else { + panic!(); + } + } + + Ok(()) +} + +#[test] +fn test_track_details_from_sdp() -> Result<()> { + //"Tracks unknown, audio and video with RTX" + { + let s = SessionDescription { + media_descriptions: vec![ + MediaDescription { + media_name: MediaName { + media: "foobar".to_owned(), + ..Default::default() + }, + attributes: vec![ + Attribute { + key: "mid".to_owned(), + value: Some("0".to_owned()), + }, + Attribute { + key: "sendrecv".to_owned(), + value: None, + }, + Attribute { + key: "ssrc".to_owned(), + value: Some("1000 msid:unknown_trk_label unknown_trk_guid".to_owned()), + }, + ], + ..Default::default() + }, + MediaDescription { + media_name: MediaName { + media: "audio".to_owned(), + ..Default::default() + }, + attributes: vec![ + Attribute { + key: "mid".to_owned(), + value: Some("1".to_owned()), + }, + Attribute { + key: "sendrecv".to_owned(), + value: None, + }, + Attribute { + key: "ssrc".to_owned(), + value: Some("2000 msid:audio_trk_label audio_trk_guid".to_owned()), + }, + ], + ..Default::default() + }, + MediaDescription { + media_name: MediaName { + media: "video".to_owned(), + ..Default::default() + }, + attributes: vec![ + Attribute { + key: "mid".to_owned(), + value: Some("2".to_owned()), + }, + Attribute { + key: "sendrecv".to_owned(), + value: None, + }, + Attribute { + key: "ssrc-group".to_owned(), + value: Some("FID 3000 4000".to_owned()), + }, + Attribute { + key: "ssrc".to_owned(), + value: Some("3000 msid:video_trk_label video_trk_guid".to_owned()), + }, + Attribute { + key: "ssrc".to_owned(), + value: Some("4000 msid:rtx_trk_label rtx_trck_guid".to_owned()), + }, + ], + ..Default::default() + }, + MediaDescription { + media_name: MediaName { + media: "video".to_owned(), + ..Default::default() + }, + attributes: vec![ + Attribute { + key: "mid".to_owned(), + value: Some("3".to_owned()), + }, + Attribute { + key: "sendonly".to_owned(), + value: None, + }, + Attribute { + key: "msid".to_owned(), + value: Some("video_stream_id video_trk_id".to_owned()), + }, + Attribute { + key: "ssrc".to_owned(), + value: Some("5000".to_owned()), + }, + ], + ..Default::default() + }, + MediaDescription { + media_name: MediaName { + media: "video".to_owned(), + ..Default::default() + }, + attributes: vec![ + Attribute { + key: "sendonly".to_owned(), + value: None, + }, + Attribute { + key: SDP_ATTRIBUTE_RID.to_owned(), + value: Some("f send pt=97;max-width=1280;max-height=720".to_owned()), + }, + ], + ..Default::default() + }, + ], + ..Default::default() + }; + + let tracks = track_details_from_sdp(&s, true); + assert_eq!(tracks.len(), 3); + if track_details_for_ssrc(&tracks, 1000).is_some() { + panic!("got the unknown track ssrc:1000 which should have been skipped"); + } + if let Some(track) = track_details_for_ssrc(&tracks, 2000) { + assert_eq!(track.kind, RTPCodecType::Audio); + assert_eq!(track.ssrcs[0], 2000); + assert_eq!(track.stream_id, "audio_trk_label"); + } else { + panic!("missing audio track with ssrc:2000"); + } + if let Some(track) = track_details_for_ssrc(&tracks, 3000) { + assert_eq!(track.kind, RTPCodecType::Video); + assert_eq!(track.ssrcs[0], 3000); + assert_eq!(track.stream_id, "video_trk_label"); + } else { + panic!("missing video track with ssrc:3000"); + } + if track_details_for_ssrc(&tracks, 4000).is_some() { + panic!("got the rtx track ssrc:3000 which should have been skipped"); + } + if let Some(track) = track_details_for_ssrc(&tracks, 5000) { + assert_eq!(track.kind, RTPCodecType::Video); + assert_eq!(track.ssrcs[0], 5000); + assert_eq!(track.id, "video_trk_id"); + assert_eq!(track.stream_id, "video_stream_id"); + } else { + panic!("missing video track with ssrc:5000"); + } + } + + { + let s = SessionDescription { + media_descriptions: vec![ + MediaDescription { + media_name: MediaName { + media: "video".to_owned(), + ..Default::default() + }, + attributes: vec![ + Attribute { + key: "mid".to_owned(), + value: Some("1".to_owned()), + }, + Attribute { + key: "inactive".to_owned(), + value: None, + }, + Attribute { + key: "ssrc".to_owned(), + value: Some("6000".to_owned()), + }, + ], + ..Default::default() + }, + MediaDescription { + media_name: MediaName { + media: "video".to_owned(), + ..Default::default() + }, + attributes: vec![ + Attribute { + key: "mid".to_owned(), + value: Some("1".to_owned()), + }, + Attribute { + key: "recvonly".to_owned(), + value: None, + }, + Attribute { + key: "ssrc".to_owned(), + value: Some("7000".to_owned()), + }, + ], + ..Default::default() + }, + ], + ..Default::default() + }; + assert_eq!( + track_details_from_sdp(&s, true).len(), + 0, + "inactive and recvonly tracks should be ignored when passing exclude_inactive: true" + ); + assert_eq!( + track_details_from_sdp(&s, false).len(), + 1, + "Inactive tracks should not be ignored when passing exclude_inactive: false" + ); + } + + Ok(()) +} + +#[test] +fn test_have_application_media_section() -> Result<()> { + //"Audio only" + { + let s = SessionDescription { + media_descriptions: vec![MediaDescription { + media_name: MediaName { + media: "audio".to_owned(), + ..Default::default() + }, + attributes: vec![ + Attribute { + key: "sendrecv".to_owned(), + value: None, + }, + Attribute { + key: "ssrc".to_owned(), + value: Some("2000".to_owned()), + }, + ], + ..Default::default() + }], + ..Default::default() + }; + + assert!(!have_application_media_section(&s)); + } + + //"Application" + { + let s = SessionDescription { + media_descriptions: vec![MediaDescription { + media_name: MediaName { + media: MEDIA_SECTION_APPLICATION.to_owned(), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }; + + assert!(have_application_media_section(&s)); + } + + Ok(()) +} + +async fn fingerprint_test( + certificate: &RTCCertificate, + engine: &Arc, + media: &[MediaSection], + sdpmedia_description_fingerprints: bool, + expected_fingerprint_count: usize, +) -> Result<()> { + let s = SessionDescription::default(); + + let dtls_fingerprints = certificate.get_fingerprints(); + + let params = PopulateSdpParams { + media_description_fingerprint: sdpmedia_description_fingerprints, + is_icelite: false, + connection_role: ConnectionRole::Active, + ice_gathering_state: RTCIceGatheringState::New, + }; + + let s = populate_sdp( + s, + &dtls_fingerprints, + engine, + &[], + &RTCIceParameters::default(), + media, + params, + ) + .await?; + + let sdparray = s.marshal(); + + assert_eq!( + sdparray.matches("sha-256").count(), + expected_fingerprint_count + ); + + Ok(()) +} + +#[tokio::test] +async fn test_media_description_fingerprints() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + let interceptor = api.interceptor_registry.build("")?; + + let kp = KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)?; + let certificate = RTCCertificate::from_key_pair(kp)?; + + let transport = Arc::new(RTCDtlsTransport::default()); + + let video_receiver = Arc::new(api.new_rtp_receiver( + RTPCodecType::Video, + Arc::clone(&transport), + Arc::clone(&interceptor), + )); + let audio_receiver = Arc::new(api.new_rtp_receiver( + RTPCodecType::Audio, + Arc::clone(&transport), + Arc::clone(&interceptor), + )); + + let video_sender = Arc::new( + api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) + .await, + ); + + let audio_sender = Arc::new( + api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) + .await, + ); + + let media = vec![ + MediaSection { + id: "video".to_owned(), + transceivers: vec![ + RTCRtpTransceiver::new( + video_receiver, + video_sender, + RTCRtpTransceiverDirection::Inactive, + RTPCodecType::Video, + api.media_engine.get_codecs_by_kind(RTPCodecType::Video), + Arc::clone(&api.media_engine), + None, + ) + .await, + ], + ..Default::default() + }, + MediaSection { + id: "audio".to_owned(), + transceivers: vec![ + RTCRtpTransceiver::new( + audio_receiver, + audio_sender, + RTCRtpTransceiverDirection::Inactive, + RTPCodecType::Audio, + api.media_engine.get_codecs_by_kind(RTPCodecType::Audio), + Arc::clone(&api.media_engine), + None, + ) + .await, + ], + ..Default::default() + }, + MediaSection { + id: "application".to_owned(), + data: true, + ..Default::default() + }, + ]; + + #[allow(clippy::needless_range_loop)] + for i in 0..2 { + let track: Arc = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: "video/vp8".to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + media[i].transceivers[0] + .set_sender(Arc::new( + RTCRtpSender::new( + api.setting_engine.get_receive_mtu(), + Some(track), + Arc::new(RTCDtlsTransport::default()), + Arc::clone(&api.media_engine), + Arc::clone(&interceptor), + false, + ) + .await, + )) + .await; + media[i].transceivers[0].set_direction_internal(RTCRtpTransceiverDirection::Sendonly); + } + + //"Per-Media Description Fingerprints", + fingerprint_test(&certificate, &api.media_engine, &media, true, 3).await?; + + //"Per-Session Description Fingerprints", + fingerprint_test(&certificate, &api.media_engine, &media, false, 1).await?; + + Ok(()) +} + +#[tokio::test] +async fn test_populate_sdp() -> Result<()> { + //"Rid" + { + let se = SettingEngine::default(); + let mut me = MediaEngine::default(); + me.register_default_codecs()?; + + let api = APIBuilder::new().with_media_engine(me).build(); + let interceptor = api.interceptor_registry.build("")?; + let transport = Arc::new(RTCDtlsTransport::default()); + + let receiver = Arc::new(api.new_rtp_receiver( + RTPCodecType::Video, + Arc::clone(&transport), + Arc::clone(&interceptor), + )); + + let sender = Arc::new( + api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) + .await, + ); + + let tr = RTCRtpTransceiver::new( + receiver, + sender, + RTCRtpTransceiverDirection::Recvonly, + RTPCodecType::Video, + api.media_engine.video_codecs.clone(), + Arc::clone(&api.media_engine), + None, + ) + .await; + + let rid_map = vec![ + SimulcastRid { + id: "ridkey".to_owned(), + direction: SimulcastDirection::Recv, + params: "some".to_owned(), + paused: false, + }, + SimulcastRid { + id: "ridpaused".to_owned(), + direction: SimulcastDirection::Recv, + params: "some2".to_owned(), + paused: true, + }, + ]; + let media_sections = vec![MediaSection { + id: "video".to_owned(), + transceivers: vec![tr], + data: false, + rid_map, + ..Default::default() + }]; + + let d = SessionDescription::default(); + + let params = PopulateSdpParams { + media_description_fingerprint: se.sdp_media_level_fingerprints, + is_icelite: se.candidates.ice_lite, + connection_role: DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), + ice_gathering_state: RTCIceGatheringState::Complete, + }; + let offer_sdp = populate_sdp( + d, + &[], + &api.media_engine, + &[], + &RTCIceParameters::default(), + &media_sections, + params, + ) + .await?; + + // Test contains rid map keys + let mut found = 0; + for desc in &offer_sdp.media_descriptions { + if desc.media_name.media != "video" { + continue; + } + + let rid_map = get_rids(desc); + if let Some(rid) = rid_map.iter().find(|rid| rid.id == "ridkey") { + assert!(!rid.paused, "Rid should be active"); + assert_eq!( + rid.direction, + SimulcastDirection::Send, + "Rid should be send" + ); + found += 1; + } + if let Some(rid) = rid_map.iter().find(|rid| rid.id == "ridpaused") { + assert!(rid.paused, "Rid should be paused"); + assert_eq!( + rid.direction, + SimulcastDirection::Send, + "Rid should be send" + ); + found += 1; + } + } + assert_eq!(found, 2, "All Rid key should be present"); + } + + //"SetCodecPreferences" + { + let se = SettingEngine::default(); + let mut me = MediaEngine::default(); + me.register_default_codecs()?; + me.push_codecs(me.video_codecs.clone(), RTPCodecType::Video) + .await; + me.push_codecs(me.audio_codecs.clone(), RTPCodecType::Audio) + .await; + + let api = APIBuilder::new().with_media_engine(me).build(); + let interceptor = api.interceptor_registry.build("")?; + let transport = Arc::new(RTCDtlsTransport::default()); + let receiver = Arc::new(api.new_rtp_receiver( + RTPCodecType::Video, + Arc::clone(&transport), + Arc::clone(&interceptor), + )); + + let sender = Arc::new( + api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) + .await, + ); + + let tr = RTCRtpTransceiver::new( + receiver, + sender, + RTCRtpTransceiverDirection::Recvonly, + RTPCodecType::Video, + api.media_engine.video_codecs.clone(), + Arc::clone(&api.media_engine), + None, + ) + .await; + tr.set_codec_preferences(vec![RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_owned(), + rtcp_feedback: vec![], + }, + payload_type: 96, + ..Default::default() + }]) + .await?; + + let media_sections = vec![MediaSection { + id: "video".to_owned(), + transceivers: vec![tr], + data: false, + rid_map: vec![], + ..Default::default() + }]; + + let d = SessionDescription::default(); + + let params = PopulateSdpParams { + media_description_fingerprint: se.sdp_media_level_fingerprints, + is_icelite: se.candidates.ice_lite, + connection_role: DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), + ice_gathering_state: RTCIceGatheringState::Complete, + }; + let offer_sdp = populate_sdp( + d, + &[], + &api.media_engine, + &[], + &RTCIceParameters::default(), + &media_sections, + params, + ) + .await?; + + // Test codecs + let mut found_vp8 = false; + for desc in &offer_sdp.media_descriptions { + if desc.media_name.media != "video" { + continue; + } + for a in &desc.attributes { + if a.key.contains("rtpmap") { + if let Some(value) = &a.value { + if value == "98 VP9/90000" { + panic!("vp9 should not be present in sdp"); + } else if value == "96 VP8/90000" { + found_vp8 = true; + } + } + } + } + } + assert!(found_vp8, "vp8 should be present in sdp"); + } + + Ok(()) +} + +#[tokio::test] +async fn test_populate_sdp_reject() -> Result<()> { + let se = SettingEngine::default(); + let mut me = MediaEngine::default(); + me.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + clock_rate: 90_000, + channels: 0, + sdp_fmtp_line: "".to_owned(), + rtcp_feedback: vec![], + }, + payload_type: 2, + stats_id: "id".to_owned(), + }, + RTPCodecType::Video, + )?; + + let api = APIBuilder::new().with_media_engine(me).build(); + let interceptor = api.interceptor_registry.build("")?; + let transport = Arc::new(RTCDtlsTransport::default()); + let video_receiver = Arc::new(api.new_rtp_receiver( + RTPCodecType::Video, + Arc::clone(&transport), + Arc::clone(&interceptor), + )); + + let video_sender = Arc::new( + api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) + .await, + ); + + let trv = RTCRtpTransceiver::new( + video_receiver, + video_sender, + RTCRtpTransceiverDirection::Recvonly, + RTPCodecType::Video, + api.media_engine.video_codecs.clone(), + Arc::clone(&api.media_engine), + None, + ) + .await; + + let audio_receiver = Arc::new(api.new_rtp_receiver( + RTPCodecType::Audio, + Arc::clone(&transport), + Arc::clone(&interceptor), + )); + + let audio_sender = Arc::new( + api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) + .await, + ); + + let tra = RTCRtpTransceiver::new( + audio_receiver, + audio_sender, + RTCRtpTransceiverDirection::Recvonly, + RTPCodecType::Audio, + api.media_engine.audio_codecs.clone(), + Arc::clone(&api.media_engine), + None, + ) + .await; + + let media_sections = vec![ + MediaSection { + id: "video".to_owned(), + transceivers: vec![trv], + data: false, + rid_map: vec![], + ..Default::default() + }, + MediaSection { + id: "audio".to_owned(), + transceivers: vec![tra], + data: false, + rid_map: vec![], + ..Default::default() + }, + ]; + + let d = SessionDescription::default(); + + let params = PopulateSdpParams { + media_description_fingerprint: se.sdp_media_level_fingerprints, + is_icelite: se.candidates.ice_lite, + connection_role: DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), + ice_gathering_state: RTCIceGatheringState::Complete, + }; + let offer_sdp = populate_sdp( + d, + &[], + &api.media_engine, + &[], + &RTCIceParameters::default(), + &media_sections, + params, + ) + .await?; + + let mut found_rejected_track = false; + + for desc in offer_sdp.media_descriptions { + if desc.media_name.media != "audio" { + continue; + } + found_rejected_track = true; + + assert!( + desc.connection_information.is_some(), + "connection_information should not be None, even for rejected tracks" + ); + assert_eq!( + desc.media_name.formats, + vec!["0"], + "Format for rejected track should be 0" + ); + assert_eq!( + desc.media_name.port.value, 0, + "Port for rejected track should be 0" + ); + } + + assert!( + found_rejected_track, + "There should've been a rejected track" + ); + + Ok(()) +} + +#[test] +fn test_get_rids() { + let m = vec![MediaDescription { + media_name: MediaName { + media: "video".to_owned(), + ..Default::default() + }, + attributes: vec![ + Attribute { + key: "sendonly".to_owned(), + value: None, + }, + Attribute { + key: SDP_ATTRIBUTE_RID.to_owned(), + value: Some("f send pt=97;max-width=1280;max-height=720".to_owned()), + }, + ], + ..Default::default() + }]; + + let rids = get_rids(&m[0]); + + assert!(!rids.is_empty(), "Rid mapping should be present"); + + let f = rids.iter().find(|rid| rid.id == "f"); + assert!(f.is_some(), "rid values should contain 'f'"); +} + +#[test] +fn test_codecs_from_media_description() -> Result<()> { + //"Codec Only" + { + let codecs = codecs_from_media_description(&MediaDescription { + media_name: MediaName { + media: "audio".to_owned(), + formats: vec!["111".to_owned()], + ..Default::default() + }, + attributes: vec![Attribute { + key: "rtpmap".to_owned(), + value: Some("111 opus/48000/2".to_owned()), + }], + ..Default::default() + })?; + + assert_eq!( + codecs, + vec![RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_OPUS.to_owned(), + clock_rate: 48000, + channels: 2, + sdp_fmtp_line: "".to_owned(), + rtcp_feedback: vec![], + }, + payload_type: 111, + ..Default::default() + }], + ); + } + + //"Codec with fmtp/rtcp-fb" + { + let codecs = codecs_from_media_description(&MediaDescription { + media_name: MediaName { + media: "audio".to_owned(), + formats: vec!["111".to_owned()], + ..Default::default() + }, + attributes: vec![ + Attribute { + key: "rtpmap".to_owned(), + value: Some("111 opus/48000/2".to_owned()), + }, + Attribute { + key: "fmtp".to_owned(), + value: Some("111 minptime=10;useinbandfec=1".to_owned()), + }, + Attribute { + key: "rtcp-fb".to_owned(), + value: Some("111 goog-remb".to_owned()), + }, + Attribute { + key: "rtcp-fb".to_owned(), + value: Some("111 ccm fir".to_owned()), + }, + ], + ..Default::default() + })?; + + assert_eq!( + codecs, + vec![RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_OPUS.to_owned(), + clock_rate: 48000, + channels: 2, + sdp_fmtp_line: "minptime=10;useinbandfec=1".to_owned(), + rtcp_feedback: vec![ + RTCPFeedback { + typ: "goog-remb".to_owned(), + parameter: "".to_owned() + }, + RTCPFeedback { + typ: "ccm".to_owned(), + parameter: "fir".to_owned() + } + ] + }, + payload_type: 111, + ..Default::default() + }], + ); + } + + Ok(()) +} + +#[test] +fn test_rtp_extensions_from_media_description() -> Result<()> { + let extensions = rtp_extensions_from_media_description(&MediaDescription { + media_name: MediaName { + media: "audio".to_owned(), + formats: vec!["111".to_owned()], + ..Default::default() + }, + attributes: vec![ + Attribute { + key: "extmap".to_owned(), + value: Some("1 ".to_owned() + sdp::extmap::ABS_SEND_TIME_URI), + }, + Attribute { + key: "extmap".to_owned(), + value: Some("3 ".to_owned() + sdp::extmap::SDES_MID_URI), + }, + ], + ..Default::default() + })?; + + assert_eq!(extensions[sdp::extmap::ABS_SEND_TIME_URI], 1); + assert_eq!(extensions[sdp::extmap::SDES_MID_URI], 3); + + Ok(()) +} diff --git a/rtc/src/peer_connection/sdp/sdp_type.rs b/rtc/src/peer_connection/sdp/sdp_type.rs new file mode 100644 index 0000000..830864f --- /dev/null +++ b/rtc/src/peer_connection/sdp/sdp_type.rs @@ -0,0 +1,101 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +/// SDPType describes the type of an SessionDescription. +#[derive(Default, Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub enum RTCSdpType { + #[default] + Unspecified = 0, + + /// indicates that a description MUST be treated as an SDP offer. + #[serde(rename = "offer")] + Offer, + + /// indicates that a description MUST be treated as an + /// SDP answer, but not a final answer. A description used as an SDP + /// pranswer may be applied as a response to an SDP offer, or an update to + /// a previously sent SDP pranswer. + #[serde(rename = "pranswer")] + Pranswer, + + /// indicates that a description MUST be treated as an SDP + /// final answer, and the offer-answer exchange MUST be considered complete. + /// A description used as an SDP answer may be applied as a response to an + /// SDP offer or as an update to a previously sent SDP pranswer. + #[serde(rename = "answer")] + Answer, + + /// indicates that a description MUST be treated as + /// canceling the current SDP negotiation and moving the SDP offer and + /// answer back to what it was in the previous stable state. Note the + /// local or remote SDP descriptions in the previous stable state could be + /// null if there has not yet been a successful offer-answer negotiation. + #[serde(rename = "rollback")] + Rollback, +} + +const SDP_TYPE_OFFER_STR: &str = "offer"; +const SDP_TYPE_PRANSWER_STR: &str = "pranswer"; +const SDP_TYPE_ANSWER_STR: &str = "answer"; +const SDP_TYPE_ROLLBACK_STR: &str = "rollback"; + +/// creates an SDPType from a string +impl From<&str> for RTCSdpType { + fn from(raw: &str) -> Self { + match raw { + SDP_TYPE_OFFER_STR => RTCSdpType::Offer, + SDP_TYPE_PRANSWER_STR => RTCSdpType::Pranswer, + SDP_TYPE_ANSWER_STR => RTCSdpType::Answer, + SDP_TYPE_ROLLBACK_STR => RTCSdpType::Rollback, + _ => RTCSdpType::Unspecified, + } + } +} + +impl fmt::Display for RTCSdpType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + RTCSdpType::Offer => write!(f, "{SDP_TYPE_OFFER_STR}"), + RTCSdpType::Pranswer => write!(f, "{SDP_TYPE_PRANSWER_STR}"), + RTCSdpType::Answer => write!(f, "{SDP_TYPE_ANSWER_STR}"), + RTCSdpType::Rollback => write!(f, "{SDP_TYPE_ROLLBACK_STR}"), + _ => write!(f, "{}", crate::UNSPECIFIED_STR), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_sdp_type() { + let tests = vec![ + ("Unspecified", RTCSdpType::Unspecified), + ("offer", RTCSdpType::Offer), + ("pranswer", RTCSdpType::Pranswer), + ("answer", RTCSdpType::Answer), + ("rollback", RTCSdpType::Rollback), + ]; + + for (sdp_type_string, expected_sdp_type) in tests { + assert_eq!(RTCSdpType::from(sdp_type_string), expected_sdp_type); + } + } + + #[test] + fn test_sdp_type_string() { + let tests = vec![ + (RTCSdpType::Unspecified, "Unspecified"), + (RTCSdpType::Offer, "offer"), + (RTCSdpType::Pranswer, "pranswer"), + (RTCSdpType::Answer, "answer"), + (RTCSdpType::Rollback, "rollback"), + ]; + + for (sdp_type, expected_string) in tests { + assert_eq!(sdp_type.to_string(), expected_string); + } + } +} diff --git a/rtc/src/peer_connection/sdp/session_description.rs b/rtc/src/peer_connection/sdp/session_description.rs new file mode 100644 index 0000000..7085f34 --- /dev/null +++ b/rtc/src/peer_connection/sdp/session_description.rs @@ -0,0 +1,238 @@ +use std::io::Cursor; + +use sdp::description::session::SessionDescription; +use serde::{Deserialize, Serialize}; + +use super::sdp_type::RTCSdpType; +use crate::error::Result; + +/// SessionDescription is used to expose local and remote session descriptions. +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct RTCSessionDescription { + #[serde(rename = "type")] + pub sdp_type: RTCSdpType, + + pub sdp: String, + + /// This will never be initialized by callers, internal use only + #[serde(skip)] + pub(crate) parsed: Option, +} + +impl RTCSessionDescription { + /// Given SDP representing an answer, wrap it in an RTCSessionDescription + /// that can be given to an RTCPeerConnection. + pub fn answer(sdp: String) -> Result { + let mut desc = RTCSessionDescription { + sdp, + sdp_type: RTCSdpType::Answer, + parsed: None, + }; + + let parsed = desc.unmarshal()?; + desc.parsed = Some(parsed); + + Ok(desc) + } + + /// Given SDP representing an offer, wrap it in an RTCSessionDescription + /// that can be given to an RTCPeerConnection. + pub fn offer(sdp: String) -> Result { + let mut desc = RTCSessionDescription { + sdp, + sdp_type: RTCSdpType::Offer, + parsed: None, + }; + + let parsed = desc.unmarshal()?; + desc.parsed = Some(parsed); + + Ok(desc) + } + + /// Given SDP representing an answer, wrap it in an RTCSessionDescription + /// that can be given to an RTCPeerConnection. `pranswer` is used when the + /// answer may not be final, or when updating a previously sent pranswer. + pub fn pranswer(sdp: String) -> Result { + let mut desc = RTCSessionDescription { + sdp, + sdp_type: RTCSdpType::Pranswer, + parsed: None, + }; + + let parsed = desc.unmarshal()?; + desc.parsed = Some(parsed); + + Ok(desc) + } + + /// Unmarshal is a helper to deserialize the sdp + pub fn unmarshal(&self) -> Result { + let mut reader = Cursor::new(self.sdp.as_bytes()); + let parsed = SessionDescription::unmarshal(&mut reader)?; + Ok(parsed) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::api::media_engine::MediaEngine; + use crate::api::APIBuilder; + use crate::peer_connection::configuration::RTCConfiguration; + + #[test] + fn test_session_description_json() { + let tests = vec![ + ( + RTCSessionDescription { + sdp_type: RTCSdpType::Offer, + sdp: "sdp".to_owned(), + parsed: None, + }, + r#"{"type":"offer","sdp":"sdp"}"#, + ), + ( + RTCSessionDescription { + sdp_type: RTCSdpType::Pranswer, + sdp: "sdp".to_owned(), + parsed: None, + }, + r#"{"type":"pranswer","sdp":"sdp"}"#, + ), + ( + RTCSessionDescription { + sdp_type: RTCSdpType::Answer, + sdp: "sdp".to_owned(), + parsed: None, + }, + r#"{"type":"answer","sdp":"sdp"}"#, + ), + ( + RTCSessionDescription { + sdp_type: RTCSdpType::Rollback, + sdp: "sdp".to_owned(), + parsed: None, + }, + r#"{"type":"rollback","sdp":"sdp"}"#, + ), + ( + RTCSessionDescription { + sdp_type: RTCSdpType::Unspecified, + sdp: "sdp".to_owned(), + parsed: None, + }, + r#"{"type":"Unspecified","sdp":"sdp"}"#, + ), + ]; + + for (desc, expected_string) in tests { + let result = serde_json::to_string(&desc); + assert!(result.is_ok(), "testCase: marshal err: {result:?}"); + let desc_data = result.unwrap(); + assert_eq!(desc_data, expected_string, "string is not expected"); + + let result = serde_json::from_str::(&desc_data); + assert!(result.is_ok(), "testCase: unmarshal err: {result:?}"); + if let Ok(sd) = result { + assert!(sd.sdp == desc.sdp && sd.sdp_type == desc.sdp_type); + } + } + } + + #[tokio::test] + async fn test_session_description_answer() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let offer_pc = api.new_peer_connection(RTCConfiguration::default()).await?; + let answer_pc = api.new_peer_connection(RTCConfiguration::default()).await?; + + let _ = offer_pc.create_data_channel("foo", None).await?; + let offer = offer_pc.create_offer(None).await?; + answer_pc.set_remote_description(offer).await?; + + let answer = answer_pc.create_answer(None).await?; + + let desc = RTCSessionDescription::answer(answer.sdp.clone())?; + + assert!(desc.sdp_type == RTCSdpType::Answer); + assert!(desc.parsed.is_some()); + + assert_eq!(answer.unmarshal()?.marshal(), desc.unmarshal()?.marshal()); + + Ok(()) + } + + #[tokio::test] + async fn test_session_description_offer() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let pc = api.new_peer_connection(RTCConfiguration::default()).await?; + let offer = pc.create_offer(None).await?; + + let desc = RTCSessionDescription::offer(offer.sdp.clone())?; + + assert!(desc.sdp_type == RTCSdpType::Offer); + assert!(desc.parsed.is_some()); + + assert_eq!(offer.unmarshal()?.marshal(), desc.unmarshal()?.marshal()); + + Ok(()) + } + + #[tokio::test] + async fn test_session_description_pranswer() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let offer_pc = api.new_peer_connection(RTCConfiguration::default()).await?; + let answer_pc = api.new_peer_connection(RTCConfiguration::default()).await?; + + let _ = offer_pc.create_data_channel("foo", None).await?; + let offer = offer_pc.create_offer(None).await?; + answer_pc.set_remote_description(offer).await?; + + let answer = answer_pc.create_answer(None).await?; + + let desc = RTCSessionDescription::pranswer(answer.sdp)?; + + assert!(desc.sdp_type == RTCSdpType::Pranswer); + assert!(desc.parsed.is_some()); + + Ok(()) + } + + #[tokio::test] + async fn test_session_description_unmarshal() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let pc = api.new_peer_connection(RTCConfiguration::default()).await?; + + let offer = pc.create_offer(None).await?; + + let desc = RTCSessionDescription { + sdp_type: offer.sdp_type, + sdp: offer.sdp, + ..Default::default() + }; + + assert!(desc.parsed.is_none()); + + let parsed1 = desc.unmarshal()?; + let parsed2 = desc.unmarshal()?; + + pc.close().await?; + + // check if the two parsed results _really_ match, could be affected by internal caching + assert_eq!(parsed1.marshal(), parsed2.marshal()); + + Ok(()) + } +} diff --git a/rtc/src/peer_connection/signaling_state.rs b/rtc/src/peer_connection/signaling_state.rs new file mode 100644 index 0000000..2edd613 --- /dev/null +++ b/rtc/src/peer_connection/signaling_state.rs @@ -0,0 +1,365 @@ +use std::fmt; + +use crate::error::{Error, Result}; +use crate::peer_connection::sdp::sdp_type::RTCSdpType; + +#[derive(Default, Debug, Copy, Clone, PartialEq)] +pub(crate) enum StateChangeOp { + #[default] + SetLocal, + SetRemote, +} + +impl fmt::Display for StateChangeOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + StateChangeOp::SetLocal => write!(f, "SetLocal"), + StateChangeOp::SetRemote => write!(f, "SetRemote"), + //_ => write!(f, UNSPECIFIED_STR), + } + } +} + +/// SignalingState indicates the signaling state of the offer/answer process. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RTCSignalingState { + #[default] + Unspecified = 0, + + /// SignalingStateStable indicates there is no offer/answer exchange in + /// progress. This is also the initial state, in which case the local and + /// remote descriptions are nil. + Stable, + + /// SignalingStateHaveLocalOffer indicates that a local description, of + /// type "offer", has been successfully applied. + HaveLocalOffer, + + /// SignalingStateHaveRemoteOffer indicates that a remote description, of + /// type "offer", has been successfully applied. + HaveRemoteOffer, + + /// SignalingStateHaveLocalPranswer indicates that a remote description + /// of type "offer" has been successfully applied and a local description + /// of type "pranswer" has been successfully applied. + HaveLocalPranswer, + + /// SignalingStateHaveRemotePranswer indicates that a local description + /// of type "offer" has been successfully applied and a remote description + /// of type "pranswer" has been successfully applied. + HaveRemotePranswer, + + /// SignalingStateClosed indicates The PeerConnection has been closed. + Closed, +} + +const SIGNALING_STATE_STABLE_STR: &str = "stable"; +const SIGNALING_STATE_HAVE_LOCAL_OFFER_STR: &str = "have-local-offer"; +const SIGNALING_STATE_HAVE_REMOTE_OFFER_STR: &str = "have-remote-offer"; +const SIGNALING_STATE_HAVE_LOCAL_PRANSWER_STR: &str = "have-local-pranswer"; +const SIGNALING_STATE_HAVE_REMOTE_PRANSWER_STR: &str = "have-remote-pranswer"; +const SIGNALING_STATE_CLOSED_STR: &str = "closed"; + +impl From<&str> for RTCSignalingState { + fn from(raw: &str) -> Self { + match raw { + SIGNALING_STATE_STABLE_STR => RTCSignalingState::Stable, + SIGNALING_STATE_HAVE_LOCAL_OFFER_STR => RTCSignalingState::HaveLocalOffer, + SIGNALING_STATE_HAVE_REMOTE_OFFER_STR => RTCSignalingState::HaveRemoteOffer, + SIGNALING_STATE_HAVE_LOCAL_PRANSWER_STR => RTCSignalingState::HaveLocalPranswer, + SIGNALING_STATE_HAVE_REMOTE_PRANSWER_STR => RTCSignalingState::HaveRemotePranswer, + SIGNALING_STATE_CLOSED_STR => RTCSignalingState::Closed, + _ => RTCSignalingState::Unspecified, + } + } +} + +impl fmt::Display for RTCSignalingState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + RTCSignalingState::Stable => write!(f, "{SIGNALING_STATE_STABLE_STR}"), + RTCSignalingState::HaveLocalOffer => { + write!(f, "{SIGNALING_STATE_HAVE_LOCAL_OFFER_STR}") + } + RTCSignalingState::HaveRemoteOffer => { + write!(f, "{SIGNALING_STATE_HAVE_REMOTE_OFFER_STR}") + } + RTCSignalingState::HaveLocalPranswer => { + write!(f, "{SIGNALING_STATE_HAVE_LOCAL_PRANSWER_STR}") + } + RTCSignalingState::HaveRemotePranswer => { + write!(f, "{SIGNALING_STATE_HAVE_REMOTE_PRANSWER_STR}") + } + RTCSignalingState::Closed => write!(f, "{SIGNALING_STATE_CLOSED_STR}"), + _ => write!(f, "{}", crate::UNSPECIFIED_STR), + } + } +} + +impl From for RTCSignalingState { + fn from(v: u8) -> Self { + match v { + 1 => RTCSignalingState::Stable, + 2 => RTCSignalingState::HaveLocalOffer, + 3 => RTCSignalingState::HaveRemoteOffer, + 4 => RTCSignalingState::HaveLocalPranswer, + 5 => RTCSignalingState::HaveRemotePranswer, + 6 => RTCSignalingState::Closed, + _ => RTCSignalingState::Unspecified, + } + } +} + +pub(crate) fn check_next_signaling_state( + cur: RTCSignalingState, + next: RTCSignalingState, + op: StateChangeOp, + sdp_type: RTCSdpType, +) -> Result { + // Special case for rollbacks + if sdp_type == RTCSdpType::Rollback && cur == RTCSignalingState::Stable { + return Err(Error::ErrSignalingStateCannotRollback); + } + + // 4.3.1 valid state transitions + match cur { + RTCSignalingState::Stable => { + match op { + StateChangeOp::SetLocal => { + // stable->SetLocal(offer)->have-local-offer + if sdp_type == RTCSdpType::Offer && next == RTCSignalingState::HaveLocalOffer { + return Ok(next); + } + } + StateChangeOp::SetRemote => { + // stable->SetRemote(offer)->have-remote-offer + if sdp_type == RTCSdpType::Offer && next == RTCSignalingState::HaveRemoteOffer { + return Ok(next); + } + } + } + } + RTCSignalingState::HaveLocalOffer => { + if op == StateChangeOp::SetRemote { + match sdp_type { + // have-local-offer->SetRemote(answer)->stable + RTCSdpType::Answer => { + if next == RTCSignalingState::Stable { + return Ok(next); + } + } + // have-local-offer->SetRemote(pranswer)->have-remote-pranswer + RTCSdpType::Pranswer => { + if next == RTCSignalingState::HaveRemotePranswer { + return Ok(next); + } + } + _ => {} + } + } else if op == StateChangeOp::SetLocal + && sdp_type == RTCSdpType::Offer + && next == RTCSignalingState::HaveLocalOffer + { + return Ok(next); + } + } + RTCSignalingState::HaveRemotePranswer => { + if op == StateChangeOp::SetRemote && sdp_type == RTCSdpType::Answer { + // have-remote-pranswer->SetRemote(answer)->stable + if next == RTCSignalingState::Stable { + return Ok(next); + } + } + } + RTCSignalingState::HaveRemoteOffer => { + if op == StateChangeOp::SetLocal { + match sdp_type { + // have-remote-offer->SetLocal(answer)->stable + RTCSdpType::Answer => { + if next == RTCSignalingState::Stable { + return Ok(next); + } + } + // have-remote-offer->SetLocal(pranswer)->have-local-pranswer + RTCSdpType::Pranswer => { + if next == RTCSignalingState::HaveLocalPranswer { + return Ok(next); + } + } + _ => {} + } + } + } + RTCSignalingState::HaveLocalPranswer => { + if op == StateChangeOp::SetLocal && sdp_type == RTCSdpType::Answer { + // have-local-pranswer->SetLocal(answer)->stable + if next == RTCSignalingState::Stable { + return Ok(next); + } + } + } + _ => { + return Err(Error::ErrSignalingStateProposedTransitionInvalid { + from: cur, + applying: sdp_type, + is_local: op == StateChangeOp::SetLocal, + }); + } + }; + + Err(Error::ErrSignalingStateProposedTransitionInvalid { + from: cur, + is_local: op == StateChangeOp::SetLocal, + applying: sdp_type, + }) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_signaling_state() { + let tests = vec![ + ("Unspecified", RTCSignalingState::Unspecified), + ("stable", RTCSignalingState::Stable), + ("have-local-offer", RTCSignalingState::HaveLocalOffer), + ("have-remote-offer", RTCSignalingState::HaveRemoteOffer), + ("have-local-pranswer", RTCSignalingState::HaveLocalPranswer), + ( + "have-remote-pranswer", + RTCSignalingState::HaveRemotePranswer, + ), + ("closed", RTCSignalingState::Closed), + ]; + + for (state_string, expected_state) in tests { + assert_eq!(RTCSignalingState::from(state_string), expected_state); + } + } + + #[test] + fn test_signaling_state_string() { + let tests = vec![ + (RTCSignalingState::Unspecified, "Unspecified"), + (RTCSignalingState::Stable, "stable"), + (RTCSignalingState::HaveLocalOffer, "have-local-offer"), + (RTCSignalingState::HaveRemoteOffer, "have-remote-offer"), + (RTCSignalingState::HaveLocalPranswer, "have-local-pranswer"), + ( + RTCSignalingState::HaveRemotePranswer, + "have-remote-pranswer", + ), + (RTCSignalingState::Closed, "closed"), + ]; + + for (state, expected_string) in tests { + assert_eq!(state.to_string(), expected_string); + } + } + + #[test] + fn test_signaling_state_transitions() { + let tests = vec![ + ( + "stable->SetLocal(offer)->have-local-offer", + RTCSignalingState::Stable, + RTCSignalingState::HaveLocalOffer, + StateChangeOp::SetLocal, + RTCSdpType::Offer, + None, + ), + ( + "stable->SetRemote(offer)->have-remote-offer", + RTCSignalingState::Stable, + RTCSignalingState::HaveRemoteOffer, + StateChangeOp::SetRemote, + RTCSdpType::Offer, + None, + ), + ( + "have-local-offer->SetRemote(answer)->stable", + RTCSignalingState::HaveLocalOffer, + RTCSignalingState::Stable, + StateChangeOp::SetRemote, + RTCSdpType::Answer, + None, + ), + ( + "have-local-offer->SetRemote(pranswer)->have-remote-pranswer", + RTCSignalingState::HaveLocalOffer, + RTCSignalingState::HaveRemotePranswer, + StateChangeOp::SetRemote, + RTCSdpType::Pranswer, + None, + ), + ( + "have-remote-pranswer->SetRemote(answer)->stable", + RTCSignalingState::HaveRemotePranswer, + RTCSignalingState::Stable, + StateChangeOp::SetRemote, + RTCSdpType::Answer, + None, + ), + ( + "have-remote-offer->SetLocal(answer)->stable", + RTCSignalingState::HaveRemoteOffer, + RTCSignalingState::Stable, + StateChangeOp::SetLocal, + RTCSdpType::Answer, + None, + ), + ( + "have-remote-offer->SetLocal(pranswer)->have-local-pranswer", + RTCSignalingState::HaveRemoteOffer, + RTCSignalingState::HaveLocalPranswer, + StateChangeOp::SetLocal, + RTCSdpType::Pranswer, + None, + ), + ( + "have-local-pranswer->SetLocal(answer)->stable", + RTCSignalingState::HaveLocalPranswer, + RTCSignalingState::Stable, + StateChangeOp::SetLocal, + RTCSdpType::Answer, + None, + ), + ( + "(invalid) stable->SetRemote(pranswer)->have-remote-pranswer", + RTCSignalingState::Stable, + RTCSignalingState::HaveRemotePranswer, + StateChangeOp::SetRemote, + RTCSdpType::Pranswer, + Some(Error::ErrSignalingStateProposedTransitionInvalid { + from: RTCSignalingState::Stable, + is_local: false, + applying: RTCSdpType::Pranswer, + }), + ), + ( + "(invalid) stable->SetRemote(rollback)->have-local-offer", + RTCSignalingState::Stable, + RTCSignalingState::HaveLocalOffer, + StateChangeOp::SetRemote, + RTCSdpType::Rollback, + Some(Error::ErrSignalingStateCannotRollback), + ), + ]; + + for (desc, cur, next, op, sdp_type, expected_err) in tests { + let result = check_next_signaling_state(cur, next, op, sdp_type); + match (&result, &expected_err) { + (Ok(got), None) => { + assert_eq!(*got, next, "{desc} state mismatch"); + } + (Err(got), Some(err)) => { + assert_eq!(got.to_string(), err.to_string(), "{desc} error mismatch"); + } + _ => { + panic!("{desc}: expected {expected_err:?}, but got {result:?}"); + } + }; + } + } +} diff --git a/rtc/src/rtp_transceiver/fmtp/generic/generic_test.rs b/rtc/src/rtp_transceiver/fmtp/generic/generic_test.rs new file mode 100644 index 0000000..f37c4af --- /dev/null +++ b/rtc/src/rtp_transceiver/fmtp/generic/generic_test.rs @@ -0,0 +1,160 @@ +use super::*; + +#[test] +fn test_generic_fmtp_parse() { + let tests: Vec<(&str, &str, Box)> = vec![ + ( + "OneParam", + "key-name=value", + Box::new(GenericFmtp { + mime_type: "generic".to_owned(), + parameters: [("key-name".to_owned(), "value".to_owned())] + .iter() + .cloned() + .collect(), + }), + ), + ( + "OneParamWithWhiteSpeces", + "\tkey-name=value ", + Box::new(GenericFmtp { + mime_type: "generic".to_owned(), + parameters: [("key-name".to_owned(), "value".to_owned())] + .iter() + .cloned() + .collect(), + }), + ), + ( + "TwoParams", + "key-name=value;key2=value2", + Box::new(GenericFmtp { + mime_type: "generic".to_owned(), + parameters: [ + ("key-name".to_owned(), "value".to_owned()), + ("key2".to_owned(), "value2".to_owned()), + ] + .iter() + .cloned() + .collect(), + }), + ), + ( + "TwoParamsWithWhiteSpeces", + "key-name=value; \n\tkey2=value2 ", + Box::new(GenericFmtp { + mime_type: "generic".to_owned(), + parameters: [ + ("key-name".to_owned(), "value".to_owned()), + ("key2".to_owned(), "value2".to_owned()), + ] + .iter() + .cloned() + .collect(), + }), + ), + ]; + + for (name, input, expected) in tests { + let f = parse("generic", input); + assert_eq!(&f, &expected, "{name} failed"); + + assert_eq!(f.mime_type(), "generic"); + } +} + +#[test] +fn test_generic_fmtp_compare() { + let consist_string: HashMap = [ + (true, "consist".to_owned()), + (false, "inconsist".to_owned()), + ] + .iter() + .cloned() + .collect(); + + let tests = vec![ + ( + "Equal", + "key1=value1;key2=value2;key3=value3", + "key1=value1;key2=value2;key3=value3", + true, + ), + ( + "EqualWithWhitespaceVariants", + "key1=value1;key2=value2;key3=value3", + " key1=value1; \nkey2=value2;\t\nkey3=value3", + true, + ), + ( + "EqualWithCase", + "key1=value1;key2=value2;key3=value3", + "key1=value1;key2=Value2;Key3=value3", + true, + ), + ( + "OneHasExtraParam", + "key1=value1;key2=value2;key3=value3", + "key1=value1;key2=value2;key3=value3;key4=value4", + true, + ), + ( + "Inconsistent", + "key1=value1;key2=value2;key3=value3", + "key1=value1;key2=different_value;key3=value3", + false, + ), + ( + "Inconsistent_OneHasExtraParam", + "key1=value1;key2=value2;key3=value3;key4=value4", + "key1=value1;key2=different_value;key3=value3", + false, + ), + ]; + + for (name, a, b, consist) in tests { + let check = |a, b| { + let aa = parse("", a); + let bb = parse("", b); + + // test forward case here + let c = aa.match_fmtp(&*bb); + assert_eq!( + c, + consist, + "{}: '{}' and '{}' are expected to be {:?}, but treated as {:?}", + name, + a, + b, + consist_string.get(&consist), + consist_string.get(&c), + ); + + // test reverse case here + let c = bb.match_fmtp(&*aa); + assert_eq!( + c, + consist, + "{}: '{}' and '{}' are expected to be {:?}, but treated as {:?}", + name, + a, + b, + consist_string.get(&consist), + consist_string.get(&c), + ); + }; + + check(a, b); + } +} + +#[test] +fn test_generic_fmtp_compare_mime_type_case_mismatch() { + let a = parse("video/vp8", ""); + let b = parse("video/VP8", ""); + + assert!( + b.match_fmtp(&*a), + "fmtp lines should match even if they use different casing" + ); +} diff --git a/rtc/src/rtp_transceiver/fmtp/generic/mod.rs b/rtc/src/rtp_transceiver/fmtp/generic/mod.rs new file mode 100644 index 0000000..4b75ce3 --- /dev/null +++ b/rtc/src/rtp_transceiver/fmtp/generic/mod.rs @@ -0,0 +1,65 @@ +#[cfg(test)] +mod generic_test; + +use super::*; + +/// fmtp_consist checks that two FMTP parameters are not inconsistent. +fn fmtp_consist(a: &HashMap, b: &HashMap) -> bool { + //TODO: add unicode case-folding equal support + for (k, v) in a { + if let Some(vb) = b.get(k) { + if vb.to_uppercase() != v.to_uppercase() { + return false; + } + } + } + for (k, v) in b { + if let Some(va) = a.get(k) { + if va.to_uppercase() != v.to_uppercase() { + return false; + } + } + } + true +} + +#[derive(Debug, PartialEq)] +pub(crate) struct GenericFmtp { + pub(crate) mime_type: String, + pub(crate) parameters: HashMap, +} + +impl Fmtp for GenericFmtp { + fn mime_type(&self) -> &str { + self.mime_type.as_str() + } + + /// Match returns true if g and b are compatible fmtp descriptions + /// The generic implementation is used for MimeTypes that are not defined + fn match_fmtp(&self, f: &(dyn Fmtp)) -> bool { + if let Some(c) = f.as_any().downcast_ref::() { + if self.mime_type.to_lowercase() != c.mime_type().to_lowercase() { + return false; + } + + fmtp_consist(&self.parameters, &c.parameters) + } else { + false + } + } + + fn parameter(&self, key: &str) -> Option<&String> { + self.parameters.get(key) + } + + fn equal(&self, other: &(dyn Fmtp)) -> bool { + other + .as_any() + .downcast_ref::() + .map_or(false, |a| self == a) + } + + fn as_any(&self) -> &(dyn Any) { + self + } +} diff --git a/rtc/src/rtp_transceiver/fmtp/h264/h264_test.rs b/rtc/src/rtp_transceiver/fmtp/h264/h264_test.rs new file mode 100644 index 0000000..fe97dbe --- /dev/null +++ b/rtc/src/rtp_transceiver/fmtp/h264/h264_test.rs @@ -0,0 +1,163 @@ +use super::*; + +#[test] +fn test_h264_fmtp_parse() { + let tests: Vec<(&str, &str, Box)> = vec![ + ( + "OneParam", + "key-name=value", + Box::new(H264Fmtp { + parameters: [("key-name".to_owned(), "value".to_owned())] + .iter() + .cloned() + .collect(), + }), + ), + ( + "OneParamWithWhiteSpeces", + "\tkey-name=value ", + Box::new(H264Fmtp { + parameters: [("key-name".to_owned(), "value".to_owned())] + .iter() + .cloned() + .collect(), + }), + ), + ( + "TwoParams", + "key-name=value;key2=value2", + Box::new(H264Fmtp { + parameters: [ + ("key-name".to_owned(), "value".to_owned()), + ("key2".to_owned(), "value2".to_owned()), + ] + .iter() + .cloned() + .collect(), + }), + ), + ( + "TwoParamsWithWhiteSpeces", + "key-name=value; \n\tkey2=value2 ", + Box::new(H264Fmtp { + parameters: [ + ("key-name".to_owned(), "value".to_owned()), + ("key2".to_owned(), "value2".to_owned()), + ] + .iter() + .cloned() + .collect(), + }), + ), + ]; + + for (name, input, expected) in tests { + let f = parse("video/h264", input); + assert_eq!(&f, &expected, "{name} failed"); + + assert_eq!(f.mime_type(), "video/h264"); + } +} + +#[test] +fn test_h264_fmtp_compare() { + let consist_string: HashMap = [ + (true, "consist".to_owned()), + (false, "inconsist".to_owned()), + ] + .iter() + .cloned() + .collect(); + + let tests = vec![ + ( + "Equal", + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + true, + ), + ( + "EqualWithWhitespaceVariants", + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + " level-asymmetry-allowed=1; \npacketization-mode=1;\t\nprofile-level-id=42e01f", + true, + ), + ( + "EqualWithCase", + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + "level-asymmetry-allowed=1;packetization-mode=1;PROFILE-LEVEL-ID=42e01f", + true, + ), + ( + "OneHasExtraParam", + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + "packetization-mode=1;profile-level-id=42e01f", + true, + ), + ( + "DifferentProfileLevelIDVersions", + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + "packetization-mode=1;profile-level-id=42e029", + true, + ), + ( + "Inconsistent", + "packetization-mode=1;profile-level-id=42e029", + "packetization-mode=0;profile-level-id=42e029", + false, + ), + ( + "Inconsistent_MissingPacketizationMode", + "packetization-mode=1;profile-level-id=42e029", + "profile-level-id=42e029", + false, + ), + ( + "Inconsistent_MissingProfileLevelID", + "packetization-mode=1;profile-level-id=42e029", + "packetization-mode=1", + false, + ), + ( + "Inconsistent_InvalidProfileLevelID", + "packetization-mode=1;profile-level-id=42e029", + "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=41e029", + false, + ), + ]; + + for (name, a, b, consist) in tests { + let check = |a, b| { + let aa = parse("video/h264", a); + let bb = parse("video/h264", b); + + // test forward case here + let c = aa.match_fmtp(&*bb); + assert_eq!( + c, + consist, + "{}: '{}' and '{}' are expected to be {:?}, but treated as {:?}", + name, + a, + b, + consist_string.get(&consist), + consist_string.get(&c), + ); + + // test reverse case here + let c = bb.match_fmtp(&*aa); + assert_eq!( + c, + consist, + "{}: '{}' and '{}' are expected to be {:?}, but treated as {:?}", + name, + a, + b, + consist_string.get(&consist), + consist_string.get(&c), + ); + }; + + check(a, b); + } +} diff --git a/rtc/src/rtp_transceiver/fmtp/h264/mod.rs b/rtc/src/rtp_transceiver/fmtp/h264/mod.rs new file mode 100644 index 0000000..22a5bf5 --- /dev/null +++ b/rtc/src/rtp_transceiver/fmtp/h264/mod.rs @@ -0,0 +1,102 @@ +#[cfg(test)] +mod h264_test; + +use super::*; + +fn profile_level_id_matches(a: &str, b: &str) -> bool { + let aa = match hex::decode(a) { + Ok(aa) => { + if aa.len() < 2 { + return false; + } + aa + } + Err(_) => return false, + }; + + let bb = match hex::decode(b) { + Ok(bb) => { + if bb.len() < 2 { + return false; + } + bb + } + Err(_) => return false, + }; + + aa[0] == bb[0] && aa[1] == bb[1] +} + +#[derive(Debug, PartialEq)] +pub(crate) struct H264Fmtp { + pub(crate) parameters: HashMap, +} + +impl Fmtp for H264Fmtp { + fn mime_type(&self) -> &str { + "video/h264" + } + + /// Match returns true if h and b are compatible fmtp descriptions + /// Based on RFC6184 Section 8.2.2: + /// The parameters identifying a media format configuration for H.264 + /// are profile-level-id and packetization-mode. These media format + /// configuration parameters (except for the level part of profile- + /// level-id) MUST be used symmetrically; that is, the answerer MUST + /// either maintain all configuration parameters or remove the media + /// format (payload type) completely if one or more of the parameter + /// values are not supported. + /// Informative note: The requirement for symmetric use does not + /// apply for the level part of profile-level-id and does not apply + /// for the other stream properties and capability parameters. + fn match_fmtp(&self, f: &(dyn Fmtp)) -> bool { + if let Some(c) = f.as_any().downcast_ref::() { + // test packetization-mode + let hpmode = match self.parameters.get("packetization-mode") { + Some(s) => s, + None => return false, + }; + let cpmode = match c.parameters.get("packetization-mode") { + Some(s) => s, + None => return false, + }; + + if hpmode != cpmode { + return false; + } + + // test profile-level-id + let hplid = match self.parameters.get("profile-level-id") { + Some(s) => s, + None => return false, + }; + let cplid = match c.parameters.get("profile-level-id") { + Some(s) => s, + None => return false, + }; + + if !profile_level_id_matches(hplid, cplid) { + return false; + } + + true + } else { + false + } + } + + fn parameter(&self, key: &str) -> Option<&String> { + self.parameters.get(key) + } + + fn equal(&self, other: &(dyn Fmtp)) -> bool { + other + .as_any() + .downcast_ref::() + .map_or(false, |a| self == a) + } + + fn as_any(&self) -> &(dyn Any) { + self + } +} diff --git a/rtc/src/rtp_transceiver/fmtp/mod.rs b/rtc/src/rtp_transceiver/fmtp/mod.rs new file mode 100644 index 0000000..ea6e83b --- /dev/null +++ b/rtc/src/rtp_transceiver/fmtp/mod.rs @@ -0,0 +1,58 @@ +pub(crate) mod generic; +pub(crate) mod h264; + +use std::any::Any; +use std::collections::HashMap; +use std::fmt; + +use crate::rtp_transceiver::fmtp::generic::GenericFmtp; +use crate::rtp_transceiver::fmtp::h264::H264Fmtp; + +/// Fmtp interface for implementing custom +/// Fmtp parsers based on mime_type +pub trait Fmtp: fmt::Debug { + /// mime_type returns the mime_type associated with + /// the fmtp + fn mime_type(&self) -> &str; + + /// match_fmtp compares two fmtp descriptions for + /// compatibility based on the mime_type + fn match_fmtp(&self, f: &(dyn Fmtp)) -> bool; + + /// parameter returns a value for the associated key + /// if contained in the parsed fmtp string + fn parameter(&self, key: &str) -> Option<&String>; + + fn equal(&self, other: &(dyn Fmtp)) -> bool; + fn as_any(&self) -> &(dyn Any); +} + +impl PartialEq for dyn Fmtp { + fn eq(&self, other: &Self) -> bool { + self.equal(other) + } +} + +/// parse parses an fmtp string based on the MimeType +pub fn parse(mime_type: &str, line: &str) -> Box { + let mut parameters = HashMap::new(); + for p in line.split(';').collect::>() { + let pp: Vec<&str> = p.trim().splitn(2, '=').collect(); + let key = pp[0].to_lowercase(); + let value = if pp.len() > 1 { + pp[1].to_owned() + } else { + String::new() + }; + parameters.insert(key, value); + } + + if mime_type.to_uppercase() == "video/h264".to_uppercase() { + Box::new(H264Fmtp { parameters }) + } else { + Box::new(GenericFmtp { + mime_type: mime_type.to_owned(), + parameters, + }) + } +} diff --git a/rtc/src/rtp_transceiver/mod.rs b/rtc/src/rtp_transceiver/mod.rs new file mode 100644 index 0000000..1edb5fc --- /dev/null +++ b/rtc/src/rtp_transceiver/mod.rs @@ -0,0 +1,560 @@ +#[cfg(test)] +mod rtp_transceiver_test; + +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; +use std::sync::Arc; + +use interceptor::stream_info::{RTPHeaderExtension, StreamInfo}; +use interceptor::Attributes; +use log::trace; +use serde::{Deserialize, Serialize}; +use smol_str::SmolStr; +use tokio::sync::{Mutex, OnceCell}; +use util::Unmarshal; + +use crate::api::media_engine::MediaEngine; +use crate::error::{Error, Result}; +use crate::rtp_transceiver::rtp_codec::*; +use crate::rtp_transceiver::rtp_receiver::{RTCRtpReceiver, RTPReceiverInternal}; +use crate::rtp_transceiver::rtp_sender::RTCRtpSender; +use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; +use crate::track::track_local::TrackLocal; + +pub(crate) mod fmtp; +pub mod rtp_codec; +pub mod rtp_receiver; +pub mod rtp_sender; +pub mod rtp_transceiver_direction; +pub(crate) mod srtp_writer_future; + +/// SSRC represents a synchronization source +/// A synchronization source is a randomly chosen +/// value meant to be globally unique within a particular +/// RTP session. Used to identify a single stream of media. +/// +#[allow(clippy::upper_case_acronyms)] +pub type SSRC = u32; + +/// PayloadType identifies the format of the RTP payload and determines +/// its interpretation by the application. Each codec in a RTP Session +/// will have a different PayloadType +/// +pub type PayloadType = u8; + +/// TYPE_RTCP_FBT_RANSPORT_CC .. +pub const TYPE_RTCP_FB_TRANSPORT_CC: &str = "transport-cc"; + +/// TYPE_RTCP_FB_GOOG_REMB .. +pub const TYPE_RTCP_FB_GOOG_REMB: &str = "goog-remb"; + +/// TYPE_RTCP_FB_ACK .. +pub const TYPE_RTCP_FB_ACK: &str = "ack"; + +/// TYPE_RTCP_FB_CCM .. +pub const TYPE_RTCP_FB_CCM: &str = "ccm"; + +/// TYPE_RTCP_FB_NACK .. +pub const TYPE_RTCP_FB_NACK: &str = "nack"; + +/// rtcpfeedback signals the connection to use additional RTCP packet types. +/// +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct RTCPFeedback { + /// Type is the type of feedback. + /// see: + /// valid: ack, ccm, nack, goog-remb, transport-cc + pub typ: String, + + /// The parameter value depends on the type. + /// For example, type="nack" parameter="pli" will send Picture Loss Indicator packets. + pub parameter: String, +} + +/// RTPCapabilities represents the capabilities of a transceiver +/// +#[derive(Default, Debug, Clone)] +pub struct RTCRtpCapabilities { + pub codecs: Vec, + pub header_extensions: Vec, +} + +/// RTPRtxParameters dictionary contains information relating to retransmission (RTX) settings. +/// +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct RTCRtpRtxParameters { + pub ssrc: SSRC, +} + +/// RTPCodingParameters provides information relating to both encoding and decoding. +/// This is a subset of the RFC since Pion WebRTC doesn't implement encoding/decoding itself +/// +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct RTCRtpCodingParameters { + pub rid: SmolStr, + pub ssrc: SSRC, + pub payload_type: PayloadType, + pub rtx: RTCRtpRtxParameters, +} + +/// RTPDecodingParameters provides information relating to both encoding and decoding. +/// This is a subset of the RFC since Pion WebRTC doesn't implement decoding itself +/// +pub type RTCRtpDecodingParameters = RTCRtpCodingParameters; + +/// RTPEncodingParameters provides information relating to both encoding and decoding. +/// This is a subset of the RFC since Pion WebRTC doesn't implement encoding itself +/// +pub type RTCRtpEncodingParameters = RTCRtpCodingParameters; + +/// RTPReceiveParameters contains the RTP stack settings used by receivers +#[derive(Debug)] +pub struct RTCRtpReceiveParameters { + pub encodings: Vec, +} + +/// RTPSendParameters contains the RTP stack settings used by receivers +#[derive(Debug)] +pub struct RTCRtpSendParameters { + pub rtp_parameters: RTCRtpParameters, + pub encodings: Vec, +} + +/// RTPTransceiverInit dictionary is used when calling the WebRTC function addTransceiver() to provide configuration options for the new transceiver. +pub struct RTCRtpTransceiverInit { + pub direction: RTCRtpTransceiverDirection, + pub send_encodings: Vec, + // Streams []*Track +} + +pub(crate) fn create_stream_info( + id: String, + ssrc: SSRC, + payload_type: PayloadType, + codec: RTCRtpCodecCapability, + webrtc_header_extensions: &[RTCRtpHeaderExtensionParameters], +) -> StreamInfo { + let header_extensions: Vec = webrtc_header_extensions + .iter() + .map(|h| RTPHeaderExtension { + id: h.id, + uri: h.uri.clone(), + }) + .collect(); + + let feedbacks: Vec<_> = codec + .rtcp_feedback + .iter() + .map(|f| interceptor::stream_info::RTCPFeedback { + typ: f.typ.clone(), + parameter: f.parameter.clone(), + }) + .collect(); + + StreamInfo { + id, + attributes: Attributes::new(), + ssrc, + payload_type, + rtp_header_extensions: header_extensions, + mime_type: codec.mime_type, + clock_rate: codec.clock_rate, + channels: codec.channels, + sdp_fmtp_line: codec.sdp_fmtp_line, + rtcp_feedback: feedbacks, + } +} + +pub type TriggerNegotiationNeededFnOption = + Option Pin + Send + Sync>> + Send + Sync>>; + +/// RTPTransceiver represents a combination of an RTPSender and an RTPReceiver that share a common mid. +pub struct RTCRtpTransceiver { + mid: OnceCell, //atomic.Value + sender: Mutex>, //atomic.Value + receiver: Mutex>, //atomic.Value + + direction: AtomicU8, //RTPTransceiverDirection + current_direction: AtomicU8, //RTPTransceiverDirection + + codecs: Arc>>, // User provided codecs via set_codec_preferences + + pub(crate) stopped: AtomicBool, + pub(crate) kind: RTPCodecType, + + media_engine: Arc, + + trigger_negotiation_needed: Mutex, +} + +impl RTCRtpTransceiver { + pub async fn new( + receiver: Arc, + sender: Arc, + direction: RTCRtpTransceiverDirection, + kind: RTPCodecType, + codecs: Vec, + media_engine: Arc, + trigger_negotiation_needed: TriggerNegotiationNeededFnOption, + ) -> Arc { + let codecs = Arc::new(Mutex::new(codecs)); + receiver.set_transceiver_codecs(Some(Arc::clone(&codecs))); + + let t = Arc::new(RTCRtpTransceiver { + mid: OnceCell::new(), + sender: Mutex::new(sender), + receiver: Mutex::new(receiver), + + direction: AtomicU8::new(direction as u8), + current_direction: AtomicU8::new(RTCRtpTransceiverDirection::Unspecified as u8), + + codecs, + stopped: AtomicBool::new(false), + kind, + media_engine, + trigger_negotiation_needed: Mutex::new(trigger_negotiation_needed), + }); + t.sender() + .await + .set_rtp_transceiver(Some(Arc::downgrade(&t))); + + t + } + + /// set_codec_preferences sets preferred list of supported codecs + /// if codecs is empty or nil we reset to default from MediaEngine + pub async fn set_codec_preferences(&self, codecs: Vec) -> Result<()> { + for codec in &codecs { + let media_engine_codecs = self.media_engine.get_codecs_by_kind(self.kind); + let (_, match_type) = codec_parameters_fuzzy_search(codec, &media_engine_codecs); + if match_type == CodecMatch::None { + return Err(Error::ErrRTPTransceiverCodecUnsupported); + } + } + + { + let mut c = self.codecs.lock().await; + *c = codecs; + } + Ok(()) + } + + /// Codecs returns list of supported codecs + pub(crate) async fn get_codecs(&self) -> Vec { + let mut codecs = self.codecs.lock().await; + RTPReceiverInternal::get_codecs(&mut codecs, self.kind, &self.media_engine) + } + + /// sender returns the RTPTransceiver's RTPSender if it has one + pub async fn sender(&self) -> Arc { + let sender = self.sender.lock().await; + sender.clone() + } + + /// set_sender_track sets the RTPSender and Track to current transceiver + pub async fn set_sender_track( + self: &Arc, + sender: Arc, + track: Option>, + ) -> Result<()> { + self.set_sender(sender).await; + self.set_sending_track(track).await + } + + pub async fn set_sender(self: &Arc, s: Arc) { + s.set_rtp_transceiver(Some(Arc::downgrade(self))); + + let prev_sender = self.sender().await; + prev_sender.set_rtp_transceiver(None); + + { + let mut sender = self.sender.lock().await; + *sender = s; + } + } + + /// receiver returns the RTPTransceiver's RTPReceiver if it has one + pub async fn receiver(&self) -> Arc { + let receiver = self.receiver.lock().await; + receiver.clone() + } + + pub(crate) async fn set_receiver(&self, r: Arc) { + r.set_transceiver_codecs(Some(Arc::clone(&self.codecs))); + + { + let mut receiver = self.receiver.lock().await; + (*receiver).set_transceiver_codecs(None); + + *receiver = r; + } + } + + /// set_mid sets the RTPTransceiver's mid. If it was already set, will return an error. + pub(crate) fn set_mid(&self, mid: SmolStr) -> Result<()> { + self.mid + .set(mid) + .map_err(|_| Error::ErrRTPTransceiverCannotChangeMid) + } + + /// mid gets the Transceiver's mid value. When not already set, this value will be set in CreateOffer or create_answer. + pub fn mid(&self) -> Option { + self.mid.get().map(Clone::clone) + } + + /// kind returns RTPTransceiver's kind. + pub fn kind(&self) -> RTPCodecType { + self.kind + } + + /// direction returns the RTPTransceiver's desired direction. + pub fn direction(&self) -> RTCRtpTransceiverDirection { + self.direction.load(Ordering::SeqCst).into() + } + + /// Set the direction of this transceiver. This might trigger a renegotiation. + pub async fn set_direction(&self, d: RTCRtpTransceiverDirection) { + let changed = self.set_direction_internal(d); + + if changed { + let lock = self.trigger_negotiation_needed.lock().await; + if let Some(trigger) = &*lock { + (trigger)().await; + } + } + } + + pub(crate) fn set_direction_internal(&self, d: RTCRtpTransceiverDirection) -> bool { + let previous: RTCRtpTransceiverDirection = + self.direction.swap(d as u8, Ordering::SeqCst).into(); + + let changed = d != previous; + + if changed { + trace!( + "Changing direction of transceiver from {} to {}", + previous, + d + ); + } + + changed + } + + /// current_direction returns the RTPTransceiver's current direction as negotiated. + /// + /// If this transceiver has never been negotiated or if it's stopped this returns [`RTCRtpTransceiverDirection::Unspecified`]. + pub fn current_direction(&self) -> RTCRtpTransceiverDirection { + if self.stopped.load(Ordering::SeqCst) { + return RTCRtpTransceiverDirection::Unspecified; + } + + self.current_direction.load(Ordering::SeqCst).into() + } + + pub(crate) fn set_current_direction(&self, d: RTCRtpTransceiverDirection) { + let previous: RTCRtpTransceiverDirection = self + .current_direction + .swap(d as u8, Ordering::SeqCst) + .into(); + + if d != previous { + trace!( + "Changing current direction of transceiver from {} to {}", + previous, + d, + ); + } + } + + /// Perform any subsequent actions after altering the transceiver's direction. + /// + /// After changing the transceiver's direction this method should be called to perform any + /// side-effects that results from the new direction, such as pausing/resuming the RTP receiver. + pub(crate) async fn process_new_current_direction( + &self, + previous_direction: RTCRtpTransceiverDirection, + ) -> Result<()> { + if self.stopped.load(Ordering::SeqCst) { + return Ok(()); + } + + let current_direction = self.current_direction(); + if previous_direction != current_direction { + let mid = self.mid(); + trace!( + "Processing transceiver({:?}) direction change from {} to {}", + mid, + previous_direction, + current_direction + ); + } else { + // no change. + return Ok(()); + } + + { + let receiver = self.receiver.lock().await; + let pause_receiver = !current_direction.has_recv(); + + if pause_receiver { + receiver.pause().await?; + } else { + receiver.resume().await?; + } + } + + let pause_sender = !current_direction.has_send(); + { + let sender = &*self.sender.lock().await; + sender.set_paused(pause_sender); + } + + Ok(()) + } + + /// stop irreversibly stops the RTPTransceiver + pub async fn stop(&self) -> Result<()> { + if self.stopped.load(Ordering::SeqCst) { + return Ok(()); + } + + self.stopped.store(true, Ordering::SeqCst); + + { + let sender = self.sender.lock().await; + sender.stop().await?; + } + { + let r = self.receiver.lock().await; + r.stop().await?; + } + + self.set_direction_internal(RTCRtpTransceiverDirection::Inactive); + + Ok(()) + } + + pub(crate) async fn set_sending_track( + &self, + track: Option>, + ) -> Result<()> { + let track_is_none = track.is_none(); + { + let sender = self.sender.lock().await; + sender.replace_track(track).await?; + } + + let direction = self.direction(); + let should_send = !track_is_none; + let should_recv = direction.has_recv(); + self.set_direction_internal(RTCRtpTransceiverDirection::from_send_recv( + should_send, + should_recv, + )); + + Ok(()) + } +} + +impl fmt::Debug for RTCRtpTransceiver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RTCRtpTransceiver") + .field("mid", &self.mid) + .field("sender", &self.sender) + .field("receiver", &self.receiver) + .field("direction", &self.direction) + .field("current_direction", &self.current_direction) + .field("codecs", &self.codecs) + .field("stopped", &self.stopped) + .field("kind", &self.kind) + .finish() + } +} + +pub(crate) async fn find_by_mid( + mid: &str, + local_transceivers: &mut Vec>, +) -> Option> { + for (i, t) in local_transceivers.iter().enumerate() { + if t.mid() == Some(SmolStr::from(mid)) { + return Some(local_transceivers.remove(i)); + } + } + + None +} + +/// Given a direction+type pluck a transceiver from the passed list +/// if no entry satisfies the requested type+direction return a inactive Transceiver +pub(crate) async fn satisfy_type_and_direction( + remote_kind: RTPCodecType, + remote_direction: RTCRtpTransceiverDirection, + local_transceivers: &mut Vec>, +) -> Option> { + // Get direction order from most preferred to least + let get_preferred_directions = || -> Vec { + match remote_direction { + RTCRtpTransceiverDirection::Sendrecv => vec![ + RTCRtpTransceiverDirection::Recvonly, + RTCRtpTransceiverDirection::Sendrecv, + ], + RTCRtpTransceiverDirection::Sendonly => vec![RTCRtpTransceiverDirection::Recvonly], + RTCRtpTransceiverDirection::Recvonly => vec![ + RTCRtpTransceiverDirection::Sendonly, + RTCRtpTransceiverDirection::Sendrecv, + ], + _ => vec![], + } + }; + + for possible_direction in get_preferred_directions() { + for (i, t) in local_transceivers.iter().enumerate() { + if t.mid().is_none() && t.kind == remote_kind && possible_direction == t.direction() { + return Some(local_transceivers.remove(i)); + } + } + } + + None +} + +/// handle_unknown_rtp_packet consumes a single RTP Packet and returns information that is helpful +/// for demuxing and handling an unknown SSRC (usually for Simulcast) +pub(crate) fn handle_unknown_rtp_packet( + buf: &[u8], + mid_extension_id: u8, + sid_extension_id: u8, + rsid_extension_id: u8, +) -> Result<(String, String, String, PayloadType)> { + let mut reader = buf; + let rp = rtp::packet::Packet::unmarshal(&mut reader)?; + + if !rp.header.extension { + return Ok((String::new(), String::new(), String::new(), 0)); + } + + let payload_type = rp.header.payload_type; + + let mid = if let Some(payload) = rp.header.get_extension(mid_extension_id) { + String::from_utf8(payload.to_vec())? + } else { + String::new() + }; + + let rid = if let Some(payload) = rp.header.get_extension(sid_extension_id) { + String::from_utf8(payload.to_vec())? + } else { + String::new() + }; + + let srid = if let Some(payload) = rp.header.get_extension(rsid_extension_id) { + String::from_utf8(payload.to_vec())? + } else { + String::new() + }; + + Ok((mid, rid, srid, payload_type)) +} diff --git a/rtc/src/rtp_transceiver/rtp_codec.rs b/rtc/src/rtp_transceiver/rtp_codec.rs new file mode 100644 index 0000000..afaa8ef --- /dev/null +++ b/rtc/src/rtp_transceiver/rtp_codec.rs @@ -0,0 +1,163 @@ +use std::fmt; + +use super::*; +use crate::api::media_engine::*; +use crate::error::{Error, Result}; +use crate::rtp_transceiver::fmtp; + +/// RTPCodecType determines the type of a codec +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +pub enum RTPCodecType { + #[default] + Unspecified = 0, + + /// RTPCodecTypeAudio indicates this is an audio codec + Audio = 1, + + /// RTPCodecTypeVideo indicates this is a video codec + Video = 2, +} + +impl From<&str> for RTPCodecType { + fn from(raw: &str) -> Self { + match raw { + "audio" => RTPCodecType::Audio, + "video" => RTPCodecType::Video, + _ => RTPCodecType::Unspecified, + } + } +} + +impl From for RTPCodecType { + fn from(v: u8) -> Self { + match v { + 1 => RTPCodecType::Audio, + 2 => RTPCodecType::Video, + _ => RTPCodecType::Unspecified, + } + } +} + +impl fmt::Display for RTPCodecType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + RTPCodecType::Audio => "audio", + RTPCodecType::Video => "video", + RTPCodecType::Unspecified => crate::UNSPECIFIED_STR, + }; + write!(f, "{s}") + } +} + +/// RTPCodecCapability provides information about codec capabilities. +/// +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct RTCRtpCodecCapability { + pub mime_type: String, + pub clock_rate: u32, + pub channels: u16, + pub sdp_fmtp_line: String, + pub rtcp_feedback: Vec, +} + +impl RTCRtpCodecCapability { + /// Turn codec capability into a `packetizer::Payloader` + pub fn payloader_for_codec(&self) -> Result> { + let mime_type = self.mime_type.to_lowercase(); + if mime_type == MIME_TYPE_H264.to_lowercase() { + Ok(Box::::default()) + } else if mime_type == MIME_TYPE_VP8.to_lowercase() { + let mut vp8_payloader = rtp::codecs::vp8::Vp8Payloader::default(); + vp8_payloader.enable_picture_id = true; + Ok(Box::new(vp8_payloader)) + } else if mime_type == MIME_TYPE_VP9.to_lowercase() { + Ok(Box::::default()) + } else if mime_type == MIME_TYPE_OPUS.to_lowercase() { + Ok(Box::::default()) + } else if mime_type == MIME_TYPE_G722.to_lowercase() + || mime_type == MIME_TYPE_PCMU.to_lowercase() + || mime_type == MIME_TYPE_PCMA.to_lowercase() + || mime_type == MIME_TYPE_TELEPHONE_EVENT.to_lowercase() + { + Ok(Box::::default()) + } else if mime_type == MIME_TYPE_AV1.to_lowercase() { + Ok(Box::::default()) + } else { + Err(Error::ErrNoPayloaderForCodec) + } + } +} + +/// RTPHeaderExtensionCapability is used to define a RFC5285 RTP header extension supported by the codec. +/// +#[derive(Default, Debug, Clone)] +pub struct RTCRtpHeaderExtensionCapability { + pub uri: String, +} + +/// RTPHeaderExtensionParameter represents a negotiated RFC5285 RTP header extension. +/// +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct RTCRtpHeaderExtensionParameters { + pub uri: String, + pub id: isize, +} + +/// RTPCodecParameters is a sequence containing the media codecs that an RtpSender +/// will choose from, as well as entries for RTX, RED and FEC mechanisms. This also +/// includes the PayloadType that has been negotiated +/// +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct RTCRtpCodecParameters { + pub capability: RTCRtpCodecCapability, + pub payload_type: PayloadType, + pub stats_id: String, +} + +/// RTPParameters is a list of negotiated codecs and header extensions +/// +#[derive(Default, Debug, Clone)] +pub struct RTCRtpParameters { + pub header_extensions: Vec, + pub codecs: Vec, +} + +#[derive(Default, Debug, Copy, Clone, PartialEq)] +pub(crate) enum CodecMatch { + #[default] + None = 0, + Partial = 1, + Exact = 2, +} + +/// Do a fuzzy find for a codec in the list of codecs +/// Used for lookup up a codec in an existing list to find a match +/// Returns codecMatchExact, codecMatchPartial, or codecMatchNone +pub(crate) fn codec_parameters_fuzzy_search( + needle: &RTCRtpCodecParameters, + haystack: &[RTCRtpCodecParameters], +) -> (RTCRtpCodecParameters, CodecMatch) { + let needle_fmtp = fmtp::parse( + &needle.capability.mime_type, + &needle.capability.sdp_fmtp_line, + ); + + //TODO: add unicode case-folding equal support + + // First attempt to match on mime_type + sdpfmtp_line + for c in haystack { + let cfmpt = fmtp::parse(&c.capability.mime_type, &c.capability.sdp_fmtp_line); + if needle_fmtp.match_fmtp(&*cfmpt) { + return (c.clone(), CodecMatch::Exact); + } + } + + // Fallback to just mime_type + for c in haystack { + if c.capability.mime_type.to_uppercase() == needle.capability.mime_type.to_uppercase() { + return (c.clone(), CodecMatch::Partial); + } + } + + (RTCRtpCodecParameters::default(), CodecMatch::None) +} diff --git a/rtc/src/rtp_transceiver/rtp_receiver/mod.rs b/rtc/src/rtp_transceiver/rtp_receiver/mod.rs new file mode 100644 index 0000000..6d19839 --- /dev/null +++ b/rtc/src/rtp_transceiver/rtp_receiver/mod.rs @@ -0,0 +1,861 @@ +#[cfg(test)] +mod rtp_receiver_test; + +use std::fmt; +use std::sync::Arc; + +use arc_swap::ArcSwapOption; +use interceptor::stream_info::RTPHeaderExtension; +use interceptor::{Attributes, Interceptor}; +use log::trace; +use smol_str::SmolStr; +use tokio::sync::{watch, Mutex, RwLock}; + +use crate::api::media_engine::MediaEngine; +use crate::dtls_transport::RTCDtlsTransport; +use crate::error::{flatten_errs, Error, Result}; +use crate::peer_connection::sdp::TrackDetails; +use crate::rtp_transceiver::rtp_codec::{ + codec_parameters_fuzzy_search, CodecMatch, RTCRtpCodecCapability, RTCRtpCodecParameters, + RTCRtpParameters, RTPCodecType, +}; +use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; +use crate::rtp_transceiver::{ + create_stream_info, RTCRtpDecodingParameters, RTCRtpReceiveParameters, SSRC, +}; +use crate::track::track_remote::TrackRemote; +use crate::track::{TrackStream, TrackStreams}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[repr(u8)] +pub enum State { + /// We haven't started yet. + Unstarted = 0, + /// We haven't started yet and additionally we've been paused. + UnstartedPaused = 1, + + /// We have started and are running. + Started = 2, + + /// We have been paused after starting. + Paused = 3, + + /// We have been stopped. + Stopped = 4, +} + +impl From for State { + fn from(value: u8) -> Self { + match value { + v if v == State::Unstarted as u8 => State::Unstarted, + v if v == State::UnstartedPaused as u8 => State::UnstartedPaused, + v if v == State::Started as u8 => State::Started, + v if v == State::Paused as u8 => State::Paused, + v if v == State::Stopped as u8 => State::Stopped, + _ => unreachable!( + "Invalid serialization of {}: {}", + std::any::type_name::(), + value + ), + } + } +} + +impl fmt::Display for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + State::Unstarted => write!(f, "Unstarted"), + State::UnstartedPaused => write!(f, "UnstartedPaused"), + State::Started => write!(f, "Running"), + State::Paused => write!(f, "Paused"), + State::Stopped => write!(f, "Closed"), + } + } +} + +impl State { + fn transition(to: Self, tx: &watch::Sender) -> Result<()> { + let current = *tx.borrow(); + if current == to { + // Already in this state + return Ok(()); + } + + match current { + Self::Unstarted + if matches!(to, Self::Started | Self::Stopped | Self::UnstartedPaused) => + { + let _ = tx.send(to); + return Ok(()); + } + Self::UnstartedPaused + if matches!(to, Self::Unstarted | Self::Stopped | Self::Paused) => + { + let _ = tx.send(to); + return Ok(()); + } + State::Started if matches!(to, Self::Paused | Self::Stopped) => { + let _ = tx.send(to); + return Ok(()); + } + State::Paused if matches!(to, Self::Started | Self::Stopped) => { + let _ = tx.send(to); + return Ok(()); + } + _ => {} + } + + Err(Error::ErrRTPReceiverStateChangeInvalid { from: current, to }) + } + + async fn wait_for(rx: &mut watch::Receiver, states: &[State]) -> Result<()> { + loop { + let state = *rx.borrow(); + + match state { + _ if states.contains(&state) => return Ok(()), + State::Stopped => { + return Err(Error::ErrClosedPipe); + } + _ => {} + } + + if rx.changed().await.is_err() { + return Err(Error::ErrClosedPipe); + } + } + } + + async fn error_on_close(rx: &mut watch::Receiver) -> Result<()> { + if rx.changed().await.is_err() { + return Err(Error::ErrClosedPipe); + } + + let state = *rx.borrow(); + if state == State::Stopped { + return Err(Error::ErrClosedPipe); + } + + Ok(()) + } + + fn is_started(&self) -> bool { + matches!(self, Self::Started | Self::Paused) + } +} + +pub struct RTPReceiverInternal { + pub(crate) kind: RTPCodecType, + + // State is stored within the channel + state_tx: watch::Sender, + state_rx: watch::Receiver, + + tracks: RwLock>, + + transceiver_codecs: ArcSwapOption>>, + + transport: Arc, + media_engine: Arc, + interceptor: Arc, +} + +impl RTPReceiverInternal { + /// read reads incoming RTCP for this RTPReceiver + async fn read( + &self, + b: &mut [u8], + ) -> Result<(Vec>, Attributes)> { + let mut state_watch_rx = self.state_tx.subscribe(); + // Ensure we are running or paused. When paused we still receive RTCP even if RTP traffic + // isn't flowing. + State::wait_for(&mut state_watch_rx, &[State::Started, State::Paused]).await?; + + let tracks = self.tracks.read().await; + if let Some(t) = tracks.first() { + if let Some(rtcp_interceptor) = &t.stream.rtcp_interceptor { + let a = Attributes::new(); + loop { + tokio::select! { + res = State::error_on_close(&mut state_watch_rx) => { + res? + } + result = rtcp_interceptor.read(b, &a) => { + return Ok(result?) + } + } + } + } else { + Err(Error::ErrInterceptorNotBind) + } + } else { + Err(Error::ErrExistingTrack) + } + } + + /// read_simulcast reads incoming RTCP for this RTPReceiver for given rid + async fn read_simulcast( + &self, + b: &mut [u8], + rid: &str, + ) -> Result<(Vec>, Attributes)> { + let mut state_watch_rx = self.state_tx.subscribe(); + + // Ensure we are running or paused. When paused we still receive RTCP even if RTP traffic + // isn't flowing. + State::wait_for(&mut state_watch_rx, &[State::Started, State::Paused]).await?; + + let tracks = self.tracks.read().await; + for t in &*tracks { + if t.track.rid() == rid { + if let Some(rtcp_interceptor) = &t.stream.rtcp_interceptor { + let a = Attributes::new(); + + loop { + tokio::select! { + res = State::error_on_close(&mut state_watch_rx) => { + res? + } + result = rtcp_interceptor.read(b, &a) => { + return Ok(result?); + } + } + } + } else { + return Err(Error::ErrInterceptorNotBind); + } + } + } + Err(Error::ErrRTPReceiverForRIDTrackStreamNotFound) + } + + /// read_rtcp is a convenience method that wraps Read and unmarshal for you. + /// It also runs any configured interceptors. + async fn read_rtcp( + &self, + receive_mtu: usize, + ) -> Result<(Vec>, Attributes)> { + let mut b = vec![0u8; receive_mtu]; + let (pkts, attributes) = self.read(&mut b).await?; + + Ok((pkts, attributes)) + } + + /// read_simulcast_rtcp is a convenience method that wraps ReadSimulcast and unmarshal for you + async fn read_simulcast_rtcp( + &self, + rid: &str, + receive_mtu: usize, + ) -> Result<(Vec>, Attributes)> { + let mut b = vec![0u8; receive_mtu]; + let (pkts, attributes) = self.read_simulcast(&mut b, rid).await?; + + Ok((pkts, attributes)) + } + + pub(crate) async fn read_rtp( + &self, + b: &mut [u8], + tid: usize, + ) -> Result<(rtp::packet::Packet, Attributes)> { + let mut state_watch_rx = self.state_tx.subscribe(); + + // Ensure we are running. + State::wait_for(&mut state_watch_rx, &[State::Started]).await?; + + //log::debug!("read_rtp enter tracks tid {}", tid); + let mut rtp_interceptor = None; + //let mut ssrc = 0; + { + let tracks = self.tracks.read().await; + for t in &*tracks { + if t.track.tid() == tid { + rtp_interceptor = t.stream.rtp_interceptor.clone(); + //ssrc = t.track.ssrc(); + break; + } + } + }; + /*log::debug!( + "read_rtp exit tracks with rtp_interceptor {} with tid {}", + rtp_interceptor.is_some(), + tid, + );*/ + + if let Some(rtp_interceptor) = rtp_interceptor { + let a = Attributes::new(); + //println!( + // "read_rtp rtp_interceptor.read enter with tid {} ssrc {}", + // tid, ssrc + //); + let mut current_state = *state_watch_rx.borrow(); + loop { + tokio::select! { + _ = state_watch_rx.changed() => { + let new_state = *state_watch_rx.borrow(); + + if new_state == State::Stopped { + return Err(Error::ErrClosedPipe); + } + current_state = new_state; + } + result = rtp_interceptor.read(b, &a) => { + let result = result?; + + if current_state == State::Paused { + trace!("Dropping {} read bytes received while RTPReceiver was paused", result.0); + continue; + } + return Ok(result); + } + } + } + } else { + //log::debug!("read_rtp exit tracks with ErrRTPReceiverWithSSRCTrackStreamNotFound"); + Err(Error::ErrRTPReceiverWithSSRCTrackStreamNotFound) + } + } + + async fn get_parameters(&self) -> RTCRtpParameters { + let mut parameters = self + .media_engine + .get_rtp_parameters_by_kind(self.kind, RTCRtpTransceiverDirection::Recvonly); + + let transceiver_codecs = self.transceiver_codecs.load(); + if let Some(codecs) = &*transceiver_codecs { + let mut c = codecs.lock().await; + parameters.codecs = + RTPReceiverInternal::get_codecs(&mut c, self.kind, &self.media_engine); + } + + parameters + } + + pub(crate) fn get_codecs( + codecs: &mut [RTCRtpCodecParameters], + kind: RTPCodecType, + media_engine: &Arc, + ) -> Vec { + let media_engine_codecs = media_engine.get_codecs_by_kind(kind); + if codecs.is_empty() { + return media_engine_codecs; + } + let mut filtered_codecs = vec![]; + for codec in codecs { + let (c, match_type) = codec_parameters_fuzzy_search(codec, &media_engine_codecs); + if match_type != CodecMatch::None { + if codec.payload_type == 0 { + codec.payload_type = c.payload_type; + } + filtered_codecs.push(codec.clone()); + } + } + + filtered_codecs + } + + // State + + /// Get the current state and a receiver for the next state change. + pub(crate) fn current_state(&self) -> State { + *self.state_rx.borrow() + } + + pub(crate) fn start(&self) -> Result<()> { + State::transition(State::Started, &self.state_tx) + } + + pub(crate) fn pause(&self) -> Result<()> { + let current = self.current_state(); + + match current { + State::Unstarted => State::transition(State::UnstartedPaused, &self.state_tx), + State::Started => State::transition(State::Paused, &self.state_tx), + _ => Ok(()), + } + } + + pub(crate) fn resume(&self) -> Result<()> { + let current = self.current_state(); + + match current { + State::UnstartedPaused => State::transition(State::Unstarted, &self.state_tx), + State::Paused => State::transition(State::Started, &self.state_tx), + _ => Ok(()), + } + } + + pub(crate) fn close(&self) -> Result<()> { + State::transition(State::Stopped, &self.state_tx) + } +} + +/// RTPReceiver allows an application to inspect the receipt of a TrackRemote +pub struct RTCRtpReceiver { + receive_mtu: usize, + kind: RTPCodecType, + transport: Arc, + + pub internal: Arc, +} + +impl std::fmt::Debug for RTCRtpReceiver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RTCRtpReceiver") + .field("kind", &self.kind) + .finish() + } +} + +impl RTCRtpReceiver { + pub fn new( + receive_mtu: usize, + kind: RTPCodecType, + transport: Arc, + media_engine: Arc, + interceptor: Arc, + ) -> Self { + let (state_tx, state_rx) = watch::channel(State::Unstarted); + + RTCRtpReceiver { + receive_mtu, + kind, + transport: Arc::clone(&transport), + + internal: Arc::new(RTPReceiverInternal { + kind, + + tracks: RwLock::new(vec![]), + transport, + media_engine, + interceptor, + + state_tx, + state_rx, + + transceiver_codecs: ArcSwapOption::new(None), + }), + } + } + + pub fn kind(&self) -> RTPCodecType { + self.kind + } + + pub(crate) fn set_transceiver_codecs( + &self, + codecs: Option>>>, + ) { + self.internal.transceiver_codecs.store(codecs); + } + + /// transport returns the currently-configured *DTLSTransport or nil + /// if one has not yet been configured + pub fn transport(&self) -> Arc { + Arc::clone(&self.transport) + } + + /// get_parameters describes the current configuration for the encoding and + /// transmission of media on the receiver's track. + pub async fn get_parameters(&self) -> RTCRtpParameters { + self.internal.get_parameters().await + } + + /// SetRTPParameters applies provided RTPParameters the RTPReceiver's tracks. + /// This method is part of the ORTC API. It is not + /// meant to be used together with the basic WebRTC API. + /// The amount of provided codecs must match the number of tracks on the receiver. + pub async fn set_rtp_parameters(&self, params: RTCRtpParameters) { + let mut header_extensions = vec![]; + for h in ¶ms.header_extensions { + header_extensions.push(RTPHeaderExtension { + id: h.id, + uri: h.uri.clone(), + }); + } + + let mut tracks = self.internal.tracks.write().await; + for (idx, codec) in params.codecs.iter().enumerate() { + let t = &mut tracks[idx]; + if let Some(stream_info) = &mut t.stream.stream_info { + stream_info.rtp_header_extensions = header_extensions.clone(); + } + + let current_track = &t.track; + current_track.set_codec(codec.clone()); + current_track.set_params(params.clone()); + } + } + + /// track returns the RtpTransceiver TrackRemote + pub async fn track(&self) -> Option> { + let tracks = self.internal.tracks.read().await; + if tracks.len() != 1 { + None + } else { + tracks.first().map(|t| Arc::clone(&t.track)) + } + } + + /// tracks returns the RtpTransceiver traclockks + /// A RTPReceiver to support Simulcast may now have multiple tracks + pub async fn tracks(&self) -> Vec> { + let tracks = self.internal.tracks.read().await; + tracks.iter().map(|t| Arc::clone(&t.track)).collect() + } + + /// receive initialize the track and starts all the transports + pub async fn receive(&self, parameters: &RTCRtpReceiveParameters) -> Result<()> { + let receiver = Arc::downgrade(&self.internal); + + let current_state = self.internal.current_state(); + if current_state.is_started() { + return Err(Error::ErrRTPReceiverReceiveAlreadyCalled); + } + self.internal.start()?; + + let (global_params, interceptor, media_engine) = { + ( + self.internal.get_parameters().await, + Arc::clone(&self.internal.interceptor), + Arc::clone(&self.internal.media_engine), + ) + }; + + let codec = if let Some(codec) = global_params.codecs.first() { + codec.capability.clone() + } else { + RTCRtpCodecCapability::default() + }; + + for encoding in ¶meters.encodings { + let (stream_info, rtp_read_stream, rtp_interceptor, rtcp_read_stream, rtcp_interceptor) = + if encoding.ssrc != 0 { + let stream_info = create_stream_info( + "".to_owned(), + encoding.ssrc, + 0, + codec.clone(), + &global_params.header_extensions, + ); + let (rtp_read_stream, rtp_interceptor, rtcp_read_stream, rtcp_interceptor) = + self.transport + .streams_for_ssrc(encoding.ssrc, &stream_info, &interceptor) + .await?; + + ( + Some(stream_info), + Some(rtp_read_stream), + Some(rtp_interceptor), + Some(rtcp_read_stream), + Some(rtcp_interceptor), + ) + } else { + (None, None, None, None, None) + }; + + let t = TrackStreams { + track: Arc::new(TrackRemote::new( + self.receive_mtu, + self.kind, + encoding.ssrc, + encoding.rid.clone(), + receiver.clone(), + Arc::clone(&media_engine), + Arc::clone(&interceptor), + )), + stream: TrackStream { + stream_info, + rtp_read_stream, + rtp_interceptor, + rtcp_read_stream, + rtcp_interceptor, + }, + + repair_stream: TrackStream { + stream_info: None, + rtp_read_stream: None, + rtp_interceptor: None, + rtcp_read_stream: None, + rtcp_interceptor: None, + }, + }; + + { + let mut tracks = self.internal.tracks.write().await; + tracks.push(t); + }; + + let rtx_ssrc = encoding.rtx.ssrc; + if rtx_ssrc != 0 { + let stream_info = create_stream_info( + "".to_owned(), + rtx_ssrc, + 0, + codec.clone(), + &global_params.header_extensions, + ); + let (rtp_read_stream, rtp_interceptor, rtcp_read_stream, rtcp_interceptor) = self + .transport + .streams_for_ssrc(rtx_ssrc, &stream_info, &interceptor) + .await?; + + self.receive_for_rtx( + rtx_ssrc, + "".to_owned(), + TrackStream { + stream_info: Some(stream_info), + rtp_read_stream: Some(rtp_read_stream), + rtp_interceptor: Some(rtp_interceptor), + rtcp_read_stream: Some(rtcp_read_stream), + rtcp_interceptor: Some(rtcp_interceptor), + }, + ) + .await?; + } + } + + Ok(()) + } + + /// read reads incoming RTCP for this RTPReceiver + pub async fn read( + &self, + b: &mut [u8], + ) -> Result<(Vec>, Attributes)> { + self.internal.read(b).await + } + + /// read_simulcast reads incoming RTCP for this RTPReceiver for given rid + pub async fn read_simulcast( + &self, + b: &mut [u8], + rid: &str, + ) -> Result<(Vec>, Attributes)> { + self.internal.read_simulcast(b, rid).await + } + + /// read_rtcp is a convenience method that wraps Read and unmarshal for you. + /// It also runs any configured interceptors. + pub async fn read_rtcp( + &self, + ) -> Result<(Vec>, Attributes)> { + self.internal.read_rtcp(self.receive_mtu).await + } + + /// read_simulcast_rtcp is a convenience method that wraps ReadSimulcast and unmarshal for you + pub async fn read_simulcast_rtcp( + &self, + rid: &str, + ) -> Result<(Vec>, Attributes)> { + self.internal + .read_simulcast_rtcp(rid, self.receive_mtu) + .await + } + + pub(crate) async fn have_received(&self) -> bool { + self.internal.current_state().is_started() + } + + pub(crate) async fn start(&self, incoming: &TrackDetails) { + let mut encoding_size = incoming.ssrcs.len(); + if incoming.rids.len() >= encoding_size { + encoding_size = incoming.rids.len(); + }; + + let mut encodings = vec![RTCRtpDecodingParameters::default(); encoding_size]; + for (i, encoding) in encodings.iter_mut().enumerate() { + if incoming.rids.len() > i { + encoding.rid = incoming.rids[i].clone(); + } + if incoming.ssrcs.len() > i { + encoding.ssrc = incoming.ssrcs[i]; + } + + encoding.rtx.ssrc = incoming.repair_ssrc; + } + + if let Err(err) = self.receive(&RTCRtpReceiveParameters { encodings }).await { + log::warn!("RTPReceiver Receive failed {}", err); + return; + } + + // set track id and label early so they can be set as new track information + // is received from the SDP. + let is_unpaused = self.current_state() == State::Started; + for track_remote in &self.tracks().await { + track_remote.set_id(incoming.id.clone()); + track_remote.set_stream_id(incoming.stream_id.clone()); + + if is_unpaused { + track_remote.fire_onunmute().await; + } + } + } + + /// Stop irreversibly stops the RTPReceiver + pub async fn stop(&self) -> Result<()> { + let previous_state = self.internal.current_state(); + self.internal.close()?; + + let mut errs = vec![]; + let was_ever_started = previous_state.is_started(); + if was_ever_started { + let tracks = self.internal.tracks.write().await; + for t in &*tracks { + if let Some(rtcp_read_stream) = &t.stream.rtcp_read_stream { + if let Err(err) = rtcp_read_stream.close().await { + errs.push(err); + } + } + + if let Some(rtp_read_stream) = &t.stream.rtp_read_stream { + if let Err(err) = rtp_read_stream.close().await { + errs.push(err); + } + } + + if let Some(repair_rtcp_read_stream) = &t.repair_stream.rtcp_read_stream { + if let Err(err) = repair_rtcp_read_stream.close().await { + errs.push(err); + } + } + + if let Some(repair_rtp_read_stream) = &t.repair_stream.rtp_read_stream { + if let Err(err) = repair_rtp_read_stream.close().await { + errs.push(err); + } + } + + if let Some(stream_info) = &t.stream.stream_info { + self.internal + .interceptor + .unbind_remote_stream(stream_info) + .await; + } + + if let Some(repair_stream_info) = &t.repair_stream.stream_info { + self.internal + .interceptor + .unbind_remote_stream(repair_stream_info) + .await; + } + } + } + + flatten_errs(errs) + } + + /// read_rtp should only be called by a track, this only exists so we can keep state in one place + pub(crate) async fn read_rtp( + &self, + b: &mut [u8], + tid: usize, + ) -> Result<(rtp::packet::Packet, Attributes)> { + self.internal.read_rtp(b, tid).await + } + + /// receive_for_rid is the sibling of Receive expect for RIDs instead of SSRCs + /// It populates all the internal state for the given RID + pub(crate) async fn receive_for_rid( + &self, + rid: SmolStr, + params: RTCRtpParameters, + stream: TrackStream, + ) -> Result> { + let mut tracks = self.internal.tracks.write().await; + for t in &mut *tracks { + if *t.track.rid() == rid { + t.track.set_kind(self.kind); + if let Some(codec) = params.codecs.first() { + t.track.set_codec(codec.clone()); + } + t.track.set_params(params.clone()); + t.track + .set_ssrc(stream.stream_info.as_ref().map_or(0, |s| s.ssrc)); + t.stream = stream; + return Ok(Arc::clone(&t.track)); + } + } + + Err(Error::ErrRTPReceiverForRIDTrackStreamNotFound) + } + + /// receiveForRtx starts a routine that processes the repair stream + /// These packets aren't exposed to the user yet, but we need to process them for + /// TWCC + pub(crate) async fn receive_for_rtx( + &self, + ssrc: SSRC, + rsid: String, + repair_stream: TrackStream, + ) -> Result<()> { + let mut tracks = self.internal.tracks.write().await; + let l = tracks.len(); + for t in &mut *tracks { + if (ssrc != 0 && l == 1) || t.track.rid() == rsid { + t.repair_stream = repair_stream; + + let receive_mtu = self.receive_mtu; + let track = t.clone(); + tokio::spawn(async move { + let a = Attributes::new(); + let mut b = vec![0u8; receive_mtu]; + while let Some(repair_rtp_interceptor) = &track.repair_stream.rtp_interceptor { + //TODO: cancel repair_rtp_interceptor.read gracefully + //println!("repair_rtp_interceptor read begin with ssrc={}", ssrc); + if repair_rtp_interceptor.read(&mut b, &a).await.is_err() { + break; + } + } + }); + + return Ok(()); + } + } + + Err(Error::ErrRTPReceiverForRIDTrackStreamNotFound) + } + + // State + + pub(crate) fn current_state(&self) -> State { + self.internal.current_state() + } + + pub(crate) async fn pause(&self) -> Result<()> { + self.internal.pause()?; + + if !self.internal.current_state().is_started() { + return Ok(()); + } + + let streams = self.internal.tracks.read().await; + + for stream in streams.iter() { + // TODO: If we introduce futures as a direct dependency this and other futures could be + // ran concurrently with [`join_all`](https://docs.rs/futures/0.3.21/futures/future/fn.join_all.html) + stream.track.fire_onmute().await; + } + + Ok(()) + } + + pub(crate) async fn resume(&self) -> Result<()> { + self.internal.resume()?; + + if !self.internal.current_state().is_started() { + return Ok(()); + } + + let streams = self.internal.tracks.read().await; + + for stream in streams.iter() { + // TODO: If we introduce futures as a direct dependency this and other futures could be + // ran concurrently with [`join_all`](https://docs.rs/futures/0.3.21/futures/future/fn.join_all.html) + stream.track.fire_onunmute().await; + } + + Ok(()) + } +} diff --git a/rtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs b/rtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs new file mode 100644 index 0000000..304bff8 --- /dev/null +++ b/rtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs @@ -0,0 +1,231 @@ +use bytes::Bytes; +use media::Sample; +use tokio::sync::mpsc; +use tokio::time::Duration; +use waitgroup::WaitGroup; + +use super::*; +use crate::api::media_engine::{MIME_TYPE_OPUS, MIME_TYPE_VP8}; +use crate::error::Result; +use crate::peer_connection::peer_connection_state::RTCPeerConnectionState; +use crate::peer_connection::peer_connection_test::{ + close_pair_now, create_vnet_pair, signal_pair, until_connection_state, +}; +use crate::rtp_transceiver::rtp_codec::RTCRtpHeaderExtensionParameters; +use crate::rtp_transceiver::RTCPFeedback; +use crate::track::track_local::track_local_static_sample::TrackLocalStaticSample; +use crate::track::track_local::TrackLocal; + +lazy_static! { + static ref P: RTCRtpParameters = RTCRtpParameters { + codecs: vec![RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_OPUS.to_string(), + clock_rate: 48000, + channels: 2, + sdp_fmtp_line: "minptime=10;useinbandfec=1".to_string(), + rtcp_feedback: vec![RTCPFeedback { + typ: "nack".to_owned(), + parameter: "".to_owned(), + }], + }, + payload_type: 111, + ..Default::default() + }], + header_extensions: vec![ + RTCRtpHeaderExtensionParameters { + uri: "urn:ietf:params:rtp-hdrext:sdes:mid".to_owned(), + ..Default::default() + }, + RTCRtpHeaderExtensionParameters { + uri: "urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id".to_owned(), + ..Default::default() + }, + RTCRtpHeaderExtensionParameters { + uri: "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id".to_owned(), + ..Default::default() + }, + ], + }; +} + +//use log::LevelFilter; +//use std::io::Write; + +#[tokio::test] +async fn test_set_rtp_parameters() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, LevelFilter::Trace) + .init();*/ + + let (mut sender, mut receiver, wan) = create_vnet_pair().await?; + + let outgoing_track: Arc = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + sender.add_track(Arc::clone(&outgoing_track)).await?; + + // Those parameters wouldn't make sense in a real application, + // but for the sake of the test we just need different values. + + let (seen_packet_tx, mut seen_packet_rx) = mpsc::channel::<()>(1); + let seen_packet_tx = Arc::new(Mutex::new(Some(seen_packet_tx))); + receiver.on_track(Box::new(move |_, receiver, _| { + let seen_packet_tx2 = Arc::clone(&seen_packet_tx); + Box::pin(async move { + receiver.set_rtp_parameters(P.clone()).await; + + if let Some(t) = receiver.track().await { + let incoming_track_codecs = t.codec(); + + assert_eq!(P.header_extensions, t.params().header_extensions); + assert_eq!( + P.codecs[0].capability.mime_type, + incoming_track_codecs.capability.mime_type + ); + assert_eq!( + P.codecs[0].capability.clock_rate, + incoming_track_codecs.capability.clock_rate + ); + assert_eq!( + P.codecs[0].capability.channels, + incoming_track_codecs.capability.channels + ); + assert_eq!( + P.codecs[0].capability.sdp_fmtp_line, + incoming_track_codecs.capability.sdp_fmtp_line + ); + assert_eq!( + P.codecs[0].capability.rtcp_feedback, + incoming_track_codecs.capability.rtcp_feedback + ); + assert_eq!(P.codecs[0].payload_type, incoming_track_codecs.payload_type); + + { + let mut done = seen_packet_tx2.lock().await; + done.take(); + } + } + }) + })); + + let wg = WaitGroup::new(); + + until_connection_state(&mut sender, &wg, RTCPeerConnectionState::Connected).await; + until_connection_state(&mut receiver, &wg, RTCPeerConnectionState::Connected).await; + + signal_pair(&mut sender, &mut receiver).await?; + + wg.wait().await; + + if let Some(v) = outgoing_track + .as_any() + .downcast_ref::() + { + v.write_sample(&Sample { + data: Bytes::from_static(&[0xAA]), + duration: Duration::from_secs(1), + ..Default::default() + }) + .await?; + } else { + panic!(); + } + + let _ = seen_packet_rx.recv().await; + { + let mut w = wan.lock().await; + w.stop().await?; + } + close_pair_now(&sender, &receiver).await; + + Ok(()) +} + +// Assert that SetReadDeadline works as expected +// This test uses VNet since we must have zero loss +#[tokio::test] +async fn test_rtp_receiver_set_read_deadline() -> Result<()> { + let (mut sender, mut receiver, wan) = create_vnet_pair().await?; + + let track: Arc = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + sender.add_track(Arc::clone(&track)).await?; + + let (seen_packet_tx, mut seen_packet_rx) = mpsc::channel::<()>(1); + let seen_packet_tx = Arc::new(Mutex::new(Some(seen_packet_tx))); + receiver.on_track(Box::new(move |track, receiver, _| { + let seen_packet_tx2 = Arc::clone(&seen_packet_tx); + Box::pin(async move { + // First call will not error because we cache for probing + let result = tokio::time::timeout(Duration::from_secs(1), track.read_rtp()).await; + assert!( + result.is_ok(), + " First call will not error because we cache for probing" + ); + + let result = tokio::time::timeout(Duration::from_secs(1), track.read_rtp()).await; + assert!(result.is_err()); + + let result = tokio::time::timeout(Duration::from_secs(1), receiver.read_rtcp()).await; + assert!(result.is_err()); + + { + let mut done = seen_packet_tx2.lock().await; + done.take(); + } + }) + })); + + let wg = WaitGroup::new(); + until_connection_state(&mut sender, &wg, RTCPeerConnectionState::Connected).await; + until_connection_state(&mut receiver, &wg, RTCPeerConnectionState::Connected).await; + + signal_pair(&mut sender, &mut receiver).await?; + + wg.wait().await; + + if let Some(v) = track.as_any().downcast_ref::() { + v.write_sample(&Sample { + data: Bytes::from_static(&[0xAA]), + duration: Duration::from_secs(1), + ..Default::default() + }) + .await?; + } else { + panic!(); + } + + let _ = seen_packet_rx.recv().await; + { + let mut w = wan.lock().await; + w.stop().await?; + } + close_pair_now(&sender, &receiver).await; + + Ok(()) +} diff --git a/rtc/src/rtp_transceiver/rtp_sender/mod.rs b/rtc/src/rtp_transceiver/rtp_sender/mod.rs new file mode 100644 index 0000000..2f90fb2 --- /dev/null +++ b/rtc/src/rtp_transceiver/rtp_sender/mod.rs @@ -0,0 +1,544 @@ +#[cfg(test)] +mod rtp_sender_test; + +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Weak}; + +use ice::rand::generate_crypto_random_string; +use interceptor::stream_info::StreamInfo; +use interceptor::{Attributes, Interceptor, RTCPReader, RTPWriter}; +use tokio::sync::{mpsc, Mutex, Notify}; +use util::sync::Mutex as SyncMutex; + +use super::srtp_writer_future::SequenceTransformer; +use crate::api::media_engine::MediaEngine; +use crate::dtls_transport::RTCDtlsTransport; +use crate::error::{Error, Result}; +use crate::rtp_transceiver::rtp_codec::{RTCRtpCodecParameters, RTPCodecType}; +use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; +use crate::rtp_transceiver::srtp_writer_future::SrtpWriterFuture; +use crate::rtp_transceiver::{ + create_stream_info, PayloadType, RTCRtpEncodingParameters, RTCRtpSendParameters, + RTCRtpTransceiver, SSRC, +}; +use crate::track::track_local::{ + InterceptorToTrackLocalWriter, TrackLocal, TrackLocalContext, TrackLocalWriter, +}; + +pub(crate) struct RTPSenderInternal { + pub(crate) send_called_rx: Mutex>, + pub(crate) stop_called_rx: Arc, + pub(crate) stop_called_signal: Arc, + pub(crate) rtcp_interceptor: Mutex>>, +} + +impl RTPSenderInternal { + /// read reads incoming RTCP for this RTPReceiver + async fn read( + &self, + b: &mut [u8], + ) -> Result<(Vec>, Attributes)> { + let mut send_called_rx = self.send_called_rx.lock().await; + + tokio::select! { + _ = send_called_rx.recv() =>{ + let rtcp_interceptor = { + let rtcp_interceptor = self.rtcp_interceptor.lock().await; + rtcp_interceptor.clone() + }; + if let Some(rtcp_interceptor) = rtcp_interceptor{ + let a = Attributes::new(); + tokio::select! { + _ = self.stop_called_rx.notified() => { + Err(Error::ErrClosedPipe) + } + result = rtcp_interceptor.read(b, &a) => { + Ok(result?) + } + } + }else{ + Err(Error::ErrInterceptorNotBind) + } + } + _ = self.stop_called_rx.notified() =>{ + Err(Error::ErrClosedPipe) + } + } + } + + /// read_rtcp is a convenience method that wraps Read and unmarshals for you. + async fn read_rtcp( + &self, + receive_mtu: usize, + ) -> Result<(Vec>, Attributes)> { + let mut b = vec![0u8; receive_mtu]; + let (pkts, attributes) = self.read(&mut b).await?; + + Ok((pkts, attributes)) + } +} + +/// RTPSender allows an application to control how a given Track is encoded and transmitted to a remote peer +pub struct RTCRtpSender { + pub(crate) track: Mutex>>, + + pub(crate) srtp_stream: Arc, + pub(crate) stream_info: Mutex, + seq_trans: Arc, + + pub(crate) context: Mutex, + + pub(crate) transport: Arc, + + pub(crate) payload_type: PayloadType, + pub(crate) ssrc: SSRC, + receive_mtu: usize, + + /// a transceiver sender since we can just check the + /// transceiver negotiation status + pub(crate) negotiated: AtomicBool, + + pub(crate) media_engine: Arc, + pub(crate) interceptor: Arc, + + pub(crate) id: String, + + /// The id of the initial track, even if we later change to a different + /// track id should be use when negotiating. + pub(crate) initial_track_id: std::sync::Mutex>, + /// AssociatedMediaStreamIds from the WebRTC specifications + pub(crate) associated_media_stream_ids: std::sync::Mutex>, + + rtp_transceiver: SyncMutex>>, + + send_called_tx: SyncMutex>>, + stop_called_tx: Arc, + stop_called_signal: Arc, + + pub(crate) paused: Arc, + + internal: Arc, +} + +impl std::fmt::Debug for RTCRtpSender { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RTCRtpSender") + .field("id", &self.id) + .finish() + } +} + +impl RTCRtpSender { + pub async fn new( + receive_mtu: usize, + track: Option>, + transport: Arc, + media_engine: Arc, + interceptor: Arc, + start_paused: bool, + ) -> Self { + let id = generate_crypto_random_string( + 32, + b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ", + ); + let (send_called_tx, send_called_rx) = mpsc::channel(1); + let stop_called_tx = Arc::new(Notify::new()); + let stop_called_rx = stop_called_tx.clone(); + let ssrc = rand::random::(); + let stop_called_signal = Arc::new(AtomicBool::new(false)); + + let internal = Arc::new(RTPSenderInternal { + send_called_rx: Mutex::new(send_called_rx), + stop_called_rx, + stop_called_signal: Arc::clone(&stop_called_signal), + rtcp_interceptor: Mutex::new(None), + }); + + let seq_trans = Arc::new(SequenceTransformer::new()); + let srtp_stream = Arc::new(SrtpWriterFuture { + closed: AtomicBool::new(false), + ssrc, + rtp_sender: Arc::downgrade(&internal), + rtp_transport: Arc::clone(&transport), + rtcp_read_stream: Mutex::new(None), + rtp_write_session: Mutex::new(None), + seq_trans: Arc::clone(&seq_trans), + }); + + let srtp_rtcp_reader = Arc::clone(&srtp_stream) as Arc; + let rtcp_interceptor = interceptor.bind_rtcp_reader(srtp_rtcp_reader).await; + { + let mut internal_rtcp_interceptor = internal.rtcp_interceptor.lock().await; + *internal_rtcp_interceptor = Some(rtcp_interceptor); + } + + let stream_ids = track + .as_ref() + .map(|track| vec![track.stream_id().to_string()]) + .unwrap_or_default(); + Self { + track: Mutex::new(track), + + srtp_stream, + stream_info: Mutex::new(StreamInfo::default()), + seq_trans, + + context: Mutex::new(TrackLocalContext::default()), + transport, + + payload_type: 0, + ssrc, + receive_mtu, + + negotiated: AtomicBool::new(false), + + media_engine, + interceptor, + + id, + initial_track_id: std::sync::Mutex::new(None), + associated_media_stream_ids: std::sync::Mutex::new(stream_ids), + + rtp_transceiver: SyncMutex::new(None), + + send_called_tx: SyncMutex::new(Some(send_called_tx)), + stop_called_tx, + stop_called_signal, + + paused: Arc::new(AtomicBool::new(start_paused)), + + internal, + } + } + + pub(crate) fn is_negotiated(&self) -> bool { + self.negotiated.load(Ordering::SeqCst) + } + + pub(crate) fn set_negotiated(&self) { + self.negotiated.store(true, Ordering::SeqCst); + } + + pub(crate) fn set_rtp_transceiver(&self, rtp_transceiver: Option>) { + if let Some(t) = rtp_transceiver.as_ref().and_then(|t| t.upgrade()) { + self.set_paused(!t.direction().has_send()); + } + let mut tr = self.rtp_transceiver.lock(); + *tr = rtp_transceiver; + } + + pub(crate) fn set_paused(&self, paused: bool) { + self.paused.store(paused, Ordering::SeqCst); + } + + /// transport returns the currently-configured DTLSTransport + /// if one has not yet been configured + pub fn transport(&self) -> Arc { + Arc::clone(&self.transport) + } + + /// get_parameters describes the current configuration for the encoding and + /// transmission of media on the sender's track. + pub async fn get_parameters(&self) -> RTCRtpSendParameters { + let kind = { + let track = self.track.lock().await; + if let Some(t) = &*track { + t.kind() + } else { + RTPCodecType::default() + } + }; + + let mut send_parameters = { + RTCRtpSendParameters { + rtp_parameters: self + .media_engine + .get_rtp_parameters_by_kind(kind, RTCRtpTransceiverDirection::Sendonly), + encodings: vec![RTCRtpEncodingParameters { + ssrc: self.ssrc, + payload_type: self.payload_type, + ..Default::default() + }], + } + }; + + let codecs = { + let tr = self.rtp_transceiver.lock().clone(); + if let Some(t) = &tr { + if let Some(t) = t.upgrade() { + t.get_codecs().await + } else { + self.media_engine.get_codecs_by_kind(kind) + } + } else { + self.media_engine.get_codecs_by_kind(kind) + } + }; + send_parameters.rtp_parameters.codecs = codecs; + + send_parameters + } + + /// track returns the RTCRtpTransceiver track, or nil + pub async fn track(&self) -> Option> { + let track = self.track.lock().await; + track.clone() + } + + /// replace_track replaces the track currently being used as the sender's source with a new TrackLocal. + /// The new track must be of the same media kind (audio, video, etc) and switching the track should not + /// require negotiation. + pub async fn replace_track( + &self, + track: Option>, + ) -> Result<()> { + if let Some(t) = &track { + let tr = self.rtp_transceiver.lock(); + if let Some(r) = &*tr { + if let Some(r) = r.upgrade() { + if r.kind != t.kind() { + return Err(Error::ErrRTPSenderNewTrackHasIncorrectKind); + } + } else { + //TODO: what about None arc? + } + } else { + //TODO: what about None tr? + } + } + + if self.has_sent() { + let t = { + let t = self.track.lock().await; + t.clone() + }; + if let Some(t) = t { + let context = self.context.lock().await; + t.unbind(&context).await?; + } + } + + if !self.has_sent() || track.is_none() { + let mut t = self.track.lock().await; + *t = track; + return Ok(()); + } + + let context = { + let context = self.context.lock().await; + context.clone() + }; + + let result = if let Some(t) = &track { + self.seq_trans.reset_offset(); + + let new_context = TrackLocalContext { + id: context.id.clone(), + params: self + .media_engine + .get_rtp_parameters_by_kind(t.kind(), RTCRtpTransceiverDirection::Sendonly), + ssrc: context.ssrc, + write_stream: context.write_stream.clone(), + paused: self.paused.clone(), + }; + + t.bind(&new_context).await + } else { + Err(Error::ErrRTPSenderTrackNil) + }; + + match result { + Err(err) => { + // Re-bind the original track + let track = self.track.lock().await; + if let Some(t) = &*track { + t.bind(&context).await?; + } + + Err(err) + } + Ok(codec) => { + // Codec has changed + if self.payload_type != codec.payload_type { + let mut context = self.context.lock().await; + context.params.codecs = vec![codec]; + } + + { + let mut t = self.track.lock().await; + *t = track; + } + + Ok(()) + } + } + } + + /// send Attempts to set the parameters controlling the sending of media. + pub async fn send(&self, parameters: &RTCRtpSendParameters) -> Result<()> { + if self.has_sent() { + return Err(Error::ErrRTPSenderSendAlreadyCalled); + } + + let write_stream = Arc::new(InterceptorToTrackLocalWriter::new(self.paused.clone())); + let (context, stream_info) = { + let track = self.track.lock().await; + let mut context = TrackLocalContext { + id: self.id.clone(), + params: self.media_engine.get_rtp_parameters_by_kind( + if let Some(t) = &*track { + t.kind() + } else { + RTPCodecType::default() + }, + RTCRtpTransceiverDirection::Sendonly, + ), + ssrc: parameters.encodings[0].ssrc, + write_stream: Some( + Arc::clone(&write_stream) as Arc + ), + paused: self.paused.clone(), + }; + + let codec = if let Some(t) = &*track { + t.bind(&context).await? + } else { + RTCRtpCodecParameters::default() + }; + let payload_type = codec.payload_type; + let capability = codec.capability.clone(); + context.params.codecs = vec![codec]; + let stream_info = create_stream_info( + self.id.clone(), + parameters.encodings[0].ssrc, + payload_type, + capability, + ¶meters.rtp_parameters.header_extensions, + ); + + (context, stream_info) + }; + + let srtp_rtp_writer = Arc::clone(&self.srtp_stream) as Arc; + let rtp_interceptor = self + .interceptor + .bind_local_stream(&stream_info, srtp_rtp_writer) + .await; + { + let mut interceptor_rtp_writer = write_stream.interceptor_rtp_writer.lock().await; + *interceptor_rtp_writer = Some(rtp_interceptor); + } + + { + let mut ctx = self.context.lock().await; + *ctx = context; + } + { + let mut si = self.stream_info.lock().await; + *si = stream_info; + } + + { + let mut send_called_tx = self.send_called_tx.lock(); + send_called_tx.take(); + } + + Ok(()) + } + + /// stop irreversibly stops the RTPSender + pub async fn stop(&self) -> Result<()> { + if self.stop_called_signal.load(Ordering::SeqCst) { + return Ok(()); + } + self.stop_called_signal.store(true, Ordering::SeqCst); + self.stop_called_tx.notify_waiters(); + + if !self.has_sent() { + return Ok(()); + } + + self.replace_track(None).await?; + + { + let stream_info = self.stream_info.lock().await; + self.interceptor.unbind_local_stream(&stream_info).await; + } + + self.srtp_stream.close().await + } + + /// read reads incoming RTCP for this RTPReceiver + pub async fn read( + &self, + b: &mut [u8], + ) -> Result<(Vec>, Attributes)> { + self.internal.read(b).await + } + + /// read_rtcp is a convenience method that wraps Read and unmarshals for you. + pub async fn read_rtcp( + &self, + ) -> Result<(Vec>, Attributes)> { + self.internal.read_rtcp(self.receive_mtu).await + } + + /// Enables overriding outgoing `RTP` packets' `sequence number`s. + /// + /// Must be called once before any data sent or never called at all. + /// + /// # Errors + /// + /// Errors if this [`RTCRtpSender`] has started to send data or sequence + /// transforming has been already enabled. + pub fn enable_seq_transformer(&self) -> Result<()> { + self.seq_trans.enable() + } + + /// has_sent tells if data has been ever sent for this instance + pub(crate) fn has_sent(&self) -> bool { + let send_called_tx = self.send_called_tx.lock(); + send_called_tx.is_none() + } + + /// has_stopped tells if stop has been called + pub(crate) async fn has_stopped(&self) -> bool { + self.stop_called_signal.load(Ordering::SeqCst) + } + + pub(crate) fn initial_track_id(&self) -> Option { + let lock = self.initial_track_id.lock().unwrap(); + + lock.clone() + } + + pub(crate) fn set_initial_track_id(&self, id: String) -> Result<()> { + let mut lock = self.initial_track_id.lock().unwrap(); + + if lock.is_some() { + return Err(Error::ErrSenderInitialTrackIdAlreadySet); + } + + *lock = Some(id); + + Ok(()) + } + + pub(crate) fn associate_media_stream_id(&self, id: String) -> bool { + let mut lock = self.associated_media_stream_ids.lock().unwrap(); + + if lock.contains(&id) { + return false; + } + + lock.push(id); + + true + } + + pub(crate) fn associated_media_stream_ids(&self) -> Vec { + let lock = self.associated_media_stream_ids.lock().unwrap(); + + lock.clone() + } +} diff --git a/rtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs b/rtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs new file mode 100644 index 0000000..d06c6b3 --- /dev/null +++ b/rtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs @@ -0,0 +1,336 @@ +use std::sync::atomic::AtomicU64; + +use bytes::Bytes; +use tokio::time::Duration; +use waitgroup::WaitGroup; + +use super::*; +use crate::api::media_engine::{MIME_TYPE_H264, MIME_TYPE_OPUS, MIME_TYPE_VP8, MIME_TYPE_VP9}; +use crate::api::setting_engine::SettingEngine; +use crate::api::APIBuilder; +use crate::error::Result; +use crate::peer_connection::peer_connection_state::RTCPeerConnectionState; +use crate::peer_connection::peer_connection_test::{ + close_pair_now, create_vnet_pair, new_pair, send_video_until_done, signal_pair, + until_connection_state, +}; +use crate::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; +use crate::track::track_local::track_local_static_sample::TrackLocalStaticSample; + +#[tokio::test] +async fn test_rtp_sender_replace_track() -> Result<()> { + let mut s = SettingEngine::default(); + s.disable_srtp_replay_protection(true); + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + + let api = APIBuilder::new() + .with_setting_engine(s) + .with_media_engine(m) + .build(); + + let (mut sender, mut receiver) = new_pair(&api).await?; + + let track_a = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + let track_b = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_H264.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + let rtp_sender = sender + .add_track(Arc::clone(&track_a) as Arc) + .await?; + + let (seen_packet_a_tx, seen_packet_a_rx) = mpsc::channel::<()>(1); + let (seen_packet_b_tx, seen_packet_b_rx) = mpsc::channel::<()>(1); + + let seen_packet_a_tx = Arc::new(seen_packet_a_tx); + let seen_packet_b_tx = Arc::new(seen_packet_b_tx); + let on_track_count = Arc::new(AtomicU64::new(0)); + receiver.on_track(Box::new(move |track, _, _| { + assert_eq!(on_track_count.fetch_add(1, Ordering::SeqCst), 0); + let seen_packet_a_tx2 = Arc::clone(&seen_packet_a_tx); + let seen_packet_b_tx2 = Arc::clone(&seen_packet_b_tx); + Box::pin(async move { + let pkt = match track.read_rtp().await { + Ok((pkt, _)) => pkt, + Err(err) => { + //assert!(errors.Is(io.EOF, err)) + log::debug!("{}", err); + return; + } + }; + + let last = pkt.payload[pkt.payload.len() - 1]; + if last == 0xAA { + assert_eq!(track.codec().capability.mime_type, MIME_TYPE_VP8); + let _ = seen_packet_a_tx2.send(()).await; + } else if last == 0xBB { + assert_eq!(track.codec().capability.mime_type, MIME_TYPE_H264); + let _ = seen_packet_b_tx2.send(()).await; + } else { + panic!("Unexpected RTP Data {last:02x}"); + } + }) + })); + + signal_pair(&mut sender, &mut receiver).await?; + + // Block Until packet with 0xAA has been seen + tokio::spawn(async move { + send_video_until_done( + seen_packet_a_rx, + vec![track_a], + Bytes::from_static(&[0xAA]), + None, + ) + .await; + }); + + rtp_sender + .replace_track(Some( + Arc::clone(&track_b) as Arc + )) + .await?; + + // Block Until packet with 0xBB has been seen + tokio::spawn(async move { + send_video_until_done( + seen_packet_b_rx, + vec![track_b], + Bytes::from_static(&[0xBB]), + None, + ) + .await; + }); + + close_pair_now(&sender, &receiver).await; + Ok(()) +} + +#[tokio::test] +async fn test_rtp_sender_get_parameters() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (mut offerer, mut answerer) = new_pair(&api).await?; + + let rtp_transceiver = offerer + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + signal_pair(&mut offerer, &mut answerer).await?; + + let sender = rtp_transceiver.sender().await; + let parameters = sender.get_parameters().await; + assert_ne!(0, parameters.rtp_parameters.codecs.len()); + assert_eq!(1, parameters.encodings.len()); + assert_eq!(sender.ssrc, parameters.encodings[0].ssrc); + + close_pair_now(&offerer, &answerer).await; + Ok(()) +} + +#[tokio::test] +async fn test_rtp_sender_set_read_deadline() -> Result<()> { + let (mut sender, mut receiver, wan) = create_vnet_pair().await?; + + let track = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + let rtp_sender = sender + .add_track(Arc::clone(&track) as Arc) + .await?; + + let peer_connections_connected = WaitGroup::new(); + until_connection_state( + &mut sender, + &peer_connections_connected, + RTCPeerConnectionState::Connected, + ) + .await; + until_connection_state( + &mut receiver, + &peer_connections_connected, + RTCPeerConnectionState::Connected, + ) + .await; + + signal_pair(&mut sender, &mut receiver).await?; + + peer_connections_connected.wait().await; + + let result = tokio::time::timeout(Duration::from_secs(1), rtp_sender.read_rtcp()).await; + assert!(result.is_err()); + + { + let mut w = wan.lock().await; + w.stop().await?; + } + close_pair_now(&sender, &receiver).await; + + Ok(()) +} + +#[tokio::test] +async fn test_rtp_sender_replace_track_invalid_track_kind_change() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (mut sender, mut receiver) = new_pair(&api).await?; + + let track_a = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + let track_b = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_OPUS.to_owned(), + ..Default::default() + }, + "audio".to_owned(), + "webrtc-rs".to_owned(), + )); + + let rtp_sender = sender + .add_track(Arc::clone(&track_a) as Arc) + .await?; + + signal_pair(&mut sender, &mut receiver).await?; + + let (seen_packet_tx, seen_packet_rx) = mpsc::channel::<()>(1); + let seen_packet_tx = Arc::new(seen_packet_tx); + receiver.on_track(Box::new(move |_, _, _| { + let seen_packet_tx2 = Arc::clone(&seen_packet_tx); + Box::pin(async move { + let _ = seen_packet_tx2.send(()).await; + }) + })); + + tokio::spawn(async move { + send_video_until_done( + seen_packet_rx, + vec![track_a], + Bytes::from_static(&[0xAA]), + None, + ) + .await; + }); + + if let Err(err) = rtp_sender.replace_track(Some(track_b)).await { + assert_eq!(err, Error::ErrRTPSenderNewTrackHasIncorrectKind); + } else { + panic!(); + } + + close_pair_now(&sender, &receiver).await; + Ok(()) +} + +#[tokio::test] +async fn test_rtp_sender_replace_track_invalid_codec_change() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (mut sender, mut receiver) = new_pair(&api).await?; + + let track_a = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + let track_b = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP9.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + let rtp_sender = sender + .add_track(Arc::clone(&track_a) as Arc) + .await?; + + { + let tr = rtp_sender.rtp_transceiver.lock(); + if let Some(t) = &*tr { + if let Some(t) = t.upgrade() { + t.set_codec_preferences(vec![RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + ..Default::default() + }, + payload_type: 96, + ..Default::default() + }]) + .await?; + } else { + panic!(); + } + } else { + panic!(); + } + } + + signal_pair(&mut sender, &mut receiver).await?; + + let (seen_packet_tx, seen_packet_rx) = mpsc::channel::<()>(1); + let seen_packet_tx = Arc::new(seen_packet_tx); + receiver.on_track(Box::new(move |_, _, _| { + let seen_packet_tx2 = Arc::clone(&seen_packet_tx); + Box::pin(async move { + let _ = seen_packet_tx2.send(()).await; + }) + })); + + tokio::spawn(async move { + send_video_until_done( + seen_packet_rx, + vec![track_a], + Bytes::from_static(&[0xAA]), + None, + ) + .await; + }); + + if let Err(err) = rtp_sender.replace_track(Some(track_b)).await { + assert_eq!(err, Error::ErrUnsupportedCodec); + } else { + panic!(); + } + + close_pair_now(&sender, &receiver).await; + Ok(()) +} diff --git a/rtc/src/rtp_transceiver/rtp_transceiver_direction.rs b/rtc/src/rtp_transceiver/rtp_transceiver_direction.rs new file mode 100644 index 0000000..756731e --- /dev/null +++ b/rtc/src/rtp_transceiver/rtp_transceiver_direction.rs @@ -0,0 +1,210 @@ +use std::fmt; + +/// RTPTransceiverDirection indicates the direction of the RTPTransceiver. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum RTCRtpTransceiverDirection { + Unspecified, + + /// Sendrecv indicates the RTPSender will offer + /// to send RTP and RTPReceiver the will offer to receive RTP. + Sendrecv, + + /// Sendonly indicates the RTPSender will offer to send RTP. + Sendonly, + + /// Recvonly indicates the RTPReceiver the will offer to receive RTP. + Recvonly, + + /// Inactive indicates the RTPSender won't offer + /// to send RTP and RTPReceiver the won't offer to receive RTP. + Inactive, +} + +const RTP_TRANSCEIVER_DIRECTION_SENDRECV_STR: &str = "sendrecv"; +const RTP_TRANSCEIVER_DIRECTION_SENDONLY_STR: &str = "sendonly"; +const RTP_TRANSCEIVER_DIRECTION_RECVONLY_STR: &str = "recvonly"; +const RTP_TRANSCEIVER_DIRECTION_INACTIVE_STR: &str = "inactive"; + +/// defines a procedure for creating a new +/// RTPTransceiverDirection from a raw string naming the transceiver direction. +impl From<&str> for RTCRtpTransceiverDirection { + fn from(raw: &str) -> Self { + match raw { + RTP_TRANSCEIVER_DIRECTION_SENDRECV_STR => RTCRtpTransceiverDirection::Sendrecv, + RTP_TRANSCEIVER_DIRECTION_SENDONLY_STR => RTCRtpTransceiverDirection::Sendonly, + RTP_TRANSCEIVER_DIRECTION_RECVONLY_STR => RTCRtpTransceiverDirection::Recvonly, + RTP_TRANSCEIVER_DIRECTION_INACTIVE_STR => RTCRtpTransceiverDirection::Inactive, + _ => RTCRtpTransceiverDirection::Unspecified, + } + } +} + +impl From for RTCRtpTransceiverDirection { + fn from(v: u8) -> Self { + match v { + 1 => RTCRtpTransceiverDirection::Sendrecv, + 2 => RTCRtpTransceiverDirection::Sendonly, + 3 => RTCRtpTransceiverDirection::Recvonly, + 4 => RTCRtpTransceiverDirection::Inactive, + _ => RTCRtpTransceiverDirection::Unspecified, + } + } +} + +impl fmt::Display for RTCRtpTransceiverDirection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + RTCRtpTransceiverDirection::Sendrecv => { + write!(f, "{RTP_TRANSCEIVER_DIRECTION_SENDRECV_STR}") + } + RTCRtpTransceiverDirection::Sendonly => { + write!(f, "{RTP_TRANSCEIVER_DIRECTION_SENDONLY_STR}") + } + RTCRtpTransceiverDirection::Recvonly => { + write!(f, "{RTP_TRANSCEIVER_DIRECTION_RECVONLY_STR}") + } + RTCRtpTransceiverDirection::Inactive => { + write!(f, "{RTP_TRANSCEIVER_DIRECTION_INACTIVE_STR}") + } + _ => write!(f, "{}", crate::UNSPECIFIED_STR), + } + } +} + +impl RTCRtpTransceiverDirection { + /// reverse indicate the opposite direction + pub fn reverse(&self) -> RTCRtpTransceiverDirection { + match *self { + RTCRtpTransceiverDirection::Sendonly => RTCRtpTransceiverDirection::Recvonly, + RTCRtpTransceiverDirection::Recvonly => RTCRtpTransceiverDirection::Sendonly, + _ => *self, + } + } + + pub fn intersect(&self, other: RTCRtpTransceiverDirection) -> RTCRtpTransceiverDirection { + Self::from_send_recv( + self.has_send() && other.has_send(), + self.has_recv() && other.has_recv(), + ) + } + + pub fn from_send_recv(send: bool, recv: bool) -> RTCRtpTransceiverDirection { + match (send, recv) { + (true, true) => Self::Sendrecv, + (true, false) => Self::Sendonly, + (false, true) => Self::Recvonly, + (false, false) => Self::Inactive, + } + } + + pub fn has_send(&self) -> bool { + matches!(self, Self::Sendrecv | Self::Sendonly) + } + + pub fn has_recv(&self) -> bool { + matches!(self, Self::Sendrecv | Self::Recvonly) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_rtp_transceiver_direction() { + let tests = vec![ + ("Unspecified", RTCRtpTransceiverDirection::Unspecified), + ("sendrecv", RTCRtpTransceiverDirection::Sendrecv), + ("sendonly", RTCRtpTransceiverDirection::Sendonly), + ("recvonly", RTCRtpTransceiverDirection::Recvonly), + ("inactive", RTCRtpTransceiverDirection::Inactive), + ]; + + for (ct_str, expected_type) in tests { + assert_eq!(RTCRtpTransceiverDirection::from(ct_str), expected_type); + } + } + + #[test] + fn test_rtp_transceiver_direction_string() { + let tests = vec![ + (RTCRtpTransceiverDirection::Unspecified, "Unspecified"), + (RTCRtpTransceiverDirection::Sendrecv, "sendrecv"), + (RTCRtpTransceiverDirection::Sendonly, "sendonly"), + (RTCRtpTransceiverDirection::Recvonly, "recvonly"), + (RTCRtpTransceiverDirection::Inactive, "inactive"), + ]; + + for (d, expected_string) in tests { + assert_eq!(d.to_string(), expected_string); + } + } + + #[test] + fn test_rtp_transceiver_has_send() { + let tests = vec![ + (RTCRtpTransceiverDirection::Unspecified, false), + (RTCRtpTransceiverDirection::Sendrecv, true), + (RTCRtpTransceiverDirection::Sendonly, true), + (RTCRtpTransceiverDirection::Recvonly, false), + (RTCRtpTransceiverDirection::Inactive, false), + ]; + + for (d, expected_value) in tests { + assert_eq!(d.has_send(), expected_value); + } + } + + #[test] + fn test_rtp_transceiver_has_recv() { + let tests = vec![ + (RTCRtpTransceiverDirection::Unspecified, false), + (RTCRtpTransceiverDirection::Sendrecv, true), + (RTCRtpTransceiverDirection::Sendonly, false), + (RTCRtpTransceiverDirection::Recvonly, true), + (RTCRtpTransceiverDirection::Inactive, false), + ]; + + for (d, expected_value) in tests { + assert_eq!(d.has_recv(), expected_value); + } + } + + #[test] + fn test_rtp_transceiver_from_send_recv() { + let tests = vec![ + (RTCRtpTransceiverDirection::Sendrecv, (true, true)), + (RTCRtpTransceiverDirection::Sendonly, (true, false)), + (RTCRtpTransceiverDirection::Recvonly, (false, true)), + (RTCRtpTransceiverDirection::Inactive, (false, false)), + ]; + + for (expected_value, (send, recv)) in tests { + assert_eq!( + RTCRtpTransceiverDirection::from_send_recv(send, recv), + expected_value + ); + } + } + + #[test] + fn test_rtp_transceiver_intersect() { + use RTCRtpTransceiverDirection::*; + + let tests = vec![ + ((Sendrecv, Recvonly), Recvonly), + ((Sendrecv, Sendonly), Sendonly), + ((Sendrecv, Inactive), Inactive), + ((Sendonly, Inactive), Inactive), + ((Recvonly, Inactive), Inactive), + ((Recvonly, Sendrecv), Recvonly), + ((Sendonly, Sendrecv), Sendonly), + ((Sendonly, Recvonly), Inactive), + ((Recvonly, Recvonly), Recvonly), + ]; + + for ((a, b), expected_direction) in tests { + assert_eq!(a.intersect(b), expected_direction); + } + } +} diff --git a/rtc/src/rtp_transceiver/rtp_transceiver_test.rs b/rtc/src/rtp_transceiver/rtp_transceiver_test.rs new file mode 100644 index 0000000..3cecc54 --- /dev/null +++ b/rtc/src/rtp_transceiver/rtp_transceiver_test.rs @@ -0,0 +1,356 @@ +use std::sync::atomic::AtomicUsize; + +use super::*; +use crate::api::media_engine::{MIME_TYPE_OPUS, MIME_TYPE_VP8, MIME_TYPE_VP9}; +use crate::api::APIBuilder; +use crate::dtls_transport::RTCDtlsTransport; +use crate::peer_connection::configuration::RTCConfiguration; +use crate::peer_connection::peer_connection_test::{close_pair_now, create_vnet_pair}; + +#[tokio::test] +async fn test_rtp_transceiver_set_codec_preferences() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + m.push_codecs(m.video_codecs.clone(), RTPCodecType::Video) + .await; + m.push_codecs(m.audio_codecs.clone(), RTPCodecType::Audio) + .await; + + let media_video_codecs = m.video_codecs.clone(); + + let api = APIBuilder::new().with_media_engine(m).build(); + let interceptor = api.interceptor_registry.build("")?; + let transport = Arc::new(RTCDtlsTransport::default()); + let receiver = Arc::new(api.new_rtp_receiver( + RTPCodecType::Video, + Arc::clone(&transport), + Arc::clone(&interceptor), + )); + + let sender = Arc::new( + api.new_rtp_sender(None, Arc::clone(&transport), Arc::clone(&interceptor)) + .await, + ); + + let tr = RTCRtpTransceiver::new( + receiver, + sender, + RTCRtpTransceiverDirection::Unspecified, + RTPCodecType::Video, + media_video_codecs.clone(), + Arc::clone(&api.media_engine), + None, + ) + .await; + + assert_eq!(&tr.get_codecs().await, &media_video_codecs); + + let fail_test_cases = vec![ + vec![RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_OPUS.to_string(), + clock_rate: 48000, + channels: 2, + sdp_fmtp_line: "minptime=10;useinbandfec=1".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 111, + ..Default::default() + }], + vec![ + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_string(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 96, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_OPUS.to_string(), + clock_rate: 48000, + channels: 2, + sdp_fmtp_line: "minptime=10;useinbandfec=1".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 111, + ..Default::default() + }, + ], + ]; + + for test_case in fail_test_cases { + if let Err(err) = tr.set_codec_preferences(test_case).await { + assert_eq!(err, Error::ErrRTPTransceiverCodecUnsupported); + } else { + panic!(); + } + } + + let success_test_cases = vec![ + vec![RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_string(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 96, + ..Default::default() + }], + vec![ + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_string(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 96, + ..Default::default() + }, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP9.to_string(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "profile-id=0".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 98, + ..Default::default() + }, + ], + ]; + + for test_case in success_test_cases { + tr.set_codec_preferences(test_case).await?; + } + + tr.set_codec_preferences(vec![]).await?; + assert_ne!(0, tr.get_codecs().await.len()); + + Ok(()) +} + +// Assert that SetCodecPreferences properly filters codecs and PayloadTypes are respected +#[tokio::test] +async fn test_rtp_transceiver_set_codec_preferences_payload_type() -> Result<()> { + let test_codec = RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: "video/test_codec".to_string(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 50, + ..Default::default() + }; + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + let offer_pc = api.new_peer_connection(RTCConfiguration::default()).await?; + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + m.register_codec(test_codec.clone(), RTPCodecType::Video)?; + let api = APIBuilder::new().with_media_engine(m).build(); + let answer_pc = api.new_peer_connection(RTCConfiguration::default()).await?; + + let _ = offer_pc + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + let answer_transceiver = answer_pc + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + answer_transceiver + .set_codec_preferences(vec![ + test_codec, + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_string(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 51, + ..Default::default() + }, + ]) + .await?; + + let offer = offer_pc.create_offer(None).await?; + + offer_pc.set_local_description(offer.clone()).await?; + answer_pc.set_remote_description(offer).await?; + + let answer = answer_pc.create_answer(None).await?; + + // VP8 with proper PayloadType + assert!( + answer.sdp.contains("a=rtpmap:51 VP8/90000"), + "{}", + answer.sdp + ); + + // test_codec is ignored since offerer doesn't support + assert!(!answer.sdp.contains("test_codec")); + + close_pair_now(&offer_pc, &answer_pc).await; + + Ok(()) +} + +#[tokio::test] +async fn test_rtp_transceiver_direction_change() -> Result<()> { + let (offer_pc, answer_pc, _) = create_vnet_pair().await?; + + let offer_transceiver = offer_pc + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + let _ = answer_pc + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + let offer = offer_pc.create_offer(None).await?; + + offer_pc.set_local_description(offer.clone()).await?; + answer_pc.set_remote_description(offer).await?; + + let answer = answer_pc.create_answer(None).await?; + assert!(answer.sdp.contains("a=sendrecv"),); + answer_pc.set_local_description(answer.clone()).await?; + offer_pc.set_remote_description(answer).await?; + + offer_transceiver + .set_direction(RTCRtpTransceiverDirection::Inactive) + .await; + + let offer = offer_pc.create_offer(None).await?; + assert!(offer.sdp.contains("a=inactive"),); + + offer_pc.set_local_description(offer.clone()).await?; + answer_pc.set_remote_description(offer).await?; + + let answer = answer_pc.create_answer(None).await?; + assert!(answer.sdp.contains("a=inactive"),); + offer_pc.set_remote_description(answer).await?; + + close_pair_now(&offer_pc, &answer_pc).await; + + Ok(()) +} + +#[tokio::test] +async fn test_rtp_transceiver_set_direction_causing_negotiation() -> Result<()> { + let (offer_pc, answer_pc, _) = create_vnet_pair().await?; + + let count = Arc::new(AtomicUsize::new(0)); + + { + let count = count.clone(); + offer_pc.on_negotiation_needed(Box::new(move || { + let count = count.clone(); + Box::pin(async move { + count.fetch_add(1, Ordering::SeqCst); + }) + })); + } + + let offer_transceiver = offer_pc + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + let _ = answer_pc + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + let offer = offer_pc.create_offer(None).await?; + offer_pc.set_local_description(offer.clone()).await?; + answer_pc.set_remote_description(offer).await?; + + let answer = answer_pc.create_answer(None).await?; + answer_pc.set_local_description(answer.clone()).await?; + offer_pc.set_remote_description(answer).await?; + + assert_eq!(count.load(Ordering::SeqCst), 0); + + let offer = offer_pc.create_offer(None).await?; + offer_pc.set_local_description(offer.clone()).await?; + answer_pc.set_remote_description(offer).await?; + + let answer = answer_pc.create_answer(None).await?; + answer_pc.set_local_description(answer.clone()).await?; + offer_pc.set_remote_description(answer).await?; + + assert_eq!(count.load(Ordering::SeqCst), 0); + + offer_transceiver + .set_direction(RTCRtpTransceiverDirection::Inactive) + .await; + + // wait for negotiation ops queue to finish. + offer_pc.internal.ops.done().await; + + assert_eq!(count.load(Ordering::SeqCst), 1); + + close_pair_now(&offer_pc, &answer_pc).await; + + Ok(()) +} + +#[ignore] +#[tokio::test] +async fn test_rtp_transceiver_stopping() -> Result<()> { + let (offer_pc, answer_pc, _) = create_vnet_pair().await?; + + let offer_transceiver = offer_pc + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + let _ = answer_pc + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + let offer = offer_pc.create_offer(None).await?; + + offer_pc.set_local_description(offer.clone()).await?; + answer_pc.set_remote_description(offer).await?; + + let answer = answer_pc.create_answer(None).await?; + assert!(answer.sdp.contains("a=sendrecv"),); + answer_pc.set_local_description(answer.clone()).await?; + offer_pc.set_remote_description(answer).await?; + + assert!( + offer_transceiver.mid().is_some(), + "A mid should have been associated with the transceiver when applying the answer" + ); + // Stop the transceiver + offer_transceiver.stop().await?; + + let offer = offer_pc.create_offer(None).await?; + assert!(offer.sdp.contains("a=inactive"),); + let parsed = offer.parsed.unwrap(); + let m = &parsed.media_descriptions[0]; + assert_eq!( + m.media_name.port.value, 0, + "After stopping a transceiver it should be rejected in offers" + ); + + close_pair_now(&offer_pc, &answer_pc).await; + + Ok(()) +} diff --git a/rtc/src/rtp_transceiver/srtp_writer_future.rs b/rtc/src/rtp_transceiver/srtp_writer_future.rs new file mode 100644 index 0000000..be4ac64 --- /dev/null +++ b/rtc/src/rtp_transceiver/srtp_writer_future.rs @@ -0,0 +1,289 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Weak}; + +use async_trait::async_trait; +use bytes::Bytes; +use interceptor::{Attributes, RTCPReader, RTPWriter}; +use srtp::session::Session; +use srtp::stream::Stream; +use tokio::sync::Mutex; +use util; + +use crate::dtls_transport::RTCDtlsTransport; +use crate::error::{Error, Result}; +use crate::rtp_transceiver::rtp_sender::RTPSenderInternal; +use crate::rtp_transceiver::SSRC; + +/// `RTP` packet sequence number manager. +/// +/// Used to override outgoing `RTP` packets' sequence numbers. On creating it is +/// unabled and can be enabled before sending data beginning. Once data sending +/// began it can not be enabled any more. +pub(crate) struct SequenceTransformer(util::sync::Mutex); + +/// [`SequenceTransformer`] inner. +struct SequenceTransformerInner { + offset: u16, + last_sq: u16, + reset_needed: bool, + enabled: bool, + data_sent: bool, +} + +impl SequenceTransformer { + /// Creates a new [`SequenceTransformer`]. + pub(crate) fn new() -> Self { + Self(util::sync::Mutex::new(SequenceTransformerInner { + offset: 0, + last_sq: rand::random(), + reset_needed: false, + enabled: false, + data_sent: false, + })) + } + + /// Enables this [`SequenceTransformer`]. + /// + /// # Errors + /// + /// With [`Error::ErrRTPSenderSeqTransEnabled`] on trying to enable already + /// enabled [`SequenceTransformer`]. + /// + /// With [`Error::ErrRTPSenderSeqTransEnabled`] on trying to enable + /// [`SequenceTransformer`] after data sending began. + pub(crate) fn enable(&self) -> Result<()> { + let mut guard = self.0.lock(); + + if guard.enabled { + return Err(Error::ErrRTPSenderSeqTransEnabled); + } + + (!guard.data_sent) + .then(|| { + guard.enabled = true; + }) + .ok_or(Error::ErrRTPSenderDataSent) + } + + /// Indicates [`SequenceTransformer`] about necessity of recalculating + /// `offset`. + pub(crate) fn reset_offset(&self) { + self.0.lock().reset_needed = true; + } + + /// Gets [`Some`] consistent `sequence number` if this [`SequenceTransformer`] is + /// enabled or [`None`] if it is not. + /// + /// Once this method is called, considers data sending began. + fn seq_number(&self, raw_sn: u16) -> Option { + let mut guard = self.0.lock(); + guard.data_sent = true; + + if !guard.enabled { + return None; + } + + let offset = guard + .reset_needed + .then(|| { + guard.reset_needed = false; + let offset = guard.last_sq.overflowing_sub(raw_sn.overflowing_sub(1).0).0; + guard.offset = offset; + offset + }) + .unwrap_or(guard.offset); + let next = raw_sn.overflowing_add(offset).0; + guard.last_sq = next; + + Some(next) + } +} + +/// SrtpWriterFuture blocks Read/Write calls until +/// the SRTP Session is available +pub(crate) struct SrtpWriterFuture { + pub(crate) closed: AtomicBool, + pub(crate) ssrc: SSRC, + pub(crate) rtp_sender: Weak, + pub(crate) rtp_transport: Arc, + pub(crate) rtcp_read_stream: Mutex>>, // atomic.Value // * + pub(crate) rtp_write_session: Mutex>>, // atomic.Value // * + pub(crate) seq_trans: Arc, +} + +impl SrtpWriterFuture { + async fn init(&self, return_when_no_srtp: bool) -> Result<()> { + if return_when_no_srtp { + { + if let Some(rtp_sender) = self.rtp_sender.upgrade() { + if rtp_sender.stop_called_signal.load(Ordering::SeqCst) { + return Err(Error::ErrClosedPipe); + } + } else { + return Err(Error::ErrClosedPipe); + } + } + + if !self.rtp_transport.srtp_ready_signal.load(Ordering::SeqCst) { + return Ok(()); + } + } else { + let mut rx = self.rtp_transport.srtp_ready_rx.lock().await; + if let Some(srtp_ready_rx) = &mut *rx { + if let Some(rtp_sender) = self.rtp_sender.upgrade() { + tokio::select! { + _ = rtp_sender.stop_called_rx.notified()=> return Err(Error::ErrClosedPipe), + _ = srtp_ready_rx.recv() =>{} + } + } else { + return Err(Error::ErrClosedPipe); + } + } + } + + if self.closed.load(Ordering::SeqCst) { + return Err(Error::ErrClosedPipe); + } + + if let Some(srtcp_session) = self.rtp_transport.get_srtcp_session().await { + let rtcp_read_stream = srtcp_session.open(self.ssrc).await; + let mut stream = self.rtcp_read_stream.lock().await; + *stream = Some(rtcp_read_stream); + } + + { + let srtp_session = self.rtp_transport.get_srtp_session().await; + let mut session = self.rtp_write_session.lock().await; + *session = srtp_session; + } + + Ok(()) + } + + pub async fn close(&self) -> Result<()> { + if self.closed.load(Ordering::SeqCst) { + return Ok(()); + } + self.closed.store(true, Ordering::SeqCst); + + let stream = { + let mut stream = self.rtcp_read_stream.lock().await; + stream.take() + }; + if let Some(rtcp_read_stream) = stream { + Ok(rtcp_read_stream.close().await?) + } else { + Ok(()) + } + } + + pub async fn read(&self, b: &mut [u8]) -> Result { + { + let stream = { + let stream = self.rtcp_read_stream.lock().await; + stream.clone() + }; + if let Some(rtcp_read_stream) = stream { + return Ok(rtcp_read_stream.read(b).await?); + } + } + + self.init(false).await?; + + { + let stream = { + let stream = self.rtcp_read_stream.lock().await; + stream.clone() + }; + if let Some(rtcp_read_stream) = stream { + return Ok(rtcp_read_stream.read(b).await?); + } + } + + Ok(0) + } + + pub async fn write_rtp(&self, pkt: &rtp::packet::Packet) -> Result { + { + let session = { + let session = self.rtp_write_session.lock().await; + session.clone() + }; + if let Some(rtp_write_session) = session { + return Ok(rtp_write_session.write_rtp(pkt).await?); + } + } + + self.init(true).await?; + + { + let session = { + let session = self.rtp_write_session.lock().await; + session.clone() + }; + if let Some(rtp_write_session) = session { + return Ok(rtp_write_session.write_rtp(pkt).await?); + } + } + + Ok(0) + } + + pub async fn write(&self, b: &Bytes) -> Result { + { + let session = { + let session = self.rtp_write_session.lock().await; + session.clone() + }; + if let Some(rtp_write_session) = session { + return Ok(rtp_write_session.write(b, true).await?); + } + } + + self.init(true).await?; + + { + let session = { + let session = self.rtp_write_session.lock().await; + session.clone() + }; + if let Some(rtp_write_session) = session { + return Ok(rtp_write_session.write(b, true).await?); + } + } + + Ok(0) + } +} + +type IResult = std::result::Result; + +#[async_trait] +impl RTCPReader for SrtpWriterFuture { + async fn read( + &self, + buf: &mut [u8], + a: &Attributes, + ) -> IResult<(Vec>, Attributes)> { + let read = self.read(buf).await?; + let pkt = rtcp::packet::unmarshal(&mut &buf[..read])?; + + Ok((pkt, a.clone())) + } +} + +#[async_trait] +impl RTPWriter for SrtpWriterFuture { + async fn write(&self, pkt: &rtp::packet::Packet, _a: &Attributes) -> IResult { + Ok( + match self.seq_trans.seq_number(pkt.header.sequence_number) { + Some(seq_num) => { + let mut new_pkt = pkt.clone(); + new_pkt.header.sequence_number = seq_num; + self.write_rtp(&new_pkt).await? + } + None => self.write_rtp(pkt).await?, + }, + ) + } +} diff --git a/rtc/src/sctp_transport/mod.rs b/rtc/src/sctp_transport/mod.rs new file mode 100644 index 0000000..8927070 --- /dev/null +++ b/rtc/src/sctp_transport/mod.rs @@ -0,0 +1,441 @@ +#[cfg(test)] +mod sctp_transport_test; + +pub mod sctp_transport_capabilities; +pub mod sctp_transport_state; + +use std::collections::{HashMap, HashSet}; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU8, Ordering}; +use std::sync::Arc; + +use arc_swap::ArcSwapOption; +use data::data_channel::DataChannel; +use data::message::message_channel_open::ChannelType; +use sctp::association::Association; +use sctp_transport_state::RTCSctpTransportState; +use tokio::sync::{Mutex, Notify}; +use util::Conn; + +use crate::api::setting_engine::SettingEngine; +use crate::data_channel::data_channel_parameters::DataChannelParameters; +use crate::data_channel::data_channel_state::RTCDataChannelState; +use crate::data_channel::RTCDataChannel; +use crate::dtls_transport::dtls_role::DTLSRole; +use crate::dtls_transport::*; +use crate::error::*; +use crate::sctp_transport::sctp_transport_capabilities::SCTPTransportCapabilities; +use crate::stats::stats_collector::StatsCollector; +use crate::stats::StatsReportType::{PeerConnection, SCTPTransport}; +use crate::stats::{ICETransportStats, PeerConnectionStats}; + +const SCTP_MAX_CHANNELS: u16 = u16::MAX; + +pub type OnDataChannelHdlrFn = Box< + dyn (FnMut(Arc) -> Pin + Send + 'static>>) + + Send + + Sync, +>; + +pub type OnDataChannelOpenedHdlrFn = Box< + dyn (FnMut(Arc) -> Pin + Send + 'static>>) + + Send + + Sync, +>; + +struct AcceptDataChannelParams { + notify_rx: Arc, + sctp_association: Arc, + data_channels: Arc>>>, + on_error_handler: Arc>>, + on_data_channel_handler: Arc>>, + on_data_channel_opened_handler: Arc>>, + data_channels_opened: Arc, + data_channels_accepted: Arc, + setting_engine: Arc, +} + +/// SCTPTransport provides details about the SCTP transport. +#[derive(Default)] +pub struct RTCSctpTransport { + pub(crate) dtls_transport: Arc, + + // State represents the current state of the SCTP transport. + state: AtomicU8, // RTCSctpTransportState + + // SCTPTransportState doesn't have an enum to distinguish between New/Connecting + // so we need a dedicated field + is_started: AtomicBool, + + // max_message_size represents the maximum size of data that can be passed to + // DataChannel's send() method. + max_message_size: usize, + + // max_channels represents the maximum amount of DataChannel's that can + // be used simultaneously. + max_channels: u16, + + sctp_association: Mutex>>, + + on_error_handler: Arc>>, + on_data_channel_handler: Arc>>, + on_data_channel_opened_handler: Arc>>, + + // DataChannels + pub(crate) data_channels: Arc>>>, + pub(crate) data_channels_opened: Arc, + pub(crate) data_channels_requested: Arc, + data_channels_accepted: Arc, + + notify_tx: Arc, + + setting_engine: Arc, +} + +impl RTCSctpTransport { + pub(crate) fn new( + dtls_transport: Arc, + setting_engine: Arc, + ) -> Self { + RTCSctpTransport { + dtls_transport, + state: AtomicU8::new(RTCSctpTransportState::Connecting as u8), + is_started: AtomicBool::new(false), + max_message_size: RTCSctpTransport::calc_message_size(65536, 65536), + max_channels: SCTP_MAX_CHANNELS, + sctp_association: Mutex::new(None), + on_error_handler: Arc::new(ArcSwapOption::empty()), + on_data_channel_handler: Arc::new(ArcSwapOption::empty()), + on_data_channel_opened_handler: Arc::new(ArcSwapOption::empty()), + + data_channels: Arc::new(Mutex::new(vec![])), + data_channels_opened: Arc::new(AtomicU32::new(0)), + data_channels_requested: Arc::new(AtomicU32::new(0)), + data_channels_accepted: Arc::new(AtomicU32::new(0)), + + notify_tx: Arc::new(Notify::new()), + + setting_engine, + } + } + + /// transport returns the DTLSTransport instance the SCTPTransport is sending over. + pub fn transport(&self) -> Arc { + Arc::clone(&self.dtls_transport) + } + + /// get_capabilities returns the SCTPCapabilities of the SCTPTransport. + pub fn get_capabilities(&self) -> SCTPTransportCapabilities { + SCTPTransportCapabilities { + max_message_size: 0, + } + } + + /// Start the SCTPTransport. Since both local and remote parties must mutually + /// create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish + /// a connection over SCTP. + pub async fn start(&self, _remote_caps: SCTPTransportCapabilities) -> Result<()> { + if self.is_started.load(Ordering::SeqCst) { + return Ok(()); + } + self.is_started.store(true, Ordering::SeqCst); + + let dtls_transport = self.transport(); + if let Some(net_conn) = &dtls_transport.conn().await { + let sctp_association = loop { + tokio::select! { + _ = self.notify_tx.notified() => { + // It seems like notify_tx is only notified on Stop so perhaps this check + // is redundant. + // TODO: Consider renaming notify_tx to shutdown_tx. + if self.state.load(Ordering::SeqCst) == RTCSctpTransportState::Closed as u8 { + return Err(Error::ErrSCTPTransportDTLS); + } + }, + association = sctp::association::Association::client(sctp::association::Config { + net_conn: Arc::clone(net_conn) as Arc, + max_receive_buffer_size: 0, + max_message_size: 0, + name: String::new(), + }) => { + break Arc::new(association?); + } + }; + }; + + { + let mut sa = self.sctp_association.lock().await; + *sa = Some(Arc::clone(&sctp_association)); + } + self.state + .store(RTCSctpTransportState::Connected as u8, Ordering::SeqCst); + + let param = AcceptDataChannelParams { + notify_rx: self.notify_tx.clone(), + sctp_association, + data_channels: Arc::clone(&self.data_channels), + on_error_handler: Arc::clone(&self.on_error_handler), + on_data_channel_handler: Arc::clone(&self.on_data_channel_handler), + on_data_channel_opened_handler: Arc::clone(&self.on_data_channel_opened_handler), + data_channels_opened: Arc::clone(&self.data_channels_opened), + data_channels_accepted: Arc::clone(&self.data_channels_accepted), + setting_engine: Arc::clone(&self.setting_engine), + }; + tokio::spawn(async move { + RTCSctpTransport::accept_data_channels(param).await; + }); + + Ok(()) + } else { + Err(Error::ErrSCTPTransportDTLS) + } + } + + /// Stop stops the SCTPTransport + pub async fn stop(&self) -> Result<()> { + { + let mut sctp_association = self.sctp_association.lock().await; + if let Some(sa) = sctp_association.take() { + sa.close().await?; + } + } + + self.state + .store(RTCSctpTransportState::Closed as u8, Ordering::SeqCst); + + self.notify_tx.notify_waiters(); + + Ok(()) + } + + async fn accept_data_channels(param: AcceptDataChannelParams) { + let dcs = param.data_channels.lock().await; + let mut existing_data_channels = Vec::new(); + for dc in dcs.iter() { + if let Some(dc) = dc.data_channel.lock().await.clone() { + existing_data_channels.push(dc); + } + } + drop(dcs); + + loop { + let dc = tokio::select! { + _ = param.notify_rx.notified() => break, + result = DataChannel::accept( + ¶m.sctp_association, + data::data_channel::Config::default(), + &existing_data_channels, + ) => { + match result { + Ok(dc) => dc, + Err(err) => { + if data::Error::ErrStreamClosed == err { + log::error!("Failed to accept data channel: {}", err); + if let Some(handler) = &*param.on_error_handler.load() { + let mut f = handler.lock().await; + f(err.into()).await; + } + } + break; + } + } + } + }; + + let mut max_retransmits = 0; + let mut max_packet_lifetime = 0; + let val = dc.config.reliability_parameter as u16; + let ordered; + + match dc.config.channel_type { + ChannelType::Reliable => { + ordered = true; + } + ChannelType::ReliableUnordered => { + ordered = false; + } + ChannelType::PartialReliableRexmit => { + ordered = true; + max_retransmits = val; + } + ChannelType::PartialReliableRexmitUnordered => { + ordered = false; + max_retransmits = val; + } + ChannelType::PartialReliableTimed => { + ordered = true; + max_packet_lifetime = val; + } + ChannelType::PartialReliableTimedUnordered => { + ordered = false; + max_packet_lifetime = val; + } + }; + + let negotiated = if dc.config.negotiated { + Some(dc.stream_identifier()) + } else { + None + }; + let rtc_dc = Arc::new(RTCDataChannel::new( + DataChannelParameters { + label: dc.config.label.clone(), + protocol: dc.config.protocol.clone(), + negotiated, + ordered, + max_packet_life_time: max_packet_lifetime, + max_retransmits, + }, + Arc::clone(¶m.setting_engine), + )); + + if let Some(handler) = &*param.on_data_channel_handler.load() { + let mut f = handler.lock().await; + f(Arc::clone(&rtc_dc)).await; + + param.data_channels_accepted.fetch_add(1, Ordering::SeqCst); + + let mut dcs = param.data_channels.lock().await; + dcs.push(Arc::clone(&rtc_dc)); + } + + rtc_dc.handle_open(Arc::new(dc)).await; + + if let Some(handler) = &*param.on_data_channel_opened_handler.load() { + let mut f = handler.lock().await; + f(rtc_dc).await; + param.data_channels_opened.fetch_add(1, Ordering::SeqCst); + } + } + } + + /// on_error sets an event handler which is invoked when + /// the SCTP connection error occurs. + pub fn on_error(&self, f: OnErrorHdlrFn) { + self.on_error_handler.store(Some(Arc::new(Mutex::new(f)))); + } + + /// on_data_channel sets an event handler which is invoked when a data + /// channel message arrives from a remote peer. + pub fn on_data_channel(&self, f: OnDataChannelHdlrFn) { + self.on_data_channel_handler + .store(Some(Arc::new(Mutex::new(f)))); + } + + /// on_data_channel_opened sets an event handler which is invoked when a data + /// channel is opened + pub fn on_data_channel_opened(&self, f: OnDataChannelOpenedHdlrFn) { + self.on_data_channel_opened_handler + .store(Some(Arc::new(Mutex::new(f)))); + } + + fn calc_message_size(remote_max_message_size: usize, can_send_size: usize) -> usize { + if remote_max_message_size == 0 && can_send_size == 0 { + usize::MAX + } else if remote_max_message_size == 0 { + can_send_size + } else if can_send_size == 0 || can_send_size > remote_max_message_size { + remote_max_message_size + } else { + can_send_size + } + } + + /// max_channels is the maximum number of RTCDataChannels that can be open simultaneously. + pub fn max_channels(&self) -> u16 { + if self.max_channels == 0 { + SCTP_MAX_CHANNELS + } else { + self.max_channels + } + } + + /// state returns the current state of the SCTPTransport + pub fn state(&self) -> RTCSctpTransportState { + self.state.load(Ordering::SeqCst).into() + } + + pub(crate) async fn collect_stats( + &self, + collector: &StatsCollector, + peer_connection_id: String, + ) { + let dtls_transport = self.transport(); + + // TODO: should this be collected? + dtls_transport.collect_stats(collector).await; + + // data channels + let mut data_channels_closed = 0; + let data_channels = self.data_channels.lock().await; + for data_channel in &*data_channels { + match data_channel.ready_state() { + RTCDataChannelState::Connecting => (), + RTCDataChannelState::Open => (), + _ => data_channels_closed += 1, + } + data_channel.collect_stats(collector).await; + } + + let mut reports = HashMap::new(); + let peer_connection_stats = + PeerConnectionStats::new(self, peer_connection_id.clone(), data_channels_closed); + reports.insert(peer_connection_id, PeerConnection(peer_connection_stats)); + + // conn + if let Some(agent) = dtls_transport.ice_transport.gatherer.get_agent().await { + let stats = ICETransportStats::new("sctp_transport".to_owned(), agent); + reports.insert(stats.id.clone(), SCTPTransport(stats)); + } + + collector.merge(reports); + } + + pub(crate) async fn generate_and_set_data_channel_id( + &self, + dtls_role: DTLSRole, + ) -> Result { + let mut id = 0u16; + if dtls_role != DTLSRole::Client { + id += 1; + } + + // Create map of ids so we can compare without double-looping each time. + let mut ids_map = HashSet::new(); + { + let data_channels = self.data_channels.lock().await; + for dc in &*data_channels { + ids_map.insert(dc.id()); + } + } + + let max = self.max_channels(); + while id < max - 1 { + if ids_map.contains(&id) { + id += 2; + } else { + return Ok(id); + } + } + + Err(Error::ErrMaxDataChannelID) + } + + pub(crate) async fn association(&self) -> Option> { + let sctp_association = self.sctp_association.lock().await; + sctp_association.clone() + } + + pub(crate) fn data_channels_accepted(&self) -> u32 { + self.data_channels_accepted.load(Ordering::SeqCst) + } + + pub(crate) fn data_channels_opened(&self) -> u32 { + self.data_channels_opened.load(Ordering::SeqCst) + } + + pub(crate) fn data_channels_requested(&self) -> u32 { + self.data_channels_requested.load(Ordering::SeqCst) + } +} diff --git a/rtc/src/sctp_transport/sctp_transport_capabilities.rs b/rtc/src/sctp_transport/sctp_transport_capabilities.rs new file mode 100644 index 0000000..ee4b2a7 --- /dev/null +++ b/rtc/src/sctp_transport/sctp_transport_capabilities.rs @@ -0,0 +1,7 @@ +use serde::{Deserialize, Serialize}; + +/// SCTPTransportCapabilities indicates the capabilities of the SCTPTransport. +#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub struct SCTPTransportCapabilities { + pub max_message_size: u32, +} diff --git a/rtc/src/sctp_transport/sctp_transport_state.rs b/rtc/src/sctp_transport/sctp_transport_state.rs new file mode 100644 index 0000000..310b814 --- /dev/null +++ b/rtc/src/sctp_transport/sctp_transport_state.rs @@ -0,0 +1,99 @@ +use std::fmt; + +/// SCTPTransportState indicates the state of the SCTP transport. +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] +#[repr(u8)] +pub enum RTCSctpTransportState { + #[default] + Unspecified, + + /// SCTPTransportStateConnecting indicates the SCTPTransport is in the + /// process of negotiating an association. This is the initial state of the + /// SCTPTransportState when an SCTPTransport is created. + Connecting, + + /// SCTPTransportStateConnected indicates the negotiation of an + /// association is completed. + Connected, + + /// SCTPTransportStateClosed indicates a SHUTDOWN or ABORT chunk is + /// received or when the SCTP association has been closed intentionally, + /// such as by closing the peer connection or applying a remote description + /// that rejects data or changes the SCTP port. + Closed, +} + +const SCTP_TRANSPORT_STATE_CONNECTING_STR: &str = "connecting"; +const SCTP_TRANSPORT_STATE_CONNECTED_STR: &str = "connected"; +const SCTP_TRANSPORT_STATE_CLOSED_STR: &str = "closed"; + +impl From<&str> for RTCSctpTransportState { + fn from(raw: &str) -> Self { + match raw { + SCTP_TRANSPORT_STATE_CONNECTING_STR => RTCSctpTransportState::Connecting, + SCTP_TRANSPORT_STATE_CONNECTED_STR => RTCSctpTransportState::Connected, + SCTP_TRANSPORT_STATE_CLOSED_STR => RTCSctpTransportState::Closed, + _ => RTCSctpTransportState::Unspecified, + } + } +} + +impl From for RTCSctpTransportState { + fn from(v: u8) -> Self { + match v { + 1 => RTCSctpTransportState::Connecting, + 2 => RTCSctpTransportState::Connected, + 3 => RTCSctpTransportState::Closed, + _ => RTCSctpTransportState::Unspecified, + } + } +} + +impl fmt::Display for RTCSctpTransportState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + RTCSctpTransportState::Connecting => SCTP_TRANSPORT_STATE_CONNECTING_STR, + RTCSctpTransportState::Connected => SCTP_TRANSPORT_STATE_CONNECTED_STR, + RTCSctpTransportState::Closed => SCTP_TRANSPORT_STATE_CLOSED_STR, + RTCSctpTransportState::Unspecified => crate::UNSPECIFIED_STR, + }; + write!(f, "{s}") + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_new_sctp_transport_state() { + let tests = vec![ + (crate::UNSPECIFIED_STR, RTCSctpTransportState::Unspecified), + ("connecting", RTCSctpTransportState::Connecting), + ("connected", RTCSctpTransportState::Connected), + ("closed", RTCSctpTransportState::Closed), + ]; + + for (state_string, expected_state) in tests { + assert_eq!( + RTCSctpTransportState::from(state_string), + expected_state, + "testCase: {expected_state}", + ); + } + } + + #[test] + fn test_sctp_transport_state_string() { + let tests = vec![ + (RTCSctpTransportState::Unspecified, crate::UNSPECIFIED_STR), + (RTCSctpTransportState::Connecting, "connecting"), + (RTCSctpTransportState::Connected, "connected"), + (RTCSctpTransportState::Closed, "closed"), + ]; + + for (state, expected_string) in tests { + assert_eq!(state.to_string(), expected_string) + } + } +} diff --git a/rtc/src/sctp_transport/sctp_transport_test.rs b/rtc/src/sctp_transport/sctp_transport_test.rs new file mode 100644 index 0000000..e128f5d --- /dev/null +++ b/rtc/src/sctp_transport/sctp_transport_test.rs @@ -0,0 +1,43 @@ +use std::sync::atomic::AtomicU16; + +use super::*; + +#[tokio::test] +async fn test_generate_data_channel_id() -> Result<()> { + let sctp_transport_with_channels = |ids: &[u16]| -> RTCSctpTransport { + let mut data_channels = vec![]; + for id in ids { + data_channels.push(Arc::new(RTCDataChannel { + id: AtomicU16::new(*id), + ..Default::default() + })); + } + + RTCSctpTransport { + data_channels: Arc::new(Mutex::new(data_channels)), + ..Default::default() + } + }; + + let tests = vec![ + (DTLSRole::Client, sctp_transport_with_channels(&[]), 0), + (DTLSRole::Client, sctp_transport_with_channels(&[1]), 0), + (DTLSRole::Client, sctp_transport_with_channels(&[0]), 2), + (DTLSRole::Client, sctp_transport_with_channels(&[0, 2]), 4), + (DTLSRole::Client, sctp_transport_with_channels(&[0, 4]), 2), + (DTLSRole::Server, sctp_transport_with_channels(&[]), 1), + (DTLSRole::Server, sctp_transport_with_channels(&[0]), 1), + (DTLSRole::Server, sctp_transport_with_channels(&[1]), 3), + (DTLSRole::Server, sctp_transport_with_channels(&[1, 3]), 5), + (DTLSRole::Server, sctp_transport_with_channels(&[1, 5]), 3), + ]; + + for (role, s, expected) in tests { + match s.generate_and_set_data_channel_id(role).await { + Ok(actual) => assert_eq!(actual, expected), + Err(err) => panic!("failed to generate id: {err}"), + }; + } + + Ok(()) +} diff --git a/rtc/src/stats/mod.rs b/rtc/src/stats/mod.rs new file mode 100644 index 0000000..75da730 --- /dev/null +++ b/rtc/src/stats/mod.rs @@ -0,0 +1,586 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::SystemTime; + +use ice::agent::agent_stats::{CandidatePairStats, CandidateStats}; +use ice::agent::Agent; +use ice::candidate::{CandidatePairState, CandidateType}; +use ice::network_type::NetworkType; +use serde::{Serialize, Serializer}; +use smol_str::SmolStr; +use stats_collector::StatsCollector; +use tokio::time::Instant; + +use crate::data_channel::data_channel_state::RTCDataChannelState; +use crate::data_channel::RTCDataChannel; +use crate::dtls_transport::dtls_fingerprint::RTCDtlsFingerprint; +use crate::peer_connection::certificate::RTCCertificate; +use crate::rtp_transceiver::rtp_codec::RTCRtpCodecParameters; +use crate::rtp_transceiver::{PayloadType, SSRC}; +use crate::sctp_transport::RTCSctpTransport; + +mod serialize; +pub mod stats_collector; + +#[derive(Debug, Serialize)] +pub enum RTCStatsType { + #[serde(rename = "candidate-pair")] + CandidatePair, + #[serde(rename = "certificate")] + Certificate, + #[serde(rename = "codec")] + Codec, + #[serde(rename = "csrc")] + CSRC, + #[serde(rename = "data-channel")] + DataChannel, + #[serde(rename = "inbound-rtp")] + InboundRTP, + #[serde(rename = "local-candidate")] + LocalCandidate, + #[serde(rename = "outbound-rtp")] + OutboundRTP, + #[serde(rename = "peer-connection")] + PeerConnection, + #[serde(rename = "receiver")] + Receiver, + #[serde(rename = "remote-candidate")] + RemoteCandidate, + #[serde(rename = "remote-inbound-rtp")] + RemoteInboundRTP, + #[serde(rename = "remote-outbound-rtp")] + RemoteOutboundRTP, + #[serde(rename = "sender")] + Sender, + #[serde(rename = "transport")] + Transport, +} + +pub enum SourceStatsType { + LocalCandidate(CandidateStats), + RemoteCandidate(CandidateStats), +} + +#[derive(Debug)] +pub enum StatsReportType { + CandidatePair(ICECandidatePairStats), + CertificateStats(CertificateStats), + Codec(CodecStats), + DataChannel(DataChannelStats), + LocalCandidate(ICECandidateStats), + PeerConnection(PeerConnectionStats), + RemoteCandidate(ICECandidateStats), + SCTPTransport(ICETransportStats), + Transport(ICETransportStats), + InboundRTP(InboundRTPStats), + OutboundRTP(OutboundRTPStats), + RemoteInboundRTP(RemoteInboundRTPStats), + RemoteOutboundRTP(RemoteOutboundRTPStats), +} + +impl From for StatsReportType { + fn from(stats: SourceStatsType) -> Self { + match stats { + SourceStatsType::LocalCandidate(stats) => StatsReportType::LocalCandidate( + ICECandidateStats::new(stats, RTCStatsType::LocalCandidate), + ), + SourceStatsType::RemoteCandidate(stats) => StatsReportType::RemoteCandidate( + ICECandidateStats::new(stats, RTCStatsType::RemoteCandidate), + ), + } + } +} + +impl Serialize for StatsReportType { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + StatsReportType::CandidatePair(stats) => stats.serialize(serializer), + StatsReportType::CertificateStats(stats) => stats.serialize(serializer), + StatsReportType::Codec(stats) => stats.serialize(serializer), + StatsReportType::DataChannel(stats) => stats.serialize(serializer), + StatsReportType::LocalCandidate(stats) => stats.serialize(serializer), + StatsReportType::PeerConnection(stats) => stats.serialize(serializer), + StatsReportType::RemoteCandidate(stats) => stats.serialize(serializer), + StatsReportType::SCTPTransport(stats) => stats.serialize(serializer), + StatsReportType::Transport(stats) => stats.serialize(serializer), + StatsReportType::InboundRTP(stats) => stats.serialize(serializer), + StatsReportType::OutboundRTP(stats) => stats.serialize(serializer), + StatsReportType::RemoteInboundRTP(stats) => stats.serialize(serializer), + StatsReportType::RemoteOutboundRTP(stats) => stats.serialize(serializer), + } + } +} + +#[derive(Debug)] +pub struct StatsReport { + pub reports: HashMap, +} + +impl From for StatsReport { + fn from(collector: StatsCollector) -> Self { + StatsReport { + reports: collector.into_reports(), + } + } +} + +impl Serialize for StatsReport { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.reports.serialize(serializer) + } +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ICECandidatePairStats { + // RTCStats + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub timestamp: Instant, + #[serde(rename = "type")] + pub stats_type: RTCStatsType, + pub id: String, + + // RTCIceCandidatePairStats + // TODO: Add `transportId` + pub local_candidate_id: String, + pub remote_candidate_id: String, + pub state: CandidatePairState, + pub nominated: bool, + pub packets_sent: u32, + pub packets_received: u32, + pub bytes_sent: u64, + pub bytes_received: u64, + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub last_packet_sent_timestamp: Instant, + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub last_packet_received_timestamp: Instant, + pub total_round_trip_time: f64, + pub current_round_trip_time: f64, + pub available_outgoing_bitrate: f64, + pub available_incoming_bitrate: f64, + pub requests_received: u64, + pub requests_sent: u64, + pub responses_received: u64, + pub responses_sent: u64, + pub consent_requests_sent: u64, + // TODO: Add `packetsDiscardedOnSend` + // TODO: Add `bytesDiscardedOnSend` + + // Non-canon + pub circuit_breaker_trigger_count: u32, + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub consent_expired_timestamp: Instant, + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub first_request_timestamp: Instant, + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub last_request_timestamp: Instant, + pub retransmissions_sent: u64, +} + +impl From for ICECandidatePairStats { + fn from(stats: CandidatePairStats) -> Self { + ICECandidatePairStats { + available_incoming_bitrate: stats.available_incoming_bitrate, + available_outgoing_bitrate: stats.available_outgoing_bitrate, + bytes_received: stats.bytes_received, + bytes_sent: stats.bytes_sent, + circuit_breaker_trigger_count: stats.circuit_breaker_trigger_count, + consent_expired_timestamp: stats.consent_expired_timestamp, + consent_requests_sent: stats.consent_requests_sent, + current_round_trip_time: stats.current_round_trip_time, + first_request_timestamp: stats.first_request_timestamp, + id: format!("{}-{}", stats.local_candidate_id, stats.remote_candidate_id), + last_packet_received_timestamp: stats.last_packet_received_timestamp, + last_packet_sent_timestamp: stats.last_packet_sent_timestamp, + last_request_timestamp: stats.last_request_timestamp, + local_candidate_id: stats.local_candidate_id, + nominated: stats.nominated, + packets_received: stats.packets_received, + packets_sent: stats.packets_sent, + remote_candidate_id: stats.remote_candidate_id, + requests_received: stats.requests_received, + requests_sent: stats.requests_sent, + responses_received: stats.responses_received, + responses_sent: stats.responses_sent, + retransmissions_sent: stats.retransmissions_sent, + state: stats.state, + stats_type: RTCStatsType::CandidatePair, + timestamp: stats.timestamp, + total_round_trip_time: stats.total_round_trip_time, + } + } +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ICECandidateStats { + // RTCStats + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub timestamp: Instant, + #[serde(rename = "type")] + pub stats_type: RTCStatsType, + pub id: String, + + // RTCIceCandidateStats + pub candidate_type: CandidateType, + pub deleted: bool, + pub ip: String, + pub network_type: NetworkType, + pub port: u16, + pub priority: u32, + pub relay_protocol: String, + pub url: String, +} + +impl ICECandidateStats { + fn new(stats: CandidateStats, stats_type: RTCStatsType) -> Self { + ICECandidateStats { + candidate_type: stats.candidate_type, + deleted: stats.deleted, + id: stats.id, + ip: stats.ip, + network_type: stats.network_type, + port: stats.port, + priority: stats.priority, + relay_protocol: stats.relay_protocol, + stats_type, + timestamp: stats.timestamp, + url: stats.url, + } + } +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ICETransportStats { + // RTCStats + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub timestamp: Instant, + #[serde(rename = "type")] + pub stats_type: RTCStatsType, + pub id: String, + + // Non-canon + pub bytes_received: usize, + pub bytes_sent: usize, +} + +impl ICETransportStats { + pub(crate) fn new(id: String, agent: Arc) -> Self { + ICETransportStats { + id, + bytes_received: agent.get_bytes_received(), + bytes_sent: agent.get_bytes_sent(), + stats_type: RTCStatsType::Transport, + timestamp: Instant::now(), + } + } +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CertificateStats { + // RTCStats + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub timestamp: Instant, + #[serde(rename = "type")] + pub stats_type: RTCStatsType, + pub id: String, + + // RTCCertificateStats + pub fingerprint: String, + pub fingerprint_algorithm: String, + // TODO: Add `base64Certificate` and `issuerCertificateId`. +} + +impl CertificateStats { + pub(crate) fn new(cert: &RTCCertificate, fingerprint: RTCDtlsFingerprint) -> Self { + CertificateStats { + // TODO: base64_certificate + fingerprint: fingerprint.value, + fingerprint_algorithm: fingerprint.algorithm, + id: cert.stats_id.clone(), + // TODO: issuer_certificate_id + stats_type: RTCStatsType::Certificate, + timestamp: Instant::now(), + } + } +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CodecStats { + // RTCStats + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub timestamp: Instant, + #[serde(rename = "type")] + pub stats_type: RTCStatsType, + pub id: String, + + // RTCCodecStats + pub payload_type: PayloadType, + pub mime_type: String, + pub channels: u16, + pub clock_rate: u32, + pub sdp_fmtp_line: String, + // TODO: Add `transportId` +} + +impl From<&RTCRtpCodecParameters> for CodecStats { + fn from(codec: &RTCRtpCodecParameters) -> Self { + CodecStats { + channels: codec.capability.channels, + clock_rate: codec.capability.clock_rate, + id: codec.stats_id.clone(), + mime_type: codec.capability.mime_type.clone(), + payload_type: codec.payload_type, + sdp_fmtp_line: codec.capability.sdp_fmtp_line.clone(), + stats_type: RTCStatsType::Codec, + timestamp: Instant::now(), + } + } +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct DataChannelStats { + // RTCStats + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub timestamp: Instant, + #[serde(rename = "type")] + pub stats_type: RTCStatsType, + pub id: String, + + // RTCDataChannelStats + pub bytes_received: usize, + pub bytes_sent: usize, + pub data_channel_identifier: u16, + pub label: String, + pub messages_received: usize, + pub messages_sent: usize, + pub protocol: String, + pub state: RTCDataChannelState, +} + +impl DataChannelStats { + pub(crate) async fn from(data_channel: &RTCDataChannel) -> Self { + let state = data_channel.ready_state(); + + let mut bytes_received = 0; + let mut bytes_sent = 0; + let mut messages_received = 0; + let mut messages_sent = 0; + + let lock = data_channel.data_channel.lock().await; + + if let Some(internal) = &*lock { + bytes_received = internal.bytes_received(); + bytes_sent = internal.bytes_sent(); + messages_received = internal.messages_received(); + messages_sent = internal.messages_sent(); + } + + Self { + bytes_received, + bytes_sent, + data_channel_identifier: data_channel.id(), // TODO: "The value is initially null" + id: data_channel.stats_id.clone(), + label: data_channel.label.clone(), + messages_received, + messages_sent, + protocol: data_channel.protocol.clone(), + state, + stats_type: RTCStatsType::DataChannel, + timestamp: Instant::now(), + } + } +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PeerConnectionStats { + // RTCStats + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub timestamp: Instant, + #[serde(rename = "type")] + pub stats_type: RTCStatsType, + pub id: String, + + // RTCPeerConnectionStats + pub data_channels_closed: u32, + pub data_channels_opened: u32, + + // Non-canon + pub data_channels_accepted: u32, + pub data_channels_requested: u32, +} + +impl PeerConnectionStats { + pub fn new(transport: &RTCSctpTransport, stats_id: String, data_channels_closed: u32) -> Self { + PeerConnectionStats { + data_channels_accepted: transport.data_channels_accepted(), + data_channels_closed, + data_channels_opened: transport.data_channels_opened(), + data_channels_requested: transport.data_channels_requested(), + id: stats_id, + stats_type: RTCStatsType::PeerConnection, + timestamp: Instant::now(), + } + } +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct InboundRTPStats { + // RTCStats + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub timestamp: Instant, + #[serde(rename = "type")] + pub stats_type: RTCStatsType, + pub id: String, + + // RTCRtpStreamStats + pub ssrc: SSRC, + pub kind: &'static str, // Either "video" or "audio" + // TODO: Add transportId + // TODO: Add codecId + + // RTCReceivedRtpStreamStats + pub packets_received: u64, + // TODO: packetsLost + // TODO: jitter(maybe, might be uattainable for the same reason as `framesDropped`) + // NB: `framesDropped` can't be produced since we aren't decoding, might be worth introducing a + // way for consumers to control this in the future. + + // RTCInboundRtpStreamStats + pub track_identifier: String, + pub mid: SmolStr, + // TODO: `remoteId` + // NB: `framesDecoded`, `frameWidth`, frameHeight`, `framesPerSecond`, `qpSum`, + // `totalDecodeTime`, `totalInterFrameDelay`, and `totalSquaredInterFrameDelay` are all decoder + // specific values and can't be produced since we aren't decoding. + pub last_packet_received_timestamp: Option, + pub header_bytes_received: u64, + // TODO: `packetsDiscarded`. This value only makes sense if we have jitter buffer, which we + // cannot assume. + // TODO: `fecPacketsReceived`, `fecPacketsDiscarded` + pub bytes_received: u64, + pub nack_count: u64, + pub fir_count: Option, + pub pli_count: Option, + // NB: `totalProcessingDelay`, `estimatedPlayoutTimestamp`, `jitterBufferDelay`, + // `jitterBufferTargetDelay`, `jitterBufferEmittedCount`, `jitterBufferMinimumDelay`, + // `totalSamplesReceived`, `concealedSamples`, `silentConcealedSamples`, `concealmentEvents`, + // `insertedSamplesForDeceleration`, `removedSamplesForAcceleration`, `audioLevel`, + // `totalAudioEneregy`, `totalSampleDuration`, `framesReceived, and `decoderImplementation` are + // all decoder specific and can't be produced since we aren't decoding. +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct OutboundRTPStats { + // RTCStats + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub timestamp: Instant, + #[serde(rename = "type")] + pub stats_type: RTCStatsType, + pub id: String, + + // RTCRtpStreamStats + pub ssrc: SSRC, + pub kind: &'static str, // Either "video" or "audio" + // TODO: Add transportId + // TODO: Add codecId + + // RTCSentRtpStreamStats + pub packets_sent: u64, + pub bytes_sent: u64, + + // RTCOutboundRtpStreamStats + // NB: non-canon in browsers this is available via `RTCMediaSourceStats` which we are unlikely to implement + pub track_identifier: String, + pub mid: SmolStr, + // TODO: `mediaSourceId` and `remoteId` + pub rid: Option, + pub header_bytes_sent: u64, + // TODO: `retransmittedPacketsSent` and `retransmittedPacketsSent` + // NB: `targetBitrate`, `totalEncodedBytesTarget`, `frameWidth` `frameHeight`, `framesPerSecond`, `framesSent`, + // `hugeFramesSent`, `framesEncoded`, `keyFramesEncoded`, `qpSum`, and `totalEncodeTime` are + // all encoder specific and can't be produced snce we aren't encoding. + // TODO: `totalPacketSendDelay` time from `TrackLocalWriter::write_rtp` to being written to + // socket. + + // NB: `qualityLimitationReason`, `qualityLimitationDurations`, and `qualityLimitationResolutionChanges` are all + // encoder specific and can't be produced since we aren't encoding. + pub nack_count: u64, + pub fir_count: Option, + pub pli_count: Option, + // NB: `encoderImplementation` is encoder specific and can't be produced since we aren't + // encoding. +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct RemoteInboundRTPStats { + // RTCStats + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub timestamp: Instant, + #[serde(rename = "type")] + pub stats_type: RTCStatsType, + pub id: String, + + // RTCRtpStreamStats + pub ssrc: SSRC, + pub kind: &'static str, // Either "video" or "audio" + // TODO: Add transportId + // TODO: Add codecId + + // RTCReceivedRtpStreamStats + pub packets_received: u64, + pub packets_lost: i64, + // TODO: jitter(maybe, might be uattainable for the same reason as `framesDropped`) + // NB: `framesDropped` can't be produced since we aren't decoding, might be worth introducing a + // way for consumers to control this in the future. + + // RTCRemoteInboundRtpStreamStats + pub local_id: String, + pub round_trip_time: Option, + pub total_round_trip_time: f64, + pub fraction_lost: f64, + pub round_trip_time_measurements: u64, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct RemoteOutboundRTPStats { + // RTCStats + #[serde(with = "serialize::instant_to_epoch_seconds")] + pub timestamp: Instant, + #[serde(rename = "type")] + pub stats_type: RTCStatsType, + pub id: String, + + // RTCRtpStreamStats + pub ssrc: SSRC, + pub kind: &'static str, // Either "video" or "audio" + // TODO: Add transportId + // TODO: Add codecId + + // RTCSentRtpStreamStats + pub packets_sent: u64, + pub bytes_sent: u64, + + // RTCRemoteOutboundRtpStreamStats + pub local_id: String, + // TODO: `remote_timestamp` + pub round_trip_time: Option, + pub reports_sent: u64, + pub total_round_trip_time: f64, + pub round_trip_time_measurements: u64, +} diff --git a/rtc/src/stats/serialize.rs b/rtc/src/stats/serialize.rs new file mode 100644 index 0000000..2db43c4 --- /dev/null +++ b/rtc/src/stats/serialize.rs @@ -0,0 +1,28 @@ +pub mod instant_to_epoch_seconds { + // Serializes a `tokio::time::Instant` to an approximation of epoch time in the form + // of an `f64` where the integer portion is seconds and the decimal portion is milliseconds. + // For instance, `Monday, May 30, 2022 10:45:26.456 PM UTC` converts to `1653950726.456`. + // + // Note that an `Instant` is not connected to real world time, so this conversion is + // approximate. + use std::time::{SystemTime, UNIX_EPOCH}; + + use serde::{Serialize, Serializer}; + use tokio::time::Instant; + + pub fn serialize(instant: &Instant, serializer: S) -> Result + where + S: Serializer, + { + let system_now = SystemTime::now(); + let instant_now = Instant::now(); + let approx = system_now - (instant_now - *instant); + let epoch = approx + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + + let epoch_ms = epoch.as_millis() as f64 / 1000.0; + + epoch_ms.serialize(serializer) + } +} diff --git a/rtc/src/stats/stats_collector.rs b/rtc/src/stats/stats_collector.rs new file mode 100644 index 0000000..e0228df --- /dev/null +++ b/rtc/src/stats/stats_collector.rs @@ -0,0 +1,32 @@ +use std::collections::HashMap; + +use util::sync::Mutex; + +use super::StatsReportType; + +#[derive(Debug, Default)] +pub struct StatsCollector { + pub(crate) reports: Mutex>, +} + +impl StatsCollector { + pub(crate) fn new() -> Self { + StatsCollector { + ..Default::default() + } + } + + pub(crate) fn insert(&self, id: String, stats: StatsReportType) { + let mut reports = self.reports.lock(); + reports.insert(id, stats); + } + + pub(crate) fn merge(&self, stats: HashMap) { + let mut reports = self.reports.lock(); + reports.extend(stats) + } + + pub(crate) fn into_reports(self) -> HashMap { + self.reports.into_inner() + } +} diff --git a/rtc/src/track/mod.rs b/rtc/src/track/mod.rs new file mode 100644 index 0000000..8c0a2d5 --- /dev/null +++ b/rtc/src/track/mod.rs @@ -0,0 +1,29 @@ +pub mod track_local; +pub mod track_remote; + +use std::sync::Arc; + +use interceptor::stream_info::StreamInfo; +use interceptor::{RTCPReader, RTPReader}; +use track_remote::*; + +pub(crate) const RTP_OUTBOUND_MTU: usize = 1200; +pub(crate) const RTP_PAYLOAD_TYPE_BITMASK: u8 = 0x7F; + +#[derive(Clone)] +pub(crate) struct TrackStream { + pub(crate) stream_info: Option, + pub(crate) rtp_read_stream: Option>, + pub(crate) rtp_interceptor: Option>, + pub(crate) rtcp_read_stream: Option>, + pub(crate) rtcp_interceptor: Option>, +} + +/// TrackStreams maintains a mapping of RTP/RTCP streams to a specific track +/// a RTPReceiver may contain multiple streams if we are dealing with Simulcast +#[derive(Clone)] +pub(crate) struct TrackStreams { + pub(crate) track: Arc, + pub(crate) stream: TrackStream, + pub(crate) repair_stream: TrackStream, +} diff --git a/rtc/src/track/track_local/mod.rs b/rtc/src/track/track_local/mod.rs new file mode 100644 index 0000000..fb70ab6 --- /dev/null +++ b/rtc/src/track/track_local/mod.rs @@ -0,0 +1,163 @@ +#[cfg(test)] +mod track_local_static_test; + +pub mod track_local_static_rtp; +pub mod track_local_static_sample; + +use std::any::Any; +use std::fmt; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use async_trait::async_trait; +use interceptor::{Attributes, RTPWriter}; +use tokio::sync::Mutex; +use util::Unmarshal; + +use crate::error::{Error, Result}; +use crate::rtp_transceiver::rtp_codec::*; +use crate::rtp_transceiver::*; + +/// TrackLocalWriter is the Writer for outbound RTP Packets +#[async_trait] +pub trait TrackLocalWriter: fmt::Debug { + /// write_rtp encrypts a RTP packet and writes to the connection + async fn write_rtp(&self, p: &rtp::packet::Packet) -> Result; + + /// write encrypts and writes a full RTP packet + async fn write(&self, b: &[u8]) -> Result; +} + +/// TrackLocalContext is the Context passed when a TrackLocal has been Binded/Unbinded from a PeerConnection, and used +/// in Interceptors. +#[derive(Default, Debug, Clone)] +pub struct TrackLocalContext { + pub(crate) id: String, + pub(crate) params: RTCRtpParameters, + pub(crate) ssrc: SSRC, + pub(crate) write_stream: Option>, + pub(crate) paused: Arc, +} + +impl TrackLocalContext { + /// codec_parameters returns the negotiated RTPCodecParameters. These are the codecs supported by both + /// PeerConnections and the SSRC/PayloadTypes + pub fn codec_parameters(&self) -> &[RTCRtpCodecParameters] { + &self.params.codecs + } + + /// header_extensions returns the negotiated RTPHeaderExtensionParameters. These are the header extensions supported by + /// both PeerConnections and the SSRC/PayloadTypes + pub fn header_extensions(&self) -> &[RTCRtpHeaderExtensionParameters] { + &self.params.header_extensions + } + + /// ssrc requires the negotiated SSRC of this track + /// This track may have multiple if RTX is enabled + pub fn ssrc(&self) -> SSRC { + self.ssrc + } + + /// write_stream returns the write_stream for this TrackLocal. The implementer writes the outbound + /// media packets to it + pub fn write_stream(&self) -> Option> { + self.write_stream.clone() + } + + /// id is a unique identifier that is used for both bind/unbind + pub fn id(&self) -> String { + self.id.clone() + } +} +/// TrackLocal is an interface that controls how the user can send media +/// The user can provide their own TrackLocal implementations, or use +/// the implementations in pkg/media +#[async_trait] +pub trait TrackLocal { + /// bind should implement the way how the media data flows from the Track to the PeerConnection + /// This will be called internally after signaling is complete and the list of available + /// codecs has been determined + async fn bind(&self, t: &TrackLocalContext) -> Result; + + /// unbind should implement the teardown logic when the track is no longer needed. This happens + /// because a track has been stopped. + async fn unbind(&self, t: &TrackLocalContext) -> Result<()>; + + /// id is the unique identifier for this Track. This should be unique for the + /// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' + /// and stream_id would be 'desktop' or 'webcam' + fn id(&self) -> &str; + + /// stream_id is the group this track belongs too. This must be unique + fn stream_id(&self) -> &str; + + /// kind controls if this TrackLocal is audio or video + fn kind(&self) -> RTPCodecType; + + fn as_any(&self) -> &dyn Any; +} + +/// TrackBinding is a single bind for a Track +/// Bind can be called multiple times, this stores the +/// result for a single bind call so that it can be used when writing +#[derive(Default, Debug)] +pub(crate) struct TrackBinding { + id: String, + ssrc: SSRC, + payload_type: PayloadType, + params: RTCRtpParameters, + write_stream: Option>, + sender_paused: Arc, +} + +impl TrackBinding { + pub fn is_sender_paused(&self) -> bool { + self.sender_paused.load(Ordering::SeqCst) + } +} + +pub(crate) struct InterceptorToTrackLocalWriter { + pub(crate) interceptor_rtp_writer: Mutex>>, + sender_paused: Arc, +} + +impl InterceptorToTrackLocalWriter { + pub(crate) fn new(paused: Arc) -> Self { + InterceptorToTrackLocalWriter { + interceptor_rtp_writer: Mutex::new(None), + sender_paused: paused, + } + } + + fn is_sender_paused(&self) -> bool { + self.sender_paused.load(Ordering::SeqCst) + } +} + +impl std::fmt::Debug for InterceptorToTrackLocalWriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InterceptorToTrackLocalWriter").finish() + } +} + +#[async_trait] +impl TrackLocalWriter for InterceptorToTrackLocalWriter { + async fn write_rtp(&self, pkt: &rtp::packet::Packet) -> Result { + if self.is_sender_paused() { + return Ok(0); + } + + let interceptor_rtp_writer = self.interceptor_rtp_writer.lock().await; + if let Some(writer) = &*interceptor_rtp_writer { + let a = Attributes::new(); + Ok(writer.write(pkt, &a).await?) + } else { + Ok(0) + } + } + + async fn write(&self, mut b: &[u8]) -> Result { + let pkt = rtp::packet::Packet::unmarshal(&mut b)?; + self.write_rtp(&pkt).await + } +} diff --git a/rtc/src/track/track_local/track_local_static_rtp.rs b/rtc/src/track/track_local/track_local_static_rtp.rs new file mode 100644 index 0000000..5e951c2 --- /dev/null +++ b/rtc/src/track/track_local/track_local_static_rtp.rs @@ -0,0 +1,238 @@ +use std::collections::HashMap; + +use bytes::BytesMut; +use tokio::sync::Mutex; +use util::{Marshal, MarshalSize}; + +use super::*; +use crate::error::flatten_errs; + +/// TrackLocalStaticRTP is a TrackLocal that has a pre-set codec and accepts RTP Packets. +/// If you wish to send a media.Sample use TrackLocalStaticSample +#[derive(Debug)] +pub struct TrackLocalStaticRTP { + pub(crate) bindings: Mutex>>, + codec: RTCRtpCodecCapability, + id: String, + stream_id: String, +} + +impl TrackLocalStaticRTP { + /// returns a TrackLocalStaticRTP. + pub fn new(codec: RTCRtpCodecCapability, id: String, stream_id: String) -> Self { + TrackLocalStaticRTP { + codec, + bindings: Mutex::new(vec![]), + id, + stream_id, + } + } + + /// codec gets the Codec of the track + pub fn codec(&self) -> RTCRtpCodecCapability { + self.codec.clone() + } + + pub async fn any_binding_paused(&self) -> bool { + let bindings = self.bindings.lock().await; + bindings + .iter() + .any(|b| b.sender_paused.load(Ordering::SeqCst)) + } + + pub async fn all_binding_paused(&self) -> bool { + let bindings = self.bindings.lock().await; + bindings + .iter() + .all(|b| b.sender_paused.load(Ordering::SeqCst)) + } + + /// write_rtp_with_extensions writes a RTP Packet to the TrackLocalStaticRTP + /// If one PeerConnection fails the packets will still be sent to + /// all PeerConnections. The error message will contain the ID of the failed + /// PeerConnections so you can remove them + /// + /// If the RTCRtpSender direction is such that no packets should be sent, any call to this + /// function are blocked internally. Care must be taken to not increase the sequence number + /// while the sender is paused. While the actual _sending_ is blocked, the receiver will + /// miss out when the sequence number "rolls over", which in turn will break SRTP. + /// + /// Extensions that are already configured on the packet are overwritten by extensions in + /// `extensions`. + pub async fn write_rtp_with_extensions( + &self, + p: &rtp::packet::Packet, + extensions: &[rtp::extension::HeaderExtension], + ) -> Result { + let mut n = 0; + let mut write_errs = vec![]; + let mut pkt = p.clone(); + + let bindings = { + let bindings = self.bindings.lock().await; + bindings.clone() + }; + // Prepare the extensions data + let extension_data: HashMap<_, _> = extensions + .iter() + .flat_map(|extension| { + let buf = { + let mut buf = BytesMut::with_capacity(extension.marshal_size()); + buf.resize(extension.marshal_size(), 0); + if let Err(err) = extension.marshal_to(&mut buf) { + write_errs.push(Error::Util(err)); + return None; + } + + buf.freeze() + }; + + Some((extension.uri(), buf)) + }) + .collect(); + + for b in bindings.into_iter() { + if b.is_sender_paused() { + // See caveat in function doc. + continue; + } + pkt.header.ssrc = b.ssrc; + pkt.header.payload_type = b.payload_type; + + for (uri, data) in extension_data.iter() { + if let Some(id) = b + .params + .header_extensions + .iter() + .find(|ext| &ext.uri == uri) + .map(|ext| ext.id) + { + if let Err(err) = pkt.header.set_extension(id as u8, data.clone()) { + write_errs.push(Error::Rtp(err)); + continue; + } + } + } + + if let Some(write_stream) = &b.write_stream { + match write_stream.write_rtp(&pkt).await { + Ok(m) => { + n += m; + } + Err(err) => { + write_errs.push(err); + } + } + } else { + write_errs.push(Error::new("track binding has none write_stream".to_owned())); + } + } + + flatten_errs(write_errs)?; + Ok(n) + } +} + +#[async_trait] +impl TrackLocal for TrackLocalStaticRTP { + /// bind is called by the PeerConnection after negotiation is complete + /// This asserts that the code requested is supported by the remote peer. + /// If so it setups all the state (SSRC and PayloadType) to have a call + async fn bind(&self, t: &TrackLocalContext) -> Result { + let parameters = RTCRtpCodecParameters { + capability: self.codec.clone(), + ..Default::default() + }; + + let (codec, match_type) = codec_parameters_fuzzy_search(¶meters, t.codec_parameters()); + if match_type != CodecMatch::None { + { + let mut bindings = self.bindings.lock().await; + bindings.push(Arc::new(TrackBinding { + ssrc: t.ssrc(), + payload_type: codec.payload_type, + write_stream: t.write_stream(), + params: t.params.clone(), + id: t.id(), + sender_paused: t.paused.clone(), + })); + } + + Ok(codec) + } else { + Err(Error::ErrUnsupportedCodec) + } + } + + /// unbind implements the teardown logic when the track is no longer needed. This happens + /// because a track has been stopped. + async fn unbind(&self, t: &TrackLocalContext) -> Result<()> { + let mut bindings = self.bindings.lock().await; + let mut idx = None; + for (index, binding) in bindings.iter().enumerate() { + if binding.id == t.id() { + idx = Some(index); + break; + } + } + if let Some(index) = idx { + bindings.remove(index); + Ok(()) + } else { + Err(Error::ErrUnbindFailed) + } + } + + /// id is the unique identifier for this Track. This should be unique for the + /// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' + /// and StreamID would be 'desktop' or 'webcam' + fn id(&self) -> &str { + self.id.as_str() + } + + /// stream_id is the group this track belongs too. This must be unique + fn stream_id(&self) -> &str { + self.stream_id.as_str() + } + + /// kind controls if this TrackLocal is audio or video + fn kind(&self) -> RTPCodecType { + if self.codec.mime_type.starts_with("audio/") { + RTPCodecType::Audio + } else if self.codec.mime_type.starts_with("video/") { + RTPCodecType::Video + } else { + RTPCodecType::Unspecified + } + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +#[async_trait] +impl TrackLocalWriter for TrackLocalStaticRTP { + /// write_rtp writes a RTP Packet to the TrackLocalStaticRTP + /// If one PeerConnection fails the packets will still be sent to + /// all PeerConnections. The error message will contain the ID of the failed + /// PeerConnections so you can remove them + /// + /// If the RTCRtpSender direction is such that no packets should be sent, any call to this + /// function are blocked internally. Care must be taken to not increase the sequence number + /// while the sender is paused. While the actual _sending_ is blocked, the receiver will + /// miss out when the sequence number "rolls over", which in turn will break SRTP. + async fn write_rtp(&self, p: &rtp::packet::Packet) -> Result { + self.write_rtp_with_extensions(p, &[]).await + } + + /// write writes a RTP Packet as a buffer to the TrackLocalStaticRTP + /// If one PeerConnection fails the packets will still be sent to + /// all PeerConnections. The error message will contain the ID of the failed + /// PeerConnections so you can remove them + async fn write(&self, mut b: &[u8]) -> Result { + let pkt = rtp::packet::Packet::unmarshal(&mut b)?; + self.write_rtp(&pkt).await?; + Ok(b.len()) + } +} diff --git a/rtc/src/track/track_local/track_local_static_sample.rs b/rtc/src/track/track_local/track_local_static_sample.rs new file mode 100644 index 0000000..1745ce4 --- /dev/null +++ b/rtc/src/track/track_local/track_local_static_sample.rs @@ -0,0 +1,299 @@ +use log::warn; +use media::Sample; +use tokio::sync::Mutex; + +use super::track_local_static_rtp::TrackLocalStaticRTP; +use super::*; +use crate::error::flatten_errs; +use crate::track::RTP_OUTBOUND_MTU; + +#[derive(Debug, Clone)] +struct TrackLocalStaticSampleInternal { + packetizer: Option>, + sequencer: Option>, + clock_rate: f64, + did_warn_about_wonky_pause: bool, +} + +/// TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples. +/// If you wish to send a RTP Packet use TrackLocalStaticRTP +#[derive(Debug)] +pub struct TrackLocalStaticSample { + rtp_track: TrackLocalStaticRTP, + internal: Mutex, +} + +impl TrackLocalStaticSample { + /// returns a TrackLocalStaticSample + pub fn new(codec: RTCRtpCodecCapability, id: String, stream_id: String) -> Self { + let rtp_track = TrackLocalStaticRTP::new(codec, id, stream_id); + + TrackLocalStaticSample { + rtp_track, + internal: Mutex::new(TrackLocalStaticSampleInternal { + packetizer: None, + sequencer: None, + clock_rate: 0.0f64, + did_warn_about_wonky_pause: false, + }), + } + } + + /// codec gets the Codec of the track + pub fn codec(&self) -> RTCRtpCodecCapability { + self.rtp_track.codec() + } + + /// write_sample writes a Sample to the TrackLocalStaticSample + /// If one PeerConnection fails the packets will still be sent to + /// all PeerConnections. The error message will contain the ID of the failed + /// PeerConnections so you can remove them + pub async fn write_sample(&self, sample: &Sample) -> Result<()> { + self.write_sample_with_extensions(sample, &[]).await + } + + /// Write a sample with provided RTP extensions. + /// + /// Alternatively to this method [`TrackLocalStaticSample::sample_writer`] can be used instead. + /// + /// See [`TrackLocalStaticSample::write_sample`] for further details. + pub async fn write_sample_with_extensions( + &self, + sample: &Sample, + extensions: &[rtp::extension::HeaderExtension], + ) -> Result<()> { + let mut internal = self.internal.lock().await; + + if internal.packetizer.is_none() || internal.sequencer.is_none() { + return Ok(()); + } + + let (any_paused, all_paused) = ( + self.rtp_track.any_binding_paused().await, + self.rtp_track.all_binding_paused().await, + ); + + if all_paused { + // Abort already here to not increment sequence numbers. + return Ok(()); + } + + if any_paused { + // This is a problem state due to how this impl is structured. The sequencer will allocate + // one sequence number per RTP packet regardless of how many TrackBinding that will send + // the packet. I.e. we get the same sequence number per multiple SSRC, which is not good + // for SRTP, but that's how it works. + // + // SRTP has a further problem with regards to jumps in sequence number. Consider this: + // + // 1. Create track local + // 2. Bind track local to track 1. + // 3. Bind track local to track 2. + // 4. Pause track 1. + // 5. Keep sending... + // + // At this point, the track local will keep incrementing the sequence number, because we have + // one binding that is still active. However SRTP hmac verifying (tag), can only accept a + // relatively small jump in sequence numbers since it uses the ROC (i.e. how many times the + // sequence number has rolled over), which means if this pause state of one binding persists + // for a longer time, the track can never be resumed since the receiver would have missed + // the rollovers. + if !internal.did_warn_about_wonky_pause { + internal.did_warn_about_wonky_pause = true; + warn!("Detected multiple track bindings where only one was paused"); + } + } + + // skip packets by the number of previously dropped packets + if let Some(sequencer) = &internal.sequencer { + for _ in 0..sample.prev_dropped_packets { + sequencer.next_sequence_number(); + } + } + + let clock_rate = internal.clock_rate; + + let packets = if let Some(packetizer) = &mut internal.packetizer { + let samples = (sample.duration.as_secs_f64() * clock_rate) as u32; + if sample.prev_dropped_packets > 0 { + packetizer.skip_samples(samples * sample.prev_dropped_packets as u32); + } + packetizer.packetize(&sample.data, samples)? + } else { + vec![] + }; + + let mut write_errs = vec![]; + for p in packets { + if let Err(err) = self + .rtp_track + .write_rtp_with_extensions(&p, extensions) + .await + { + write_errs.push(err); + } + } + + flatten_errs(write_errs) + } + + /// Create a builder for writing samples with additional data. + /// + /// # Example + /// ```no_run + /// use rtp::extension::audio_level_extension::AudioLevelExtension; + /// use std::time::Duration; + /// use webrtc::api::media_engine::MIME_TYPE_VP8; + /// use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability; + /// use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample; + /// + /// #[tokio::main] + /// async fn main() { + /// let track = TrackLocalStaticSample::new( + /// RTCRtpCodecCapability { + /// mime_type: MIME_TYPE_VP8.to_owned(), + /// ..Default::default() + /// }, + /// "video".to_owned(), + /// "webrtc-rs".to_owned(), + /// ); + /// let result = track + /// .sample_writer() + /// .with_audio_level(AudioLevelExtension { + /// level: 10, + /// voice: true, + /// }) + /// .write_sample(&media::Sample{ + /// data: bytes::Bytes::new(), + /// duration: Duration::from_secs(1), + /// ..Default::default() + /// }) + /// .await; + /// } + /// ``` + pub fn sample_writer(&self) -> SampleWriter<'_> { + SampleWriter::new(self) + } +} + +#[async_trait] +impl TrackLocal for TrackLocalStaticSample { + /// Bind is called by the PeerConnection after negotiation is complete + /// This asserts that the code requested is supported by the remote peer. + /// If so it setups all the state (SSRC and PayloadType) to have a call + async fn bind(&self, t: &TrackLocalContext) -> Result { + let codec = self.rtp_track.bind(t).await?; + + let mut internal = self.internal.lock().await; + + // We only need one packetizer + if internal.packetizer.is_some() { + return Ok(codec); + } + + let payloader = codec.capability.payloader_for_codec()?; + let sequencer: Box = + Box::new(rtp::sequence::new_random_sequencer()); + internal.packetizer = Some(Box::new(rtp::packetizer::new_packetizer( + RTP_OUTBOUND_MTU, + 0, // Value is handled when writing + 0, // Value is handled when writing + payloader, + sequencer.clone(), + codec.capability.clock_rate, + ))); + internal.sequencer = Some(sequencer); + internal.clock_rate = codec.capability.clock_rate as f64; + + Ok(codec) + } + + /// unbind implements the teardown logic when the track is no longer needed. This happens + /// because a track has been stopped. + async fn unbind(&self, t: &TrackLocalContext) -> Result<()> { + self.rtp_track.unbind(t).await + } + + /// id is the unique identifier for this Track. This should be unique for the + /// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' + /// and StreamID would be 'desktop' or 'webcam' + fn id(&self) -> &str { + self.rtp_track.id() + } + + /// stream_id is the group this track belongs too. This must be unique + fn stream_id(&self) -> &str { + self.rtp_track.stream_id() + } + + /// kind controls if this TrackLocal is audio or video + fn kind(&self) -> RTPCodecType { + self.rtp_track.kind() + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +mod sample_writer { + use media::Sample; + use rtp::extension::audio_level_extension::AudioLevelExtension; + use rtp::extension::video_orientation_extension::VideoOrientationExtension; + use rtp::extension::HeaderExtension; + + use super::TrackLocalStaticSample; + use crate::error::Result; + + /// Helper for writing Samples via [`TrackLocalStaticSample`] that carry extra RTP data. + /// + /// Created via [`TrackLocalStaticSample::sample_writer`]. + pub struct SampleWriter<'track> { + track: &'track TrackLocalStaticSample, + extensions: Vec, + } + + impl<'track> SampleWriter<'track> { + pub(super) fn new(track: &'track TrackLocalStaticSample) -> Self { + Self { + track, + extensions: vec![], + } + } + + /// Add a RTP audio level extension to all packets written for the sample. + /// + /// This overwrites any previously configured audio level extension. + pub fn with_audio_level(self, ext: AudioLevelExtension) -> Self { + self.with_extension(HeaderExtension::AudioLevel(ext)) + } + + /// Add a RTP video orientation extension to all packets written for the sample. + /// + /// This overwrites any previously configured video orientation extension. + pub fn with_video_orientation(self, ext: VideoOrientationExtension) -> Self { + self.with_extension(HeaderExtension::VideoOrientation(ext)) + } + + /// Add any RTP extension to all packets written for the sample. + pub fn with_extension(mut self, ext: HeaderExtension) -> Self { + self.extensions.retain(|e| !e.is_same(&ext)); + + self.extensions.push(ext); + + self + } + + /// Write the sample to the track. + /// + /// Creates one or more RTP packets with any extensions specified for each packet and sends + /// them. + pub async fn write_sample(self, sample: &Sample) -> Result<()> { + self.track + .write_sample_with_extensions(sample, &self.extensions) + .await + } + } +} + +pub use sample_writer::SampleWriter; diff --git a/rtc/src/track/track_local/track_local_static_test.rs b/rtc/src/track/track_local/track_local_static_test.rs new file mode 100644 index 0000000..b385d98 --- /dev/null +++ b/rtc/src/track/track_local/track_local_static_test.rs @@ -0,0 +1,434 @@ +use std::sync::Arc; + +use bytes::Bytes; +use tokio::sync::{mpsc, Mutex}; + +use super::track_local_static_rtp::*; +use super::track_local_static_sample::*; +use super::*; +use crate::api::media_engine::{MediaEngine, MIME_TYPE_VP8}; +use crate::api::APIBuilder; +use crate::peer_connection::configuration::RTCConfiguration; +use crate::peer_connection::peer_connection_test::*; + +// If a remote doesn't support a Codec used by a `TrackLocalStatic` +// an error should be returned to the user +#[tokio::test] +async fn test_track_local_static_no_codec_intersection() -> Result<()> { + let track: Arc = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: "video/vp8".to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + //"Offerer" + { + let mut pc = api.new_peer_connection(RTCConfiguration::default()).await?; + + let mut no_codec_pc = APIBuilder::new() + .build() + .new_peer_connection(RTCConfiguration::default()) + .await?; + + pc.add_track(Arc::clone(&track)).await?; + + if let Err(err) = signal_pair(&mut pc, &mut no_codec_pc).await { + assert_eq!(err, Error::ErrUnsupportedCodec); + } else { + panic!(); + } + + close_pair_now(&no_codec_pc, &pc).await; + } + + //"Answerer" + { + let mut pc = api.new_peer_connection(RTCConfiguration::default()).await?; + + let mut m = MediaEngine::default(); + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: "video/VP9".to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_owned(), + rtcp_feedback: vec![], + }, + payload_type: 96, + ..Default::default() + }, + RTPCodecType::Video, + )?; + let mut vp9only_pc = APIBuilder::new() + .with_media_engine(m) + .build() + .new_peer_connection(RTCConfiguration::default()) + .await?; + + vp9only_pc + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + pc.add_track(Arc::clone(&track)).await?; + + if let Err(err) = signal_pair(&mut vp9only_pc, &mut pc).await { + assert_eq!( + err, + Error::ErrUnsupportedCodec, + "expected {}, but got {}", + Error::ErrUnsupportedCodec, + err + ); + } else { + panic!(); + } + + close_pair_now(&vp9only_pc, &pc).await; + } + + //"Local" + { + let (mut offerer, mut answerer) = new_pair(&api).await?; + + let invalid_codec_track = TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: "video/invalid-codec".to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + ); + + offerer.add_track(Arc::new(invalid_codec_track)).await?; + + if let Err(err) = signal_pair(&mut offerer, &mut answerer).await { + assert_eq!(err, Error::ErrUnsupportedCodec); + } else { + panic!(); + } + + close_pair_now(&offerer, &answerer).await; + } + + Ok(()) +} + +// Assert that Bind/Unbind happens when expected +#[tokio::test] +async fn test_track_local_static_closed() -> Result<()> { + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (mut pc_offer, mut pc_answer) = new_pair(&api).await?; + + pc_answer + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + let vp8writer: Arc = Arc::new(TrackLocalStaticRTP::new( + RTCRtpCodecCapability { + mime_type: "video/vp8".to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + pc_offer.add_track(Arc::clone(&vp8writer)).await?; + + if let Some(v) = vp8writer.as_any().downcast_ref::() { + let bindings = v.bindings.lock().await; + assert_eq!( + bindings.len(), + 0, + "No binding should exist before signaling" + ); + } else { + panic!(); + } + + signal_pair(&mut pc_offer, &mut pc_answer).await?; + + if let Some(v) = vp8writer.as_any().downcast_ref::() { + let bindings = v.bindings.lock().await; + assert_eq!(bindings.len(), 1, "binding should exist after signaling"); + } else { + panic!(); + } + + close_pair_now(&pc_offer, &pc_answer).await; + + if let Some(v) = vp8writer.as_any().downcast_ref::() { + let bindings = v.bindings.lock().await; + assert_eq!(bindings.len(), 0, "No binding should exist after close"); + } else { + panic!(); + } + + Ok(()) +} + +//use log::LevelFilter; +//use std::io::Write; + +#[tokio::test] +async fn test_track_local_static_payload_type() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, LevelFilter::Trace) + .init();*/ + + let mut media_engine_one = MediaEngine::default(); + media_engine_one.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_owned(), + rtcp_feedback: vec![], + }, + payload_type: 100, + ..Default::default() + }, + RTPCodecType::Video, + )?; + + let mut media_engine_two = MediaEngine::default(); + media_engine_two.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "".to_owned(), + rtcp_feedback: vec![], + }, + payload_type: 200, + ..Default::default() + }, + RTPCodecType::Video, + )?; + + let mut offerer = APIBuilder::new() + .with_media_engine(media_engine_one) + .build() + .new_peer_connection(RTCConfiguration::default()) + .await?; + let mut answerer = APIBuilder::new() + .with_media_engine(media_engine_two) + .build() + .new_peer_connection(RTCConfiguration::default()) + .await?; + + let track = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + offerer + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + answerer + .add_track(Arc::clone(&track) as Arc) + .await?; + + let (on_track_fired_tx, on_track_fired_rx) = mpsc::channel::<()>(1); + let on_track_fired_tx = Arc::new(Mutex::new(Some(on_track_fired_tx))); + offerer.on_track(Box::new(move |track, _, _| { + let on_track_fired_tx2 = Arc::clone(&on_track_fired_tx); + Box::pin(async move { + assert_eq!(track.payload_type(), 100); + assert_eq!(track.codec().capability.mime_type, MIME_TYPE_VP8); + { + log::debug!("onTrackFiredFunc!!!"); + let mut done = on_track_fired_tx2.lock().await; + done.take(); + } + }) + })); + + signal_pair(&mut offerer, &mut answerer).await?; + + send_video_until_done( + on_track_fired_rx, + vec![track], + Bytes::from_static(&[0x00]), + None, + ) + .await; + + close_pair_now(&offerer, &answerer).await; + + Ok(()) +} + +// Assert that writing to a Track doesn't modify the input +// Even though we can pass a pointer we shouldn't modify the incoming value +#[tokio::test] +async fn test_track_local_static_mutate_input() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, LevelFilter::Trace) + .init();*/ + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (mut pc_offer, mut pc_answer) = new_pair(&api).await?; + + let vp8writer: Arc = Arc::new(TrackLocalStaticRTP::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + pc_offer.add_track(Arc::clone(&vp8writer)).await?; + + signal_pair(&mut pc_offer, &mut pc_answer).await?; + + let pkt = rtp::packet::Packet { + header: rtp::header::Header { + ssrc: 1, + payload_type: 1, + ..Default::default() + }, + ..Default::default() + }; + if let Some(v) = vp8writer.as_any().downcast_ref::() { + v.write_rtp(&pkt).await?; + } else { + panic!(); + } + + assert_eq!(pkt.header.ssrc, 1); + assert_eq!(pkt.header.payload_type, 1); + + close_pair_now(&pc_offer, &pc_answer).await; + + Ok(()) +} + +//use std::io::Write; +//use log::LevelFilter; + +// Assert that writing to a Track that has Binded (but not connected) +// does not block +#[tokio::test] +async fn test_track_local_static_binding_non_blocking() -> Result<()> { + /*env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{}:{} [{}] {} - {}", + record.file().unwrap_or("unknown"), + record.line().unwrap_or(0), + record.level(), + chrono::Local::now().format("%H:%M:%S.%6f"), + record.args() + ) + }) + .filter(None, LevelFilter::Trace) + .init();*/ + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + let api = APIBuilder::new().with_media_engine(m).build(); + + let (pc_offer, pc_answer) = new_pair(&api).await?; + + pc_offer + .add_transceiver_from_kind(RTPCodecType::Video, None) + .await?; + + let vp8writer: Arc = Arc::new(TrackLocalStaticRTP::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + pc_answer.add_track(Arc::clone(&vp8writer)).await?; + + let offer = pc_offer.create_offer(None).await?; + pc_answer.set_remote_description(offer).await?; + + let answer = pc_answer.create_answer(None).await?; + pc_answer.set_local_description(answer).await?; + + if let Some(v) = vp8writer.as_any().downcast_ref::() { + v.write(&[0u8; 20]).await?; + } else { + panic!(); + } + + close_pair_now(&pc_offer, &pc_answer).await; + + Ok(()) +} + +/* +//TODO: func BenchmarkTrackLocalWrite(b *testing.B) { + offerPC, answerPC, err := newPair() + defer closePairNow(b, offerPC, answerPC) + if err != nil { + b.Fatalf("Failed to create a PC pair for testing") + } + + track, err := NewTrackLocalStaticRTP(RTPCodecCapability{mime_type: MIME_TYPE_VP8}, "video", "pion") + assert.NoError(b, err) + + _, err = offerPC.AddTrack(track) + assert.NoError(b, err) + + _, err = answerPC.AddTransceiverFromKind(RTPCodecTypeVideo) + assert.NoError(b, err) + + b.SetBytes(1024) + + buf := make([]byte, 1024) + for i := 0; i < b.N; i++ { + _, err := track.Write(buf) + assert.NoError(b, err) + } +} +*/ diff --git a/rtc/src/track/track_remote/mod.rs b/rtc/src/track/track_remote/mod.rs new file mode 100644 index 0000000..b5cd4f5 --- /dev/null +++ b/rtc/src/track/track_remote/mod.rs @@ -0,0 +1,320 @@ +use std::collections::VecDeque; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicU32, AtomicU8, AtomicUsize, Ordering}; +use std::sync::{Arc, Weak}; + +use arc_swap::ArcSwapOption; +use interceptor::{Attributes, Interceptor}; +use smol_str::SmolStr; +use tokio::sync::Mutex; +use util::sync::Mutex as SyncMutex; + +use crate::api::media_engine::MediaEngine; +use crate::error::{Error, Result}; +use crate::rtp_transceiver::rtp_codec::{RTCRtpCodecParameters, RTCRtpParameters, RTPCodecType}; +use crate::rtp_transceiver::rtp_receiver::RTPReceiverInternal; +use crate::rtp_transceiver::{PayloadType, SSRC}; + +lazy_static! { + static ref TRACK_REMOTE_UNIQUE_ID: AtomicUsize = AtomicUsize::new(0); +} +pub type OnMuteHdlrFn = Box< + dyn (FnMut() -> Pin + Send + 'static>>) + Send + Sync + 'static, +>; + +#[derive(Default)] +struct Handlers { + on_mute: ArcSwapOption>, + on_unmute: ArcSwapOption>, +} + +#[derive(Default)] +struct TrackRemoteInternal { + peeked: VecDeque<(rtp::packet::Packet, Attributes)>, +} + +/// TrackRemote represents a single inbound source of media +pub struct TrackRemote { + tid: usize, + + id: SyncMutex, + stream_id: SyncMutex, + + receive_mtu: usize, + payload_type: AtomicU8, //PayloadType, + kind: AtomicU8, //RTPCodecType, + ssrc: AtomicU32, //SSRC, + codec: SyncMutex, + pub(crate) params: SyncMutex, + rid: SmolStr, + + media_engine: Arc, + interceptor: Arc, + + handlers: Arc, + + receiver: Option>, + internal: Mutex, +} + +impl std::fmt::Debug for TrackRemote { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TrackRemote") + .field("id", &self.id) + .field("stream_id", &self.stream_id) + .field("payload_type", &self.payload_type) + .field("kind", &self.kind) + .field("ssrc", &self.ssrc) + .field("codec", &self.codec) + .field("params", &self.params) + .field("rid", &self.rid) + .finish() + } +} + +impl TrackRemote { + pub(crate) fn new( + receive_mtu: usize, + kind: RTPCodecType, + ssrc: SSRC, + rid: SmolStr, + receiver: Weak, + media_engine: Arc, + interceptor: Arc, + ) -> Self { + TrackRemote { + tid: TRACK_REMOTE_UNIQUE_ID.fetch_add(1, Ordering::SeqCst), + id: Default::default(), + stream_id: Default::default(), + receive_mtu, + payload_type: Default::default(), + kind: AtomicU8::new(kind as u8), + ssrc: AtomicU32::new(ssrc), + codec: Default::default(), + params: Default::default(), + rid, + receiver: Some(receiver), + media_engine, + interceptor, + handlers: Default::default(), + + internal: Default::default(), + } + } + + pub fn tid(&self) -> usize { + self.tid + } + + /// id is the unique identifier for this Track. This should be unique for the + /// stream, but doesn't have to globally unique. A common example would be 'audio' or 'video' + /// and StreamID would be 'desktop' or 'webcam' + pub fn id(&self) -> String { + let id = self.id.lock(); + id.clone() + } + + pub fn set_id(&self, s: String) { + let mut id = self.id.lock(); + *id = s; + } + + /// stream_id is the group this track belongs too. This must be unique + pub fn stream_id(&self) -> String { + let stream_id = self.stream_id.lock(); + stream_id.clone() + } + + pub fn set_stream_id(&self, s: String) { + let mut stream_id = self.stream_id.lock(); + *stream_id = s; + } + + /// rid gets the RTP Stream ID of this Track + /// With Simulcast you will have multiple tracks with the same ID, but different RID values. + /// In many cases a TrackRemote will not have an RID, so it is important to assert it is non-zero + pub fn rid(&self) -> &str { + self.rid.as_str() + } + + /// payload_type gets the PayloadType of the track + pub fn payload_type(&self) -> PayloadType { + self.payload_type.load(Ordering::SeqCst) + } + + pub fn set_payload_type(&self, payload_type: PayloadType) { + self.payload_type.store(payload_type, Ordering::SeqCst); + } + + /// kind gets the Kind of the track + pub fn kind(&self) -> RTPCodecType { + self.kind.load(Ordering::SeqCst).into() + } + + pub fn set_kind(&self, kind: RTPCodecType) { + self.kind.store(kind as u8, Ordering::SeqCst); + } + + /// ssrc gets the SSRC of the track + pub fn ssrc(&self) -> SSRC { + self.ssrc.load(Ordering::SeqCst) + } + + pub fn set_ssrc(&self, ssrc: SSRC) { + self.ssrc.store(ssrc, Ordering::SeqCst); + } + + /// msid gets the Msid of the track + pub fn msid(&self) -> String { + format!("{} {}", self.stream_id(), self.id()) + } + + /// codec gets the Codec of the track + pub fn codec(&self) -> RTCRtpCodecParameters { + let codec = self.codec.lock(); + codec.clone() + } + + pub fn set_codec(&self, codec: RTCRtpCodecParameters) { + let mut c = self.codec.lock(); + *c = codec; + } + + pub fn params(&self) -> RTCRtpParameters { + let p = self.params.lock(); + p.clone() + } + + pub fn set_params(&self, params: RTCRtpParameters) { + let mut p = self.params.lock(); + *p = params; + } + + pub fn onmute(&self, handler: F) + where + F: FnMut() -> Pin + Send + 'static>> + Send + 'static + Sync, + { + self.handlers + .on_mute + .store(Some(Arc::new(Mutex::new(Box::new(handler))))); + } + + pub fn onunmute(&self, handler: F) + where + F: FnMut() -> Pin + Send + 'static>> + Send + 'static + Sync, + { + self.handlers + .on_unmute + .store(Some(Arc::new(Mutex::new(Box::new(handler))))); + } + + /// Reads data from the track. + /// + /// **Cancel Safety:** This method is not cancel safe. Dropping the resulting [`Future`] before + /// it returns [`Poll::Ready`] will cause data loss. + pub async fn read(&self, b: &mut [u8]) -> Result<(rtp::packet::Packet, Attributes)> { + { + // Internal lock scope + let mut internal = self.internal.lock().await; + if let Some((pkt, attributes)) = internal.peeked.pop_front() { + self.check_and_update_track(&pkt).await?; + + return Ok((pkt, attributes)); + } + }; + + let receiver = match self.receiver.as_ref().and_then(|r| r.upgrade()) { + Some(r) => r, + None => return Err(Error::ErrRTPReceiverNil), + }; + + let (pkt, attributes) = receiver.read_rtp(b, self.tid).await?; + self.check_and_update_track(&pkt).await?; + Ok((pkt, attributes)) + } + + /// check_and_update_track checks payloadType for every incoming packet + /// once a different payloadType is detected the track will be updated + pub(crate) async fn check_and_update_track(&self, pkt: &rtp::packet::Packet) -> Result<()> { + let payload_type = pkt.header.payload_type; + if payload_type != self.payload_type() { + let p = self + .media_engine + .get_rtp_parameters_by_payload_type(payload_type) + .await?; + + if let Some(receiver) = &self.receiver { + if let Some(receiver) = receiver.upgrade() { + self.kind.store(receiver.kind as u8, Ordering::SeqCst); + } + } + self.payload_type.store(payload_type, Ordering::SeqCst); + { + let mut codec = self.codec.lock(); + *codec = if let Some(codec) = p.codecs.first() { + codec.clone() + } else { + return Err(Error::ErrCodecNotFound); + }; + } + { + let mut params = self.params.lock(); + *params = p; + } + } + + Ok(()) + } + + /// read_rtp is a convenience method that wraps Read and unmarshals for you. + pub async fn read_rtp(&self) -> Result<(rtp::packet::Packet, Attributes)> { + let mut b = vec![0u8; self.receive_mtu]; + let (pkt, attributes) = self.read(&mut b).await?; + + Ok((pkt, attributes)) + } + + /// peek is like Read, but it doesn't discard the packet read + pub(crate) async fn peek(&self, b: &mut [u8]) -> Result<(rtp::packet::Packet, Attributes)> { + let (pkt, a) = self.read(b).await?; + + // this might overwrite data if somebody peeked between the Read + // and us getting the lock. Oh well, we'll just drop a packet in + // that case. + { + let mut internal = self.internal.lock().await; + internal.peeked.push_back((pkt.clone(), a.clone())); + } + Ok((pkt, a)) + } + + /// Set the initially peeked data for this track. + /// + /// This is useful when a track is first created to populate data read from the track in the + /// process of identifying the track as part of simulcast probing. Using this during other + /// parts of the track's lifecycle is probably an error. + pub(crate) async fn prepopulate_peeked_data( + &self, + data: VecDeque<(rtp::packet::Packet, Attributes)>, + ) { + let mut internal = self.internal.lock().await; + internal.peeked = data; + } + + pub(crate) async fn fire_onmute(&self) { + let on_mute = self.handlers.on_mute.load(); + + if let Some(f) = on_mute.as_ref() { + (f.lock().await)().await + }; + } + + pub(crate) async fn fire_onunmute(&self) { + let on_unmute = self.handlers.on_unmute.load(); + + if let Some(f) = on_unmute.as_ref() { + (f.lock().await)().await + }; + } +}