Skip to content

Commit

Permalink
fix client_test.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
yngrtc committed Mar 16, 2024
1 parent d639bfb commit 2e72a8f
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 144 deletions.
17 changes: 17 additions & 0 deletions rtc-shared/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use crate::error::{Error, Result};
use std::net::{SocketAddr, ToSocketAddrs};

// match_range is a MatchFunc that accepts packets with the first byte in [lower..upper]
fn match_range(lower: u8, upper: u8) -> impl Fn(&[u8]) -> bool {
move |buf: &[u8]| -> bool {
Expand Down Expand Up @@ -53,3 +56,17 @@ pub fn match_srtp(buf: &[u8]) -> bool {
pub fn match_srtcp(buf: &[u8]) -> bool {
match_srtp_or_srtcp(buf) && is_rtcp(buf)
}

/// lookup host to SocketAddr
pub fn lookup_host<T>(use_ipv4: bool, host: T) -> Result<SocketAddr>
where
T: ToSocketAddrs,
{
for remote_addr in host.to_socket_addrs()? {
if (use_ipv4 && remote_addr.is_ipv4()) || (!use_ipv4 && remote_addr.is_ipv6()) {
return Ok(remote_addr);
}
}

Err(Error::ErrAddressParseFailed)
}
183 changes: 115 additions & 68 deletions rtc-turn/src/client/client_test.rs
Original file line number Diff line number Diff line change
@@ -1,120 +1,167 @@
use tokio::net::UdpSocket;

use super::*;
use crate::auth::*;
use std::collections::HashSet;
use std::net::UdpSocket;

async fn create_listening_test_client(rto_in_ms: u16) -> Result<Client> {
let conn = UdpSocket::bind("0.0.0.0:0").await?;
fn create_listening_test_client(rto_in_ms: u64) -> Result<(UdpSocket, Client)> {
let udp_socket = UdpSocket::bind("0.0.0.0:0")?;

let c = Client::new(ClientConfig {
let client = Client::new(ClientConfig {
stun_serv_addr: String::new(),
turn_serv_addr: String::new(),
local_addr: udp_socket.local_addr()?,
protocol: Protocol::UDP,
username: String::new(),
password: String::new(),
realm: String::new(),
software: "TEST SOFTWARE".to_owned(),
rto_in_ms,
conn: Arc::new(conn),
vnet: None,
})
.await?;

c.listen().await?;
})?;

Ok(c)
Ok((udp_socket, client))
}

async fn create_listening_test_client_with_stun_serv() -> Result<Client> {
let conn = UdpSocket::bind("0.0.0.0:0").await?;
fn create_listening_test_client_with_stun_serv() -> Result<(UdpSocket, Client)> {
let udp_socket = UdpSocket::bind("0.0.0.0:0")?;

let c = Client::new(ClientConfig {
let client = Client::new(ClientConfig {
stun_serv_addr: "stun1.l.google.com:19302".to_owned(),
turn_serv_addr: String::new(),
local_addr: udp_socket.local_addr()?,
protocol: Protocol::UDP,
username: String::new(),
password: String::new(),
realm: String::new(),
software: "TEST SOFTWARE".to_owned(),
rto_in_ms: 0,
conn: Arc::new(conn),
vnet: None,
})
.await?;

c.listen().await?;
})?;

Ok(c)
Ok((udp_socket, client))
}

#[tokio::test]
async fn test_client_with_stun_send_binding_request() -> Result<()> {
#[test]
fn test_client_with_stun_send_binding_request() -> Result<()> {
//env_logger::init();

let c = create_listening_test_client_with_stun_serv().await?;
let (conn, mut client) = create_listening_test_client_with_stun_serv()?;
let local_addr = conn.local_addr()?;

let tid = client.send_binding_request()?;

while let Some(transmit) = client.poll_transmit() {
conn.send_to(&transmit.message, transmit.transport.peer_addr)?;
}

let resp = c.send_binding_request().await?;
log::debug!("mapped-addr: {}", resp);
{
let ci = c.client_internal.lock().await;
let tm = ci.tr_map.lock().await;
assert_eq!(0, tm.size(), "should be no transaction left");
let mut buffer = vec![0u8; 2048];
let (n, peer_addr) = conn.recv_from(&mut buffer)?;
client.handle_transmit(Transmit {
now: Instant::now(),
transport: TransportContext {
local_addr,
peer_addr,
protocol: Protocol::UDP,
ecn: None,
},
message: BytesMut::from(&buffer[..n]),
})?;

if let Some(event) = client.poll_event() {
match event {
Event::BindingResponse(id, refl_addr) => {
assert_eq!(tid, id);
log::debug!("mapped-addr: {}", refl_addr);
}
_ => assert!(false),
}
} else {
assert!(false);
}

c.close().await?;
assert_eq!(0, client.tr_map.size(), "should be no transaction left");

client.close();

Ok(())
}

#[tokio::test]
async fn test_client_with_stun_send_binding_request_to_parallel() -> Result<()> {
env_logger::init();

let c1 = create_listening_test_client(0).await?;
let c2 = c1.clone();
#[test]
fn test_client_with_stun_send_binding_request_to_parallel() -> Result<()> {
//env_logger::init();

let (stared_tx, mut started_rx) = mpsc::channel::<()>(1);
let (finished_tx, mut finished_rx) = mpsc::channel::<()>(1);
let (conn, mut client) = create_listening_test_client(0)?;
let local_addr = conn.local_addr()?;

let to = lookup_host(true, "stun1.l.google.com:19302").await?;
let to = lookup_host(true, "stun1.l.google.com:19302")?;

tokio::spawn(async move {
drop(stared_tx);
if let Ok(resp) = c2.send_binding_request_to(&to.to_string()).await {
log::debug!("mapped-addr: {}", resp);
}
drop(finished_tx);
});
let tid1 = client.send_binding_request_to(to)?;
let tid2 = client.send_binding_request_to(to)?;
while let Some(transmit) = client.poll_transmit() {
conn.send_to(&transmit.message, transmit.transport.peer_addr)?;
}

let _ = started_rx.recv().await;
let mut buffer = vec![0u8; 2048];
for _ in 0..2 {
let (n, peer_addr) = conn.recv_from(&mut buffer)?;
client.handle_transmit(Transmit {
now: Instant::now(),
transport: TransportContext {
local_addr,
peer_addr,
protocol: Protocol::UDP,
ecn: None,
},
message: BytesMut::from(&buffer[..n]),
})?;
}

let resp = c1.send_binding_request_to(&to.to_string()).await?;
log::debug!("mapped-addr: {}", resp);
let mut tids = HashSet::new();
while let Some(event) = client.poll_event() {
match event {
Event::BindingResponse(tid, refl_addr) => {
tids.insert(tid);
log::debug!("mapped-addr: {}", refl_addr);
}
_ => {}
}
}

let _ = finished_rx.recv().await;
assert_eq!(2, tids.len());
assert!(tids.contains(&tid1));
assert!(tids.contains(&tid2));

c1.close().await?;
client.close();

Ok(())
}

#[tokio::test]
async fn test_client_with_stun_send_binding_request_to_timeout() -> Result<()> {
#[test]
fn test_client_with_stun_send_binding_request_to_timeout() -> Result<()> {
//env_logger::init();

let c = create_listening_test_client(10).await?;
let (conn, mut client) = create_listening_test_client(10)?;

let to = lookup_host(true, "127.0.0.1:9").await?;
let to = lookup_host(true, "127.0.0.1:9")?;

let result = c.send_binding_request_to(&to.to_string()).await;
assert!(result.is_err(), "expected error, but got ok");

c.close().await?;
let tid = client.send_binding_request_to(to)?;
while let Some(transmit) = client.poll_transmit() {
conn.send_to(&transmit.message, transmit.transport.peer_addr)?;
}

Ok(())
}
while let Some(to) = client.poll_timout() {
client.handle_timeout(to);
}

struct TestAuthHandler;
impl AuthHandler for TestAuthHandler {
fn auth_handle(&self, username: &str, realm: &str, _src_addr: SocketAddr) -> Result<Vec<u8>> {
Ok(generate_auth_key(username, realm, "pass"))
if let Some(event) = client.poll_event() {
match event {
Event::TransactionTimeout(id) => {
assert_eq!(tid, id);
}
_ => assert!(false),
}
} else {
assert!(false);
}

client.close();

Ok(())
}
33 changes: 21 additions & 12 deletions rtc-turn/src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*TODO:#[cfg(test)]
#[cfg(test)]
mod client_test;
*/

pub mod binding;
pub mod permission;
pub mod relay;
Expand All @@ -9,7 +9,6 @@ pub mod transaction;
use bytes::BytesMut;
use std::collections::{HashMap, VecDeque};
use std::net::SocketAddr;
use std::str::FromStr;
use std::time::Instant;

use stun::attributes::*;
Expand All @@ -31,6 +30,7 @@ use crate::proto::relayaddr::RelayedAddress;
use crate::proto::reqtrans::RequestedTransport;
use crate::proto::{PROTO_TCP, PROTO_UDP};
use shared::error::{Error, Result};
use shared::util::lookup_host;
use shared::{Protocol, Transmit, TransportContext};
use stun::error_code::ErrorCodeAttribute;
use stun::fingerprint::FINGERPRINT;
Expand Down Expand Up @@ -89,7 +89,7 @@ pub struct ClientConfig {
/// Client is a STUN client
pub struct Client {
stun_serv_addr: Option<SocketAddr>,
turn_serv_addr: SocketAddr,
turn_serv_addr: Option<SocketAddr>,
local_addr: SocketAddr,
protocol: Protocol,
username: Username,
Expand All @@ -112,13 +112,19 @@ impl Client {
let stun_serv_addr = if config.stun_serv_addr.is_empty() {
None
} else {
Some(SocketAddr::from_str(config.stun_serv_addr.as_str())?)
Some(lookup_host(
config.local_addr.is_ipv4(),
config.stun_serv_addr.as_str(),
)?)
};

let turn_serv_addr = if config.turn_serv_addr.is_empty() {
return Err(Error::ErrNilTurnSocket);
None
} else {
SocketAddr::from_str(config.turn_serv_addr.as_str())?
Some(lookup_host(
config.local_addr.is_ipv4(),
config.turn_serv_addr.as_str(),
)?)
};

Ok(Client {
Expand Down Expand Up @@ -500,8 +506,11 @@ impl Client {
])?;

log::debug!("client.Allocate call PerformTransaction 1");
let tid =
self.perform_transaction(&msg, self.turn_serv_addr, TransactionType::AllocateAttempt);
let tid = self.perform_transaction(
&msg,
self.turn_server_addr()?,
TransactionType::AllocateAttempt,
);
Ok(tid)
}

Expand Down Expand Up @@ -558,7 +567,7 @@ impl Client {
log::debug!("client.Allocate call PerformTransaction 2");
self.perform_transaction(
&msg,
self.turn_serv_addr,
self.turn_server_addr()?,
TransactionType::AllocateRequest(nonce),
);
}
Expand Down Expand Up @@ -599,8 +608,8 @@ impl Client {
}

/// turn_server_addr return the TURN server address
fn turn_server_addr(&self) -> SocketAddr {
self.turn_serv_addr
fn turn_server_addr(&self) -> Result<SocketAddr> {
self.turn_serv_addr.ok_or(Error::ErrNilTurnSocket)
}

/// username returns username
Expand Down
Loading

0 comments on commit 2e72a8f

Please sign in to comment.