From d01f8144c08cb4b6cf0f09d4e706272fab33e395 Mon Sep 17 00:00:00 2001 From: Ian Clarke Date: Sat, 13 Apr 2024 15:02:01 -0500 Subject: [PATCH] Create encrypted UDP transport with streaming and firewall hole-punching (#949) Co-authored-by: Ignacio Duart --- .gitignore | 1 + Cargo.lock | 160 +- Cargo.toml | 1 - crates/core/Cargo.toml | 9 +- crates/core/src/bin/freenet.rs | 2 +- crates/core/src/client_events.rs | 18 - crates/core/src/config.rs | 4 +- crates/core/src/lib.rs | 2 + .../src/node/network_bridge/p2p_protoc.rs | 17 +- crates/core/src/node/op_state_manager.rs | 115 +- crates/core/src/node/p2p_impl.rs | 8 - crates/core/src/node/testing_impl.rs | 11 +- crates/core/src/node/testing_impl/network.rs | 4 +- crates/core/src/operations/connect.rs | 37 +- crates/core/src/ring.rs | 12 +- crates/core/src/topology.rs | 4 +- crates/core/src/tracing.rs | 11 +- crates/core/src/transport.rs | 123 ++ .../core/src/transport/connection_handler.rs | 1331 +++++++++++++++++ crates/core/src/transport/crypto.rs | 62 + crates/core/src/transport/packet_data.rs | 349 +++++ crates/core/src/transport/peer_connection.rs | 463 ++++++ .../peer_connection/inbound_stream.rs | 122 ++ .../peer_connection/outbound_stream.rs | 147 ++ crates/core/src/transport/rate_limiter.rs | 185 +++ .../src/transport/received_packet_tracker.rs | 227 +++ .../core/src/transport/sent_packet_tracker.rs | 256 ++++ .../core/src/transport/symmetric_message.rs | 288 ++++ crates/core/src/util.rs | 5 +- crates/core/src/util/time_source.rs | 146 ++ crates/core/src/wasm_runtime/store.rs | 6 +- crates/fdev/Cargo.toml | 1 - crates/fdev/src/main.rs | 2 +- crates/fdev/src/testing/multiple_process.rs | 4 +- crates/fdev/src/testing/network.rs | 4 +- crates/fdev/src/testing/single_process.rs | 2 +- docs/src/SUMMARY.md | 1 + docs/src/architecture/transport.md | 201 +++ 38 files changed, 4120 insertions(+), 221 deletions(-) create mode 100644 crates/core/src/transport.rs create mode 100644 crates/core/src/transport/connection_handler.rs create mode 100644 crates/core/src/transport/crypto.rs create mode 100644 crates/core/src/transport/packet_data.rs create mode 100644 crates/core/src/transport/peer_connection.rs create mode 100644 crates/core/src/transport/peer_connection/inbound_stream.rs create mode 100644 crates/core/src/transport/peer_connection/outbound_stream.rs create mode 100644 crates/core/src/transport/rate_limiter.rs create mode 100644 crates/core/src/transport/received_packet_tracker.rs create mode 100644 crates/core/src/transport/sent_packet_tracker.rs create mode 100644 crates/core/src/transport/symmetric_message.rs create mode 100644 crates/core/src/util/time_source.rs create mode 100644 docs/src/architecture/transport.md diff --git a/.gitignore b/.gitignore index 936bda15a..2a3f4ec7a 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ config.toml .env .rustc* +rustc-ice*.txt diff --git a/Cargo.lock b/Cargo.lock index edb08bc0f..0bce8ea67 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -87,9 +87,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.16" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" [[package]] name = "android-tzdata" @@ -156,9 +156,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.81" +version = "1.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" +checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" [[package]] name = "approx" @@ -178,6 +178,12 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "arrayref" version = "0.3.7" @@ -285,9 +291,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.79" +version = "0.1.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", @@ -383,7 +389,7 @@ dependencies = [ "serde_path_to_error", "serde_urlencoded", "sha1", - "sync_wrapper 1.0.0", + "sync_wrapper 1.0.1", "tokio", "tokio-tungstenite 0.21.0", "tower", @@ -550,9 +556,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.15.4" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytecheck" @@ -619,9 +625,9 @@ checksum = "a2698f953def977c68f935bb0dfa959375ad4638570e969e2f1e9f433cbf1af6" [[package]] name = "cc" -version = "1.0.91" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd97381a8cc6493395a5afc4c691c1084b3768db713b73aa215217aa245d153" +checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7" dependencies = [ "jobserver", "libc", @@ -685,7 +691,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -938,9 +944,9 @@ checksum = "393bc73c451830ff8dbb3a07f61843d6cb41a084f9996319917c0b291ed785bb" [[package]] name = "crc" -version = "3.0.1" +version = "3.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86ec7a15cbe22e59248fc7eadb1907dab5ba09372595da4d73dd805ed4417dfe" +checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" dependencies = [ "crc-catalog", ] @@ -1326,18 +1332,18 @@ dependencies = [ [[package]] name = "either" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" +checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" dependencies = [ "serde", ] [[package]] name = "encoding_rs" -version = "0.8.33" +version = "0.8.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1" +checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" dependencies = [ "cfg-if", ] @@ -1515,7 +1521,6 @@ dependencies = [ "clap", "dashmap", "either", - "fastrand 2.0.2", "freenet", "freenet-stdlib", "futures", @@ -1618,8 +1623,10 @@ dependencies = [ name = "freenet" version = "0.0.7" dependencies = [ + "aes-gcm", "anyhow", "arbitrary", + "arc-swap", "async-trait", "asynchronous-codec 0.7.0", "axum", @@ -1640,7 +1647,6 @@ dependencies = [ "delegate", "directories", "either", - "fastrand 2.0.2", "flatbuffers", "freenet-stdlib", "futures", @@ -1659,6 +1665,7 @@ dependencies = [ "rand", "reqwest", "rocksdb", + "rsa", "serde", "serde_json", "serde_with", @@ -1910,9 +1917,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.13" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a06fddc2749e0528d2813f95e050e87e52c8cbbae56223b9babf73b3e53b0cc6" +checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" dependencies = [ "cfg-if", "js-sys", @@ -2470,9 +2477,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" -version = "0.1.28" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6" +checksum = "685a7d121ee3f65ae4fddd72b25a04bb36b6af81bc0828f7d5434c0fe60fa3a2" dependencies = [ "libc", ] @@ -2551,7 +2558,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" dependencies = [ "cfg-if", - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -3455,6 +3462,7 @@ dependencies = [ "num-iter", "num-traits", "rand", + "serde", "smallvec", "zeroize", ] @@ -4164,9 +4172,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.35" +version = "1.0.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" dependencies = [ "proc-macro2", ] @@ -4497,6 +4505,7 @@ dependencies = [ "pkcs1", "pkcs8", "rand_core", + "serde", "signature", "spki", "subtle", @@ -5326,9 +5335,9 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "sync_wrapper" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "384595c11a4e2969895cad5a8c4029115f5ab956a9e5ef4de79d11a426e5f20c" +checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" [[package]] name = "synstructure" @@ -5452,9 +5461,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.34" +version = "0.3.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" +checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" dependencies = [ "deranged", "itoa", @@ -5473,9 +5482,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774" +checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" dependencies = [ "num-conv", "time-core", @@ -6184,9 +6193,9 @@ dependencies = [ [[package]] name = "wasmer" -version = "4.2.6" +version = "4.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c15724dc25d1ee57962334aea8e41ade2675e5ea2ac6b8d42da6051b0face66" +checksum = "4014573f108a246858299eb230031e268316fd57207bd2e8afc79b20fc7ce983" dependencies = [ "bytes", "cfg-if", @@ -6213,9 +6222,9 @@ dependencies = [ [[package]] name = "wasmer-compiler" -version = "4.2.6" +version = "4.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55a7f3b3a96f8d844c25e2c032af9572306dd63fa93dc17bcca4c5458ac569bd" +checksum = "3a77bfe259f08e8ec9e77f8f772ebfb4149f799d1f637231c5a5a6a90c447256" dependencies = [ "backtrace", "bytes", @@ -6240,9 +6249,9 @@ dependencies = [ [[package]] name = "wasmer-compiler-cranelift" -version = "4.2.6" +version = "4.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "102e2c5bacac69495c4025767e2fa26797ffb27f242dccb7cf57d9cefd944386" +checksum = "9280c47ebc754f95357745a38a995dd766f149e16b26e1b7e35741eb23c03d12" dependencies = [ "cranelift-codegen", "cranelift-entity", @@ -6259,9 +6268,9 @@ dependencies = [ [[package]] name = "wasmer-derive" -version = "4.2.6" +version = "4.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ea737fa08f95d6abc4459f42a70a9833e8974b814e74971d77ef473814f4d4c" +checksum = "e9352877c4f07fc59146d21b56ae6dc469caf342587f49c81b4fbeafead31972" dependencies = [ "proc-macro-error", "proc-macro2", @@ -6271,9 +6280,9 @@ dependencies = [ [[package]] name = "wasmer-types" -version = "4.2.6" +version = "4.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0689110e291b0f07fc665f2824e5ff81df120848e8a9acfbf1a9bf7990773f9" +checksum = "749214b6170f2b2fbbfe5b7e7f8d381e64930ac4122f3abceb33cde0292d45d2" dependencies = [ "bytecheck", "enum-iterator", @@ -6287,9 +6296,9 @@ dependencies = [ [[package]] name = "wasmer-vm" -version = "4.2.6" +version = "4.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd41f822a1ac4242d478754e8ceba2806a00ea5072803622e1fe91e8e28b2a1" +checksum = "300215479de0deeb453e95aeb1b9c8ffd9bc7d9bd27c5f9e8a184e54db4d31a9" dependencies = [ "backtrace", "cc", @@ -6474,7 +6483,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -6505,7 +6514,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -6525,17 +6534,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" dependencies = [ - "windows_aarch64_gnullvm 0.52.4", - "windows_aarch64_msvc 0.52.4", - "windows_i686_gnu 0.52.4", - "windows_i686_msvc 0.52.4", - "windows_x86_64_gnu 0.52.4", - "windows_x86_64_gnullvm 0.52.4", - "windows_x86_64_msvc 0.52.4", + "windows_aarch64_gnullvm 0.52.5", + "windows_aarch64_msvc 0.52.5", + "windows_i686_gnu 0.52.5", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.5", + "windows_x86_64_gnu 0.52.5", + "windows_x86_64_gnullvm 0.52.5", + "windows_x86_64_msvc 0.52.5", ] [[package]] @@ -6546,9 +6556,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" [[package]] name = "windows_aarch64_msvc" @@ -6564,9 +6574,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" [[package]] name = "windows_i686_gnu" @@ -6582,9 +6592,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.4" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" [[package]] name = "windows_i686_msvc" @@ -6600,9 +6616,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" [[package]] name = "windows_x86_64_gnu" @@ -6618,9 +6634,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" [[package]] name = "windows_x86_64_gnullvm" @@ -6630,9 +6646,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" [[package]] name = "windows_x86_64_msvc" @@ -6648,15 +6664,15 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" [[package]] name = "winnow" -version = "0.6.5" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dffa400e67ed5a4dd237983829e66475f0a4a26938c4b04c21baede6262215b8" +checksum = "f0c976aaaa0e1f90dbb21e9587cdaf1d9679a1cde8875c0d6bd83ab96a208352" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index 00631ff12..089a10d3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,6 @@ crossbeam = "0.8.2" ctrlc = { version = "3.4" } dashmap = "^5.5" either = "1.8" -fastrand = "2" futures = "0.3" rand = { version = "0.8" } semver = { version = "1.0.14", features = ["serde"] } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index b130d0b3b..41f77c156 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -14,10 +14,12 @@ path = "src/bin/freenet.rs" [dependencies] anyhow = "1" +arc-swap = "1.6" asynchronous-codec = "0.7" +aes-gcm = "0.10.3" async-trait = "0.1" axum = { default-features = false, features = ["http1", "matched-path", "query", "tower-log", "ws", "json"], workspace = true } -bincode = "1" +bincode = "1.3.3" blake3 = { workspace = true } bs58 = "0.5" byteorder = "1" @@ -34,7 +36,6 @@ dashmap = { workspace = true } delegate = "0.12" directories = "5" either = { features = ["serde"], workspace = true } -fastrand = { workspace = true } flatbuffers = "23.5.26" futures = "0.3.21" headers = "0.4" @@ -54,6 +55,7 @@ serde_with = { workspace = true } sqlx = { features = ["runtime-tokio-rustls", "sqlite"], optional = true, version = "0.7" } stretto = { features = ["async", "sync"], version = "0.8" } tar = { version = "0.4.38" } +time = "0.3.30" thiserror = "1" tokio = { features = ["fs", "macros", "rt-multi-thread", "sync", "process"], version = "1" } tokio-tungstenite = "0.21" @@ -63,7 +65,7 @@ unsigned-varint = { version = "0.8", features = ["codec", "asynchronous_codec"] wasmer = { features = ["sys"], workspace = true } xz2 = { version = "0.1" } reqwest = { version = "0.11.23", features = ["json"] } -# enum-iterator = "1.4.1" +rsa = { version = "0.9.6", features = ["serde"] } # Tracing deps opentelemetry = "0.21.0" @@ -74,7 +76,6 @@ tracing-subscriber = { optional = true, version = "0.3.16" } # internal deps freenet-stdlib = { features = ["net"], workspace = true } -time = "0.3.30" [dev-dependencies] arbitrary = { features = ["derive"], version = "1" } diff --git a/crates/core/src/bin/freenet.rs b/crates/core/src/bin/freenet.rs index 0b7416bbb..0ab04cee3 100644 --- a/crates/core/src/bin/freenet.rs +++ b/crates/core/src/bin/freenet.rs @@ -21,7 +21,7 @@ async fn run_local(config: PeerCliConfig) -> Result<(), DynError> { } fn main() -> Result<(), DynError> { - freenet::config::set_logger(); + freenet::config::set_logger(None); let config = PeerCliConfig::parse(); let rt = tokio::runtime::Builder::new_multi_thread() .worker_threads(4) diff --git a/crates/core/src/client_events.rs b/crates/core/src/client_events.rs index 6e1ed9074..02fea0a2b 100644 --- a/crates/core/src/client_events.rs +++ b/crates/core/src/client_events.rs @@ -556,24 +556,6 @@ pub(crate) mod test { } } - impl RandomEventGenerator for fastrand::Rng { - fn gen_u8(&mut self) -> u8 { - self.u8(..u8::MAX) - } - - fn gen_range(&mut self, range: std::ops::Range) -> usize { - self.choice(range).expect("non empty") - } - - fn choose<'a, T>(&mut self, vec: &'a [T]) -> Option<&'a T> { - self.choice(0..vec.len()).and_then(|choice| vec.get(choice)) - } - - fn seed_from_u64(seed: u64) -> Self { - Self::with_seed(seed) - } - } - #[test] fn test_gen_event() { const NUM_PEERS: usize = 20; diff --git a/crates/core/src/config.rs b/crates/core/src/config.rs index 17c45a1fc..12185ea31 100644 --- a/crates/core/src/config.rs +++ b/crates/core/src/config.rs @@ -342,7 +342,7 @@ impl libp2p::swarm::Executor for GlobalExecutor { } } -pub fn set_logger() { +pub fn set_logger(level: Option) { #[cfg(feature = "trace")] { static LOGGER_SET: AtomicBool = AtomicBool::new(false); @@ -358,6 +358,6 @@ pub fn set_logger() { return; } - crate::tracing::tracer::init_tracer().expect("failed tracing initialization") + crate::tracing::tracer::init_tracer(level).expect("failed tracing initialization") } } diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 664752564..6a38a4acb 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -23,6 +23,8 @@ pub mod server; mod topology; /// Tracing and loging infrastructure. Includes our custom event log register. Tracing collectors, etc. mod tracing; +/// Code for communicating with other peers over UDP, handles hole-punching, error handling, etc. +mod transport; pub mod util; /// WASM code execution runtime, tailored for the contract and delegate APIs. mod wasm_runtime; diff --git a/crates/core/src/node/network_bridge/p2p_protoc.rs b/crates/core/src/node/network_bridge/p2p_protoc.rs index 292c2c733..970af203a 100644 --- a/crates/core/src/node/network_bridge/p2p_protoc.rs +++ b/crates/core/src/node/network_bridge/p2p_protoc.rs @@ -20,7 +20,6 @@ use libp2p::{ identify, identity::Keypair, multiaddr::Protocol, - ping, swarm::{ self, dial_opts::DialOpts, @@ -81,8 +80,6 @@ fn config_behaviour( ) .with_agent_version(CURRENT_AGENT_VER.to_string()); - let ping = ping::Behaviour::default(); - let peer_id = private_key.public().to_peer_id(); let auto_nat = { let config = autonat::Config { @@ -97,7 +94,6 @@ fn config_behaviour( }; NetBehaviour { - ping, identify: identify::Behaviour::new(ident_config), auto_nat, freenet: FreenetBehaviour { @@ -180,7 +176,8 @@ impl NetworkBridge for P2pBridge { async fn send(&self, target: &FreenetPeerId, msg: NetMessage) -> super::ConnResult<()> { self.log_register - .register_events(NetEventLog::from_outbound_msg(&msg, &self.op_manager.ring)); + .register_events(NetEventLog::from_outbound_msg(&msg, &self.op_manager.ring)) + .await; self.op_manager.sending_transaction(target, &msg); self.ev_listener_tx .send(Left((*target, Box::new(msg)))) @@ -488,7 +485,7 @@ impl P2pConnManager { Ok(Right(ConnectionClosed { peer: peer_id })) | Ok(Right(NodeAction(NodeEvent::DropConnection(peer_id)))) => { self.bridge.active_net_connections.remove(&peer_id); - op_manager.ring.prune_connection(peer_id); + op_manager.ring.prune_connection(peer_id).await; // todo: notify the handler, read `disconnect_peer_id` doc let _ = self.swarm.disconnect_peer_id(peer_id.0); tracing::info!("Dropped connection with peer {}", peer_id); @@ -1191,7 +1188,6 @@ fn decode_msg(buf: BytesMut) -> Result { #[behaviour(to_swarm = "NetEvent")] pub(in crate::node) struct NetBehaviour { identify: identify::Behaviour, - ping: ping::Behaviour, freenet: FreenetBehaviour, auto_nat: autonat::Behaviour, } @@ -1200,7 +1196,6 @@ pub(in crate::node) struct NetBehaviour { pub(in crate::node) enum NetEvent { Freenet(Box), Identify(Box), - Ping(ping::Event), Autonat(autonat::Event), } @@ -1216,12 +1211,6 @@ impl From for NetEvent { } } -impl From for NetEvent { - fn from(event: ping::Event) -> NetEvent { - Self::Ping(event) - } -} - impl From for NetEvent { fn from(event: NetMessage) -> NetEvent { Self::Freenet(Box::new(event)) diff --git a/crates/core/src/node/op_state_manager.rs b/crates/core/src/node/op_state_manager.rs index cb39a28ce..01abf3fe8 100644 --- a/crates/core/src/node/op_state_manager.rs +++ b/crates/core/src/node/op_state_manager.rs @@ -223,65 +223,6 @@ async fn garbage_cleanup_task( let mut ttl_set = BTreeSet::new(); - let mut remove_old = move |ttl_set: &mut BTreeSet>, - delayed: &mut Vec| { - let mut old_missing = std::mem::replace(delayed, Vec::with_capacity(200)); - for tx in old_missing.drain(..) { - if let Some(tx) = ops.completed.remove(&tx) { - if cfg!(feature = "trace-ot") { - event_register.notify_of_time_out(tx); - } else { - _ = tx; - } - continue; - } - let still_waiting = match tx.transaction_type() { - TransactionType::Connect => ops.connect.remove(&tx).is_none(), - TransactionType::Put => ops.put.remove(&tx).is_none(), - TransactionType::Get => ops.get.remove(&tx).is_none(), - TransactionType::Subscribe => ops.subscribe.remove(&tx).is_none(), - TransactionType::Update => ops.update.remove(&tx).is_none(), - }; - let timed_out = tx.timed_out(); - if still_waiting && !timed_out { - delayed.push(tx); - } else { - if still_waiting && timed_out { - ops.under_progress.remove(&tx); - ops.completed.remove(&tx); - } - live_tx_tracker.remove_finished_transaction(tx); - } - } - - // notice the use of reverse so the older transactions are removed instead of the newer ones - let older_than: Reverse = Reverse(Transaction::ttl_transaction()); - for Reverse(tx) in ttl_set.split_off(&older_than).into_iter() { - if ops.under_progress.contains(&tx) { - delayed.push(tx); - continue; - } - if let Some(tx) = ops.completed.remove(&tx) { - if cfg!(feature = "trace-ot") { - event_register.notify_of_time_out(tx); - } else { - _ = tx; - } - continue; - } - let removed = match tx.transaction_type() { - TransactionType::Connect => ops.connect.remove(&tx).is_some(), - TransactionType::Put => ops.put.remove(&tx).is_some(), - TransactionType::Get => ops.get.remove(&tx).is_some(), - TransactionType::Subscribe => ops.subscribe.remove(&tx).is_some(), - TransactionType::Update => ops.update.remove(&tx).is_some(), - }; - if removed { - live_tx_tracker.remove_finished_transaction(tx); - } - } - }; - let mut delayed = vec![]; loop { tokio::select! { @@ -291,7 +232,61 @@ async fn garbage_cleanup_task( } } _ = tick.tick() => { - remove_old(&mut ttl_set, &mut delayed); + let mut old_missing = std::mem::replace(&mut delayed, Vec::with_capacity(200)); + for tx in old_missing.drain(..) { + if let Some(tx) = ops.completed.remove(&tx) { + if cfg!(feature = "trace-ot") { + event_register.notify_of_time_out(tx).await; + } else { + _ = tx; + } + continue; + } + let still_waiting = match tx.transaction_type() { + TransactionType::Connect => ops.connect.remove(&tx).is_none(), + TransactionType::Put => ops.put.remove(&tx).is_none(), + TransactionType::Get => ops.get.remove(&tx).is_none(), + TransactionType::Subscribe => ops.subscribe.remove(&tx).is_none(), + TransactionType::Update => ops.update.remove(&tx).is_none(), + }; + let timed_out = tx.timed_out(); + if still_waiting && !timed_out { + delayed.push(tx); + } else { + if still_waiting && timed_out { + ops.under_progress.remove(&tx); + ops.completed.remove(&tx); + } + live_tx_tracker.remove_finished_transaction(tx); + } + } + + // notice the use of reverse so the older transactions are removed instead of the newer ones + let older_than: Reverse = Reverse(Transaction::ttl_transaction()); + for Reverse(tx) in ttl_set.split_off(&older_than).into_iter() { + if ops.under_progress.contains(&tx) { + delayed.push(tx); + continue; + } + if let Some(tx) = ops.completed.remove(&tx) { + if cfg!(feature = "trace-ot") { + event_register.notify_of_time_out(tx).await; + } else { + _ = tx; + } + continue; + } + let removed = match tx.transaction_type() { + TransactionType::Connect => ops.connect.remove(&tx).is_some(), + TransactionType::Put => ops.put.remove(&tx).is_some(), + TransactionType::Get => ops.get.remove(&tx).is_some(), + TransactionType::Subscribe => ops.subscribe.remove(&tx).is_some(), + TransactionType::Update => ops.update.remove(&tx).is_some(), + }; + if removed { + live_tx_tracker.remove_finished_transaction(tx); + } + } } } } diff --git a/crates/core/src/node/p2p_impl.rs b/crates/core/src/node/p2p_impl.rs index 06b247348..4c6ddefdb 100644 --- a/crates/core/src/node/p2p_impl.rs +++ b/crates/core/src/node/p2p_impl.rs @@ -167,7 +167,6 @@ impl NodeP2P { mod test { use std::{net::Ipv4Addr, time::Duration}; - use super::super::network_bridge::p2p_protoc::NetEvent; use super::*; use crate::{ client_events::test::MemoryEventsGen, @@ -178,7 +177,6 @@ mod test { }; use futures::StreamExt; - use libp2p::swarm::SwarmEvent; use tokio::sync::watch::channel; /// Ping test event loop @@ -189,12 +187,6 @@ mod test { peer.conn_manager.swarm.select_next_some(), ); match ev.await { - Ok(SwarmEvent::Behaviour(NetEvent::Ping(ping))) => { - if ping.result.is_ok() { - tracing::info!("ping done @ {}", peer.peer_key); - return Ok(()); - } - } Ok(other) => { tracing::debug!("{:?}", other) } diff --git a/crates/core/src/node/testing_impl.rs b/crates/core/src/node/testing_impl.rs index 7bc05cd91..42b104afb 100644 --- a/crates/core/src/node/testing_impl.rs +++ b/crates/core/src/node/testing_impl.rs @@ -1099,10 +1099,13 @@ where NodeEvent::ShutdownNode => break Ok(()), NodeEvent::DropConnection(peer) => { tracing::info!("Dropping connection to {peer}"); - event_register.register_events(Either::Left( - crate::tracing::NetEventLog::disconnected(&op_manager.ring, &peer), - )); - op_manager.ring.prune_connection(peer); + event_register + .register_events(Either::Left(crate::tracing::NetEventLog::disconnected( + &op_manager.ring, + &peer, + ))) + .await; + op_manager.ring.prune_connection(peer).await; continue; } NodeEvent::Disconnect { cause: Some(cause) } => { diff --git a/crates/core/src/node/testing_impl/network.rs b/crates/core/src/node/testing_impl/network.rs index b44266432..3619263a7 100644 --- a/crates/core/src/node/testing_impl/network.rs +++ b/crates/core/src/node/testing_impl/network.rs @@ -72,8 +72,8 @@ impl NetworkPeer { let event_register = { #[cfg(feature = "trace-ot")] { - use crate::tracing::{CombinedRegister, OTEventRegister}; - CombinedRegister::new([ + use crate::tracing::OTEventRegister; + crate::tracing::CombinedRegister::new([ Box::new(EventRegister::new( crate::config::Config::conf().event_log(), )), diff --git a/crates/core/src/operations/connect.rs b/crates/core/src/operations/connect.rs index 9f3ed2e87..c3bd9f8a0 100644 --- a/crates/core/src/operations/connect.rs +++ b/crates/core/src/operations/connect.rs @@ -625,10 +625,13 @@ impl Operation for ConnectOp { } network_bridge.add_connection(sender.peer).await?; - op_manager.ring.add_connection( - sender.location.ok_or(ConnectionError::LocationUnknown)?, - sender.peer, - ); + op_manager + .ring + .add_connection( + sender.location.ok_or(ConnectionError::LocationUnknown)?, + sender.peer, + ) + .await; tracing::debug!(tx = %id, from = %by_peer.peer, "Opened connection with peer"); if target != gateway { new_state = None; @@ -655,10 +658,13 @@ impl Operation for ConnectOp { "Successfully completed connection", ); network_bridge.add_connection(sender.peer).await?; - op_manager.ring.add_connection( - sender.location.ok_or(ConnectionError::LocationUnknown)?, - sender.peer, - ); + op_manager + .ring + .add_connection( + sender.location.ok_or(ConnectionError::LocationUnknown)?, + sender.peer, + ) + .await; new_state = None; } _ => return Err(OpError::UnexpectedOpState), @@ -735,12 +741,15 @@ async fn propagate_oc_to_responding_peers( ) { tracing::info!(tx = %id, from = %sender.peer, to = %other_peer.peer, "Established connection"); network_bridge.add_connection(other_peer.peer).await?; - op_manager.ring.add_connection( - other_peer - .location - .ok_or(ConnectionError::LocationUnknown)?, - other_peer.peer, - ); + op_manager + .ring + .add_connection( + other_peer + .location + .ok_or(ConnectionError::LocationUnknown)?, + other_peer.peer, + ) + .await; if other_peer.peer != sender.peer { // notify all the additional peers which accepted a request; // the gateway will be notified in the last message diff --git a/crates/core/src/ring.rs b/crates/core/src/ring.rs index b3c38003c..a8470ceb7 100644 --- a/crates/core/src/ring.rs +++ b/crates/core/src/ring.rs @@ -521,10 +521,11 @@ impl Ring { .record_request(recipient, target, request_type); } - pub fn add_connection(&self, loc: Location, peer: PeerId) { - let mut cbl = self.connections_by_location.write(); + pub async fn add_connection(&self, loc: Location, peer: PeerId) { self.event_register - .register_events(Either::Left(NetEventLog::connected(self, peer, loc))); + .register_events(Either::Left(NetEventLog::connected(self, peer, loc))) + .await; + let mut cbl = self.connections_by_location.write(); cbl.entry(loc).or_default().push(Connection { location: PeerKeyLocation { peer, @@ -654,7 +655,7 @@ impl Ring { self.connections_by_location.read().len() } - pub fn prune_connection(&self, peer: PeerId) { + pub async fn prune_connection(&self, peer: PeerId) { #[cfg(debug_assertions)] { tracing::info!(%peer, "Removing connection"); @@ -679,7 +680,8 @@ impl Ring { }); } self.event_register - .register_events(Either::Left(NetEventLog::disconnected(self, &peer))); + .register_events(Either::Left(NetEventLog::disconnected(self, &peer))) + .await; self.open_connections .fetch_sub(1, std::sync::atomic::Ordering::SeqCst); } diff --git a/crates/core/src/topology.rs b/crates/core/src/topology.rs index 20f6f0025..cdd76f458 100644 --- a/crates/core/src/topology.rs +++ b/crates/core/src/topology.rs @@ -546,7 +546,7 @@ mod tests { // Simulate a bunch of random requests clustered around 0.35 for _ in 0..NUM_REQUESTS { let requested_location = random_location(&random_location(&this_peer_location)); - // FIXME: Is PeerKeyLocation unimportant for this test? + // todo: Is PeerKeyLocation unimportant for this test? topology_manager.record_request( PeerKeyLocation::random(), requested_location, @@ -867,7 +867,7 @@ mod tests { } fn find_worst_peer( - peers: &Vec, + peers: &[PeerKeyLocation], bw_usage_by_peer: &[usize], requests_per_peer: &[usize], ) -> usize { diff --git a/crates/core/src/tracing.rs b/crates/core/src/tracing.rs index 84e3cb558..d56498089 100644 --- a/crates/core/src/tracing.rs +++ b/crates/core/src/tracing.rs @@ -38,6 +38,7 @@ pub(crate) use test::TestEventListener; use crate::node::OpManager; #[derive(Debug, Clone, Copy)] +#[allow(dead_code)] struct ListenerLogId(usize); /// A type that reacts to incoming messages from the network and records information about them. @@ -82,7 +83,7 @@ impl NetEventRegister for CombinedRegister { fn notify_of_time_out(&mut self, tx: Transaction) -> BoxFuture<()> { async move { for reg in &mut self.0 { - reg.notify_of_time_out(tx); + reg.notify_of_time_out(tx).await; } } .boxed() @@ -1148,16 +1149,18 @@ enum PutEvent { #[cfg(feature = "trace")] pub(crate) mod tracer { + use tracing::level_filters::LevelFilter; use tracing_subscriber::{Layer, Registry}; use crate::DynError; - pub fn init_tracer() -> Result<(), DynError> { + pub fn init_tracer(level: Option) -> Result<(), DynError> { let default_filter = if cfg!(any(test, debug_assertions)) { - tracing_subscriber::filter::LevelFilter::DEBUG + LevelFilter::DEBUG } else { - tracing_subscriber::filter::LevelFilter::INFO + LevelFilter::INFO }; + let default_filter = level.unwrap_or(default_filter); let filter_layer = tracing_subscriber::EnvFilter::builder() .with_default_directive(default_filter.into()) .from_env_lossy() diff --git a/crates/core/src/transport.rs b/crates/core/src/transport.rs new file mode 100644 index 000000000..81cfb5d88 --- /dev/null +++ b/crates/core/src/transport.rs @@ -0,0 +1,123 @@ +#![allow(dead_code)] // TODO: Remove before integration +//! Freenet Transport protocol implementation. +//! +//! Please see `docs/architecture/transport.md` for more information. +//! +use std::{borrow::Cow, io, net::SocketAddr}; + +use futures::Future; +use tokio::net::UdpSocket; + +mod connection_handler; +mod crypto; +mod packet_data; +mod peer_connection; +mod rate_limiter; +// todo: optimize trackers +mod received_packet_tracker; +mod sent_packet_tracker; +mod symmetric_message; + +type MessagePayload = Vec; + +type PacketId = u32; + +use self::{packet_data::PacketData, peer_connection::StreamId}; + +#[derive(Debug, thiserror::Error)] +pub(crate) enum TransportError { + #[error("transport handler channel closed, socket likely closed")] + ChannelClosed, + #[error("connection to remote closed")] + ConnectionClosed, + #[error("failed while establishing connection, reason: {cause}")] + ConnectionEstablishmentFailure { cause: Cow<'static, str> }, + #[error("incomplete inbound stream: {0}")] + IncompleteInboundStream(StreamId), + #[error(transparent)] + IO(#[from] std::io::Error), + #[error(transparent)] + Other(#[from] anyhow::Error), + #[error("{0}")] + PrivateKeyDecryptionError(aes_gcm::aead::Error), + #[error(transparent)] + PubKeyDecryptionError(#[from] rsa::errors::Error), + #[error(transparent)] + Serialization(#[from] bincode::Error), + #[error("received unexpected message from remote: {0}")] + UnexpectedMessage(Cow<'static, str>), +} + +/// Make connection handler more testable +trait Socket: Sized + Send + Sync + 'static { + fn bind(addr: SocketAddr) -> impl Future> + Send; + fn recv_from( + &self, + buf: &mut [u8], + ) -> impl Future> + Send; + fn send_to( + &self, + buf: &[u8], + target: SocketAddr, + ) -> impl Future> + Send; +} + +impl Socket for UdpSocket { + async fn bind(addr: SocketAddr) -> io::Result { + Self::bind(addr).await + } + + async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + self.recv_from(buf).await + } + + async fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result { + self.send_to(buf, target).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::transport::received_packet_tracker::ReportResult; + use crate::transport::sent_packet_tracker::{ResendAction, MESSAGE_CONFIRMATION_TIMEOUT}; + + #[test] + fn test_packet_send_receive_acknowledge_flow() { + let mut sent_tracker = sent_packet_tracker::tests::mock_sent_packet_tracker(); + let mut received_tracker = received_packet_tracker::tests::mock_received_packet_tracker(); + + // Simulate sending packets + for id in 1..=5 { + sent_tracker.report_sent_packet(id, vec![id as u8].into()); + } + + // Simulate receiving some packets + for id in [1u32, 3, 5] { + assert_eq!( + received_tracker.report_received_packet(id), + ReportResult::Ok + ); + } + + // Get receipts and simulate acknowledging them + let receipts = received_tracker.get_receipts(); + assert_eq!(receipts, vec![1u32, 3, 5]); + sent_tracker.report_received_receipts(&receipts); + + // Check resend action for lost packets + sent_tracker + .time_source + .advance_time(MESSAGE_CONFIRMATION_TIMEOUT); + for id in [2, 4] { + match sent_tracker.get_resend() { + ResendAction::Resend(packet_id, packet) => { + assert_eq!(packet_id, id); + // Simulate resending packet + sent_tracker.report_sent_packet(id, packet); + } + _ => panic!("Expected resend action for packet {}", id), + } + } + } +} diff --git a/crates/core/src/transport/connection_handler.rs b/crates/core/src/transport/connection_handler.rs new file mode 100644 index 000000000..fbd34380c --- /dev/null +++ b/crates/core/src/transport/connection_handler.rs @@ -0,0 +1,1331 @@ +use std::collections::BTreeMap; +use std::net::{Ipv4Addr, SocketAddr}; +use std::pin::Pin; +use std::sync::atomic::AtomicU32; +use std::sync::Arc; +use std::time::Duration; +use std::vec::Vec; + +use crate::transport::crypto::TransportSecretKey; +use crate::transport::packet_data::{AssymetricRSA, UnknownEncryption}; +use crate::transport::symmetric_message::OutboundConnection; +use aes_gcm::{Aes128Gcm, KeyInit}; +use futures::{ + stream::{FuturesUnordered, StreamExt}, + Future, +}; +use futures::{FutureExt, TryFutureExt}; +use tokio::net::UdpSocket; +use tokio::sync::{mpsc, oneshot}; +use tokio::task; + +use super::{ + crypto::{TransportKeypair, TransportPublicKey}, + packet_data::MAX_PACKET_SIZE, + peer_connection::{PeerConnection, RemoteConnection}, + sent_packet_tracker::SentPacketTracker, + symmetric_message::{SymmetricMessage, SymmetricMessagePayload}, + PacketData, Socket, TransportError, +}; + +const PROTOC_VERSION: [u8; 2] = 1u16.to_le_bytes(); + +// Constants for interval increase +const INITIAL_INTERVAL: Duration = Duration::from_millis(200); +const INTERVAL_INCREASE_FACTOR: u64 = 2; +const MAX_INTERVAL: Duration = Duration::from_millis(5000); // Maximum interval limit + +const DEFAULT_BW_TRACKER_WINDOW_SIZE: Duration = Duration::from_secs(10); +const BANDWITH_LIMIT: usize = 1024 * 1024 * 10; // 10 MB/s + +type ConnectionHandlerMessage = (SocketAddr, Vec); +pub type SerializedMessage = Vec; +type PeerChannel = ( + mpsc::Sender, + mpsc::Receiver, +); + +struct OutboundMessage { + remote_addr: SocketAddr, + msg: SerializedMessage, + recv: mpsc::Receiver, +} + +pub(crate) struct ConnectionHandler { + send_queue: mpsc::Sender<(SocketAddr, ConnectionEvent)>, + new_connection_notifier: mpsc::Receiver, +} + +impl ConnectionHandler { + pub async fn new( + keypair: TransportKeypair, + listen_port: u16, + is_gateway: bool, + ) -> Result { + // Bind the UDP socket to the specified port + let socket = Arc::new(S::bind((Ipv4Addr::UNSPECIFIED, listen_port).into()).await?); + Self::config_listener( + socket, + keypair, + is_gateway, + #[cfg(test)] + (Ipv4Addr::UNSPECIFIED, listen_port).into(), + ) + } + + fn config_listener( + socket: Arc, + keypair: TransportKeypair, + is_gateway: bool, + #[cfg(test)] socket_addr: SocketAddr, + ) -> Result { + let (conn_handler_sender, conn_handler_receiver) = mpsc::channel(100); + let (new_connection_sender, new_connection_notifier) = mpsc::channel(100); + + // Channel buffer is one so senders will await until the receiver is ready, important for bandwidth limiting + let (outbound_sender, outbound_recv) = mpsc::channel(1); + let transport = UdpPacketsListener { + is_gateway, + socket_listener: socket.clone(), + this_peer_keypair: keypair, + remote_connections: BTreeMap::new(), + connection_handler: conn_handler_receiver, + new_connection_notifier: new_connection_sender, + outbound_packets: outbound_sender, + #[cfg(test)] + this_addr: socket_addr, + }; + let bw_tracker = super::rate_limiter::PacketRateLimiter::new( + DEFAULT_BW_TRACKER_WINDOW_SIZE, + outbound_recv, + ); + let connection_handler = ConnectionHandler { + send_queue: conn_handler_sender, + new_connection_notifier, + }; + + task::spawn(bw_tracker.rate_limiter(BANDWITH_LIMIT, socket)); + task::spawn(transport.listen()); + + Ok(connection_handler) + } + + #[cfg(test)] + fn test_set_up( + socket_addr: SocketAddr, + socket: Arc, + keypair: TransportKeypair, + is_gateway: bool, + ) -> Result { + Self::config_listener(socket, keypair, is_gateway, socket_addr) + } + + pub async fn connect( + &mut self, + remote_public_key: TransportPublicKey, + remote_addr: SocketAddr, + remote_is_gateway: bool, + ) -> Pin> + Send>> { + let (open_connection, recv_connection) = oneshot::channel(); + if self + .send_queue + .send(( + remote_addr, + ConnectionEvent::ConnectionStart { + remote_public_key, + remote_is_gateway, + open_connection, + }, + )) + .await + .is_err() + { + return async { Err(TransportError::ChannelClosed) }.boxed(); + } + recv_connection + .map(|res| match res { + Ok(Ok(remote_conn)) => Ok(PeerConnection::new(remote_conn)), + Ok(Err(e)) => Err(e), + Err(_) => Err(TransportError::ConnectionEstablishmentFailure { + cause: "Failed to establish connection".into(), + }), + }) + .boxed() + } + + pub async fn next_connection(&mut self) -> Option { + self.new_connection_notifier.recv().await + } +} + +pub enum Message { + Short(Vec), + Streamed(Vec, mpsc::Receiver), +} + +pub struct StreamFragment { + pub fragment_number: u32, + pub fragment: Vec, +} + +/// Handles UDP transport internally. +struct UdpPacketsListener { + socket_listener: Arc, + remote_connections: BTreeMap, + connection_handler: mpsc::Receiver<(SocketAddr, ConnectionEvent)>, + this_peer_keypair: TransportKeypair, + is_gateway: bool, + new_connection_notifier: mpsc::Sender, + outbound_packets: mpsc::Sender<(SocketAddr, Arc<[u8]>)>, + #[cfg(test)] + this_addr: SocketAddr, +} + +type OngoingConnection = ( + mpsc::Sender>, + oneshot::Sender>, +); + +type OngoingConnectionResult = Option< + Result< + Result<(RemoteConnection, InboundRemoteConnection), (TransportError, SocketAddr)>, + tokio::task::JoinError, + >, +>; + +#[cfg(test)] +impl Drop for UdpPacketsListener { + fn drop(&mut self) { + tracing::info!(%self.this_addr, "Dropping UdpPacketsListener"); + } +} + +impl UdpPacketsListener { + async fn listen(mut self) -> Result<(), TransportError> { + let mut buf = [0u8; MAX_PACKET_SIZE]; + let mut ongoing_connections: BTreeMap = BTreeMap::new(); + let mut connection_tasks = FuturesUnordered::new(); + loop { + tokio::select! { + // Handling of inbound packets + recv_result = self.socket_listener.recv_from(&mut buf) => { + match recv_result { + Ok((size, remote_addr)) => { + let packet_data = PacketData::from_buf(&buf[..size]); + if let Some(remote_conn) = self.remote_connections.remove(&remote_addr){ + let _ = remote_conn.inbound_packet_sender.send(packet_data).await; + self.remote_connections.insert(remote_addr, remote_conn); + continue; + } + + if let Some((packets_sender, open_connection)) = ongoing_connections.remove(&remote_addr) { + if packets_sender.send(packet_data).await.is_err() { + // it can happen that the connection is established but the channel is closed because the task completed + // but we still ahven't polled the result future + tracing::debug!(%remote_addr, "failed to send packet to remote"); + } + ongoing_connections.insert(remote_addr, (packets_sender, open_connection)); + continue; + } + + if !self.is_gateway { + tracing::debug!(%remote_addr, "unexpected packet from remote"); + continue; + } + let packet_data = PacketData::from_buf(&buf[..size]); + // FIXME: also parallelize this like we do with nat_traversal future + if let Err(error) = self.gateway_connection(packet_data, remote_addr).await { + tracing::error!(%error, ?remote_addr, "Failed to establish connection"); + } + } + Err(e) => { + // TODO: this should panic and be propagate to the main task or retry and eventually fail + tracing::error!("Failed to receive UDP packet: {:?}", e); + return Err(e.into()); + } + } + }, + connection_handshake = connection_tasks.next(), if !connection_tasks.is_empty() => { + let Some(res): OngoingConnectionResult = connection_handshake else { + unreachable!(); + }; + match res.expect("task shouldn't panic") { + Ok((outbound_remote_conn, inbound_remote_connection)) => { + if let Some((_, result_sender)) = ongoing_connections.remove(&outbound_remote_conn.remote_addr) { + tracing::debug!(%outbound_remote_conn.remote_addr, "connection established"); + self.remote_connections.insert(outbound_remote_conn.remote_addr, inbound_remote_connection); + let _ = result_sender.send(Ok(outbound_remote_conn)); + } else { + tracing::error!(%outbound_remote_conn.remote_addr, "connection established but no ongoing connection found"); + } + } + Err((error, remote_addr)) => { + tracing::error!(%error, ?remote_addr, "Failed to establish connection"); + if let Some((_, result_sender)) = ongoing_connections.remove(&remote_addr) { + let _ = result_sender.send(Err(error)); + } + } + } + } + // Handling of connection events + connection_event = self.connection_handler.recv() => { + let Some((remote_addr, event)) = connection_event else { return Ok(()); }; + let ConnectionEvent::ConnectionStart { remote_public_key, remote_is_gateway, open_connection } = event; + tracing::debug!(%remote_addr, "attempting to establish connection"); + let (ongoing_connection, packets_sender) = self.traverse_nat( + remote_addr, remote_public_key, remote_is_gateway + ); + let task = tokio::spawn(ongoing_connection.map_err(move |error| { + (error, remote_addr) + })); + connection_tasks.push(task); + ongoing_connections.insert(remote_addr, (packets_sender, open_connection)); + }, + } + } + } + + async fn gateway_connection( + &mut self, + remote_intro_packet: PacketData, + remote_addr: SocketAddr, + ) -> Result<(), TransportError> { + tracing::debug!(%remote_addr, "new connection to gateway"); + let Ok(decrypted_intro_packet) = self + .this_peer_keypair + .secret + .decrypt(remote_intro_packet.data()) + else { + tracing::debug!(%remote_addr, "failed to decrypt packet with private key"); + return Ok(()); + }; + let protoc = &decrypted_intro_packet[..PROTOC_VERSION.len()]; + let outbound_key_bytes = + &decrypted_intro_packet[PROTOC_VERSION.len()..PROTOC_VERSION.len() + 16]; + let outbound_key = Aes128Gcm::new_from_slice(outbound_key_bytes).map_err(|_| { + TransportError::ConnectionEstablishmentFailure { + cause: "invalid symmetric key".into(), + } + })?; + if protoc != PROTOC_VERSION { + let packet = SymmetricMessage::ack_error(&outbound_key)?; + self.outbound_packets + .send((remote_addr, packet.prepared_send())) + .await + .map_err(|_| TransportError::ChannelClosed)?; + return Err(TransportError::ConnectionEstablishmentFailure { + cause: format!( + "remote is using a different protocol version: {:?}", + String::from_utf8_lossy(protoc) + ) + .into(), + }); + } + + let inbound_key_bytes = rand::random::<[u8; 16]>(); + let inbound_key = Aes128Gcm::new(&inbound_key_bytes.into()); + let outbound_ack_packet = + SymmetricMessage::ack_ok(&outbound_key, inbound_key_bytes, remote_addr)?; + + let mut buf = [0u8; MAX_PACKET_SIZE]; + let mut waiting_time = INITIAL_INTERVAL; + let mut attempts = 0; + const MAX_ATTEMPTS: usize = 20; + while attempts < MAX_ATTEMPTS { + self.outbound_packets + .send((remote_addr, outbound_ack_packet.clone().prepared_send())) + .await + .map_err(|_| TransportError::ChannelClosed)?; + + // wait until the remote sends the ack packet + let timeout = + tokio::time::timeout(waiting_time, self.socket_listener.recv_from(&mut buf)); + match timeout.await { + Ok(Ok((size, remote))) => { + let packet = PacketData::from_buf(&buf[..size]); + if remote != remote_addr { + if let Some(remote) = self.remote_connections.remove(&remote_addr) { + let _ = remote.inbound_packet_sender.send(packet).await; + self.remote_connections.insert(remote_addr, remote); + continue; + } + } + let _ = packet.try_decrypt_sym(&inbound_key).map_err(|_| { + tracing::debug!(%remote_addr, "Failed to decrypt packet with inbound key"); + TransportError::ConnectionEstablishmentFailure { + cause: "invalid symmetric key".into(), + } + })?; + } + Ok(Err(_)) => { + return Err(TransportError::ChannelClosed); + } + Err(_) => { + attempts += 1; + waiting_time = std::cmp::min( + Duration::from_millis( + waiting_time.as_millis() as u64 * INTERVAL_INCREASE_FACTOR, + ), + MAX_INTERVAL, + ); + continue; + } + } + // we know the inbound is successfully connected now and can proceed + // ignoring this will force them to resend the packet but that is fine and simpler + break; + } + + let sent_tracker = Arc::new(parking_lot::Mutex::new(SentPacketTracker::new())); + let peer_connection = PeerConnection::new(RemoteConnection { + outbound_packets: self.outbound_packets.clone(), + outbound_symmetric_key: outbound_key, + remote_addr, + sent_tracker: sent_tracker.clone(), + last_packet_id: Arc::new(AtomicU32::new(0)), + inbound_packet_recv: mpsc::channel(100).1, + inbound_symmetric_key: inbound_key, + inbound_symmetric_key_bytes: inbound_key_bytes, + my_address: None, + }); + + self.new_connection_notifier + .send(peer_connection) + .await + .map_err(|_| TransportError::ChannelClosed)?; + + sent_tracker.lock().report_sent_packet( + SymmetricMessage::FIRST_PACKET_ID, + outbound_ack_packet.prepared_send(), + ); + + Ok(()) + } + + // TODO: this value should be set given exponential backoff and max timeout + #[cfg(not(test))] + const NAT_TRAVERSAL_MAX_ATTEMPTS: usize = 20; + + #[cfg(test)] + const NAT_TRAVERSAL_MAX_ATTEMPTS: usize = 10; + + fn traverse_nat( + &mut self, + remote_addr: SocketAddr, + remote_public_key: TransportPublicKey, + remote_is_gateway: bool, + ) -> ( + impl Future> + + Send + + 'static, + mpsc::Sender>, + ) { + // Constants for exponential backoff + const INITIAL_TIMEOUT: Duration = Duration::from_secs(5); + const TIMEOUT_MULTIPLIER: f64 = 1.1; + #[cfg(not(test))] + const MAX_TIMEOUT: Duration = Duration::from_secs(60); // Maximum timeout limit + #[cfg(test)] + const MAX_TIMEOUT: Duration = Duration::from_secs(10); // Maximum timeout limit + + #[allow(clippy::large_enum_variant)] + enum ConnectionState { + /// Initial state of the joiner + StartOutbound, + /// Initial state of the joinee, at this point NAT has been already traversed + RemoteInbound { + /// Encrypted intro packet for comparison + intro_packet: PacketData, + }, + } + + fn decrypt_asym( + remote_addr: SocketAddr, + packet: &PacketData, + transport_secret_key: &TransportSecretKey, + outbound_sym_key: &mut Option, + state: &mut ConnectionState, + ) -> Result<(), ()> { + // probably the first packet to punch through the NAT + if let Ok(decrypted_intro_packet) = packet.try_decrypt_asym(transport_secret_key) { + tracing::debug!(%remote_addr, "received intro packet"); + let protoc = &decrypted_intro_packet.data()[..PROTOC_VERSION.len()]; + if protoc != PROTOC_VERSION { + todo!("return error"); + } + let outbound_key_bytes = + &decrypted_intro_packet.data()[PROTOC_VERSION.len()..PROTOC_VERSION.len() + 16]; + let outbound_key = + Aes128Gcm::new_from_slice(outbound_key_bytes).expect("correct length"); + *outbound_sym_key = Some(outbound_key.clone()); + *state = ConnectionState::RemoteInbound { + intro_packet: packet.assert_assymetric(), + }; + return Ok(()); + } + tracing::debug!(%remote_addr, "failed to decrypt packet"); + Err(()) + } + + let outbound_packets = self.outbound_packets.clone(); + let transport_secret_key = self.this_peer_keypair.secret.clone(); + let (inbound_from_remote, mut next_inbound) = + mpsc::channel::>(1); + let f = async move { + let mut state = ConnectionState::StartOutbound {}; + // Initialize timeout and interval + let mut timeout = INITIAL_TIMEOUT; + let mut interval_duration = INITIAL_INTERVAL; + let mut tick = tokio::time::interval(interval_duration); + + let mut failures = 0; + + let inbound_sym_key_bytes = rand::random::<[u8; 16]>(); + let inbound_sym_key = Aes128Gcm::new(&inbound_sym_key_bytes.into()); + + let mut outbound_sym_key: Option = None; + let outbound_intro_packet = { + let mut data = [0u8; { 16 + PROTOC_VERSION.len() }]; + data[..PROTOC_VERSION.len()].copy_from_slice(&PROTOC_VERSION); + data[PROTOC_VERSION.len()..].copy_from_slice(&inbound_sym_key_bytes); + PacketData::<_, MAX_PACKET_SIZE>::encrypt_with_pubkey(&data, &remote_public_key) + }; + + let mut sent_tracker = SentPacketTracker::new(); + + while failures < Self::NAT_TRAVERSAL_MAX_ATTEMPTS { + match state { + ConnectionState::StartOutbound { .. } => { + tracing::debug!(%remote_addr, "sending protocol version and inbound key"); + outbound_packets + .send((remote_addr, outbound_intro_packet.data().into())) + .await + .map_err(|_| TransportError::ChannelClosed)?; + } + ConnectionState::RemoteInbound { .. } => { + tracing::debug!(%remote_addr, "sending back protocol version and inbound key to remote"); + let our_inbound = SymmetricMessage::ack_ok( + outbound_sym_key.as_ref().expect("should be set"), + inbound_sym_key_bytes, + remote_addr, + )?; + outbound_packets + .send((remote_addr, our_inbound.data().into())) + .await + .map_err(|_| TransportError::ChannelClosed)?; + sent_tracker.report_sent_packet( + SymmetricMessage::FIRST_PACKET_ID, + our_inbound.data().into(), + ); + } + } + let next_inbound = tokio::time::timeout(timeout, next_inbound.recv()); + match next_inbound.await { + Ok(Some(packet)) => { + match state { + ConnectionState::StartOutbound { .. } => { + // at this point it's either the remote sending us an intro packet or a symmetric packet + // cause is the first packet that passes through the NAT + if let Ok(decrypted_packet) = + packet.try_decrypt_sym(&inbound_sym_key) + { + // the remote got our inbound key, so we know that they are at least at the RemoteInbound state + let symmetric_message = + SymmetricMessage::deser(decrypted_packet.data())?; + + #[cfg(test)] + { + tracing::debug!(%remote_addr, ?symmetric_message.payload, "received symmetric packet"); + } + + match symmetric_message.payload { + SymmetricMessagePayload::AckConnection { + result: + Ok(OutboundConnection { + key, + remote_addr: my_address, + }), + } => { + let outbound_sym_key = Aes128Gcm::new_from_slice(&key) + .map_err(|_| { + TransportError::ConnectionEstablishmentFailure { + cause: "invalid symmetric key".into(), + } + })?; + outbound_packets + .send(( + remote_addr, + SymmetricMessage::ack_ok( + &outbound_sym_key, + inbound_sym_key_bytes, + remote_addr, + )? + .data() + .into(), + )) + .await + .map_err(|_| TransportError::ChannelClosed)?; + let (inbound_sender, inbound_recv) = mpsc::channel(100); + return Ok(( + RemoteConnection { + outbound_packets: outbound_packets.clone(), + outbound_symmetric_key: outbound_sym_key, + remote_addr, + sent_tracker: Arc::new( + parking_lot::Mutex::new(sent_tracker), + ), + last_packet_id: Arc::new(AtomicU32::new(0)), + inbound_packet_recv: inbound_recv, + inbound_symmetric_key: inbound_sym_key, + inbound_symmetric_key_bytes: + inbound_sym_key_bytes, + my_address: Some(my_address), + }, + InboundRemoteConnection { + inbound_packet_sender: inbound_sender, + inbound_intro_packet: None, + inbound_checked_times: 0, + }, + )); + } + SymmetricMessagePayload::AckConnection { + result: Err(err), + } => { + return Err( + TransportError::ConnectionEstablishmentFailure { + cause: err, + }, + ); + } + _ => { + tracing::debug!(%remote_addr, "unexpected packet from remote"); + failures += 1; + continue; + } + } + } + + // probably the first packet to punch through the NAT + if decrypt_asym( + remote_addr, + &packet, + &transport_secret_key, + &mut outbound_sym_key, + &mut state, + ) + .is_ok() + { + continue; + } + + failures += 1; + tracing::debug!("Failed to decrypt packet"); + continue; + } + ConnectionState::RemoteInbound { + // this is the packet encrypted with out RSA pub key + ref intro_packet, + .. + } => { + // next packet should be an acknowledgement packet, but might also be a repeated + // intro packet so we need to handle that + if packet.is_intro_packet(intro_packet) { + // we add to the number of failures so we are not stuck in a loop retrying + failures += 1; + continue; + } + // if is not an intro packet, the connection is successful and we can proceed + let (inbound_sender, inbound_recv) = mpsc::channel(1); + return Ok(( + RemoteConnection { + outbound_packets: outbound_packets.clone(), + outbound_symmetric_key: outbound_sym_key + .expect("should be set at this stage"), + remote_addr, + sent_tracker: Arc::new(parking_lot::Mutex::new( + SentPacketTracker::new(), + )), + last_packet_id: Arc::new(AtomicU32::new(0)), + inbound_packet_recv: inbound_recv, + inbound_symmetric_key: inbound_sym_key, + inbound_symmetric_key_bytes: inbound_sym_key_bytes, + my_address: None, + }, + InboundRemoteConnection { + inbound_packet_sender: inbound_sender, + inbound_intro_packet: Some(intro_packet.clone()), + inbound_checked_times: 0, + }, + )); + } + } + } + Ok(None) => { + tracing::debug!("debug: connection closed"); + return Err(TransportError::ConnectionClosed); + } + Err(_) => { + failures += 1; + tracing::debug!("Failed to receive UDP response, time out"); + } + } + + // We have retried for a while, so return an error + if timeout >= MAX_TIMEOUT { + break; + } + + // Update timeout using exponential backoff, capped at MAX_TIMEOUT + timeout = std::cmp::min( + Duration::from_secs((timeout.as_secs() as f64 * TIMEOUT_MULTIPLIER) as u64), + MAX_TIMEOUT, + ); + + // Update interval, capped at MAX_INTERVAL + if interval_duration < MAX_INTERVAL { + interval_duration = std::cmp::min( + Duration::from_millis( + interval_duration.as_millis() as u64 * INTERVAL_INCREASE_FACTOR, + ), + MAX_INTERVAL, + ); + tick = tokio::time::interval(interval_duration); + } + + tick.tick().await; + } + + Err(TransportError::ConnectionEstablishmentFailure { + cause: "max connection attempts reached".into(), + }) + }; + (f, inbound_from_remote) + } +} + +enum ConnectionEvent { + ConnectionStart { + remote_public_key: TransportPublicKey, + remote_is_gateway: bool, + open_connection: oneshot::Sender>, + }, +} + +struct InboundRemoteConnection { + inbound_packet_sender: mpsc::Sender>, + inbound_intro_packet: Option>, + inbound_checked_times: usize, +} + +impl InboundRemoteConnection { + fn check_inbound_packet(&mut self, packet: &PacketData) -> bool { + let mut inbound = false; + if let Some(inbound_intro_packet) = self.inbound_intro_packet.as_ref() { + if packet.is_intro_packet(inbound_intro_packet) { + inbound = true; + } + } + if self.inbound_checked_times >= UdpPacketsListener::::NAT_TRAVERSAL_MAX_ATTEMPTS + { + // no point in checking more than the max attemps since they won't be sending + // the intro packet more than this amount of times + self.inbound_intro_packet = None; + } else { + self.inbound_checked_times += 1; + } + inbound + } +} + +#[cfg(test)] +mod test { + use std::{ + collections::HashMap, + net::Ipv4Addr, + ops::Range, + sync::{ + atomic::{AtomicU16, AtomicU64, AtomicUsize, Ordering}, + OnceLock, + }, + }; + + use futures::{stream::FuturesOrdered, TryStreamExt}; + use rand::{Rng, SeedableRng}; + use serde::{de::DeserializeOwned, Serialize}; + use tokio::sync::{Mutex, RwLock}; + use tracing::info; + + use super::*; + use crate::DynError; + + #[allow(clippy::type_complexity)] + static CHANNELS: OnceLock< + Arc)>>>>, + > = OnceLock::new(); + + #[derive(Default, Clone)] + enum PacketDropPolicy { + /// Receive all packets without dropping + #[default] + ReceiveAll, + /// Drop the packets randomly based on the factor + Factor(f64), + /// Drop packets fall in the given range + Range(Range), + } + + struct MockSocket { + inbound: Mutex)>>, + this: SocketAddr, + packet_drop_policy: PacketDropPolicy, + num_packets_sent: AtomicUsize, + rng: Mutex, + } + + impl MockSocket { + async fn test_config(packet_drop_policy: PacketDropPolicy, addr: SocketAddr) -> Self { + let channels = CHANNELS + .get_or_init(|| Arc::new(RwLock::new(HashMap::new()))) + .clone(); + let (outbound, inbound) = mpsc::unbounded_channel(); + channels.write().await.insert(addr, outbound); + static SEED: AtomicU64 = AtomicU64::new(0xfeedbeef); + MockSocket { + inbound: Mutex::new(inbound), + this: addr, + packet_drop_policy, + num_packets_sent: AtomicUsize::new(0), + rng: Mutex::new(rand::rngs::SmallRng::seed_from_u64( + SEED.fetch_add(1, std::sync::atomic::Ordering::SeqCst), + )), + } + } + } + + impl Socket for MockSocket { + async fn bind(_addr: SocketAddr) -> Result { + unimplemented!() + } + + async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> { + // tracing::trace!(this = %self.this, "waiting for packet"); + let Some((remote, packet)) = self.inbound.try_lock().unwrap().recv().await else { + tracing::error!(this = %self.this, "connection closed"); + return Err(std::io::ErrorKind::ConnectionAborted.into()); + }; + // tracing::trace!(?remote, this = %self.this, "receiving packet from remote"); + buf[..packet.len()].copy_from_slice(&packet[..]); + Ok((packet.len(), remote)) + } + + async fn send_to(&self, buf: &[u8], target: SocketAddr) -> std::io::Result { + let packet_idx = self.num_packets_sent.fetch_add(1, Ordering::Release); + match &self.packet_drop_policy { + PacketDropPolicy::ReceiveAll => {} + PacketDropPolicy::Factor(factor) => { + if *factor > self.rng.try_lock().unwrap().gen::() { + tracing::trace!(id=%packet_idx, %self.this, "drop packet"); + return Ok(buf.len()); + } + } + PacketDropPolicy::Range(r) => { + if r.contains(&packet_idx) { + tracing::trace!(id=%packet_idx, %self.this, "drop packet"); + return Ok(buf.len()); + } + } + } + + assert!(self.this != target, "cannot send to self"); + let channels = CHANNELS + .get_or_init(|| Arc::new(RwLock::new(HashMap::new()))) + .clone(); + let channels = channels.read().await; + let Some(sender) = channels.get(&target) else { + return Ok(0); + }; + // tracing::trace!(?target, ?self.this, "sending packet to remote"); + sender + .send((self.this, buf.to_vec())) + .map_err(|_| std::io::ErrorKind::ConnectionAborted)?; + // tracing::trace!(?target, ?self.this, "packet sent to remote"); + Ok(buf.len()) + } + } + + impl Drop for MockSocket { + fn drop(&mut self) { + let channels = CHANNELS + .get_or_init(|| Arc::new(RwLock::new(HashMap::new()))) + .clone(); + loop { + if let Ok(mut channels) = channels.try_write() { + channels.remove(&self.this); + break; + } + // unorthodox blocking here but shouldn't be a problem for testing + std::thread::sleep(Duration::from_millis(1)); + } + } + } + + async fn set_peer_connection( + packet_drop_policy: PacketDropPolicy, + ) -> Result<(TransportPublicKey, ConnectionHandler, SocketAddr), DynError> { + set_peer_connection_in(packet_drop_policy, false).await + } + + async fn set_gateway_connection( + packet_drop_policy: PacketDropPolicy, + ) -> Result<(TransportPublicKey, ConnectionHandler, SocketAddr), DynError> { + set_peer_connection_in(packet_drop_policy, true).await + } + + async fn set_peer_connection_in( + packet_drop_policy: PacketDropPolicy, + gateway: bool, + ) -> Result<(TransportPublicKey, ConnectionHandler, SocketAddr), DynError> { + static PORT: AtomicU16 = AtomicU16::new(25000); + + let peer_keypair = TransportKeypair::new(); + let peer_pub = peer_keypair.public.clone(); + let port = PORT.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let socket = Arc::new( + MockSocket::test_config(packet_drop_policy, (Ipv4Addr::LOCALHOST, port).into()).await, + ); + let peer_conn = ConnectionHandler::test_set_up( + (Ipv4Addr::LOCALHOST, port).into(), + socket, + peer_keypair, + gateway, + ) + .expect("failed to create peer"); + Ok((peer_pub, peer_conn, (Ipv4Addr::LOCALHOST, port).into())) + } + + trait TestFixture: Clone + Send + Sync + 'static { + type Message: DeserializeOwned + Serialize + Send + 'static; + fn expected_iterations(&self) -> usize; + fn gen_msg(&mut self) -> Self::Message; + fn assert_message_ok(&self, peer_idx: usize, msg: Self::Message) -> bool; + } + + struct TestConfig { + packet_drop_policy: PacketDropPolicy, + peers: usize, + wait_time: Duration, + } + + impl Default for TestConfig { + fn default() -> Self { + Self { + packet_drop_policy: PacketDropPolicy::ReceiveAll, + peers: 2, + wait_time: Duration::from_secs(2), + } + } + } + + async fn run_test( + config: TestConfig, + generators: Vec, + ) -> Result<(), DynError> { + assert_eq!(generators.len(), config.peers); + let mut peer_keys_and_addr = vec![]; + let mut peer_conns = vec![]; + for _ in 0..config.peers { + let (peer_pub, peer, peer_addr) = + set_peer_connection(config.packet_drop_policy.clone()).await?; + peer_keys_and_addr.push((peer_pub, peer_addr)); + peer_conns.push(peer); + } + + let mut tasks = vec![]; + let barrier = Arc::new(tokio::sync::Barrier::new(config.peers)); + for (i, (mut peer, test_generator)) in peer_conns.into_iter().zip(generators).enumerate() { + let mut peer_keys_and_addr = peer_keys_and_addr.clone(); + peer_keys_and_addr.remove(i); + let barrier_cp = barrier.clone(); + let peer = tokio::spawn(async move { + let mut conns = FuturesOrdered::new(); + let mut establish_conns = Vec::new(); + barrier_cp.wait().await; + for (peer_pub, peer_addr) in &peer_keys_and_addr { + let peer_conn = tokio::time::timeout( + Duration::from_secs(10), + peer.connect(peer_pub.clone(), *peer_addr, false).await, + ); + establish_conns.push(peer_conn); + } + let connections = futures::future::try_join_all(establish_conns) + .await? + .into_iter() + .collect::, _>>()?; + // additional wait time so we can clear up any additional messages that may need to be sent + let extra_wait = if config.wait_time.as_secs() > 10 { + Duration::from_secs(3) + } else { + Duration::from_secs(1) + }; + for ((_, peer_addr), mut peer_conn) in + peer_keys_and_addr.into_iter().zip(connections) + { + let mut test_gen_cp = test_generator.clone(); + conns.push_back(async move { + let mut messages = vec![]; + let mut to = tokio::time::interval(config.wait_time); + to.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + to.tick().await; + let start = std::time::Instant::now(); + while messages.len() < test_gen_cp.expected_iterations() { + peer_conn.send(test_gen_cp.gen_msg()).await?; + let msg = tokio::select! { + _ = to.tick() => { + return Err::<_, DynError>( + format!("timeout waiting for messages, total time: {:.2}", start.elapsed().as_secs_f64()).into() + ); + } + msg = peer_conn.recv() => { + msg + } + }; + match msg { + Ok(msg) => { + let output_as_str: G::Message = bincode::deserialize(&msg)?; + messages.push(output_as_str); + info!("{peer_addr:?} received {} messages", messages.len()); + } + Err(error) => { + tracing::error!(%error, "error receiving message"); + return Err(error.into()); + } + } + } + + tracing::info!(%peer_addr, "finished"); + let _ = tokio::time::timeout(extra_wait, peer_conn.recv()).await; + Ok(messages) + }); + } + let results = conns.try_collect::>().await?; + Ok::<_, DynError>((results, test_generator)) + }); + tasks.push(peer); + } + + let all_results = futures::future::try_join_all(tasks) + .await? + .into_iter() + .collect::, _>>()?; + for (peer_results, test_gen) in all_results { + for (idx, result) in peer_results.into_iter().enumerate() { + assert_eq!(result.len(), test_gen.expected_iterations()); + for msg in result { + assert!(test_gen.assert_message_ok(idx, msg)); + } + } + } + + Ok(()) + } + + #[tokio::test] + async fn simulate_nat_traversal() -> Result<(), DynError> { + // crate::config::set_logger(); + let (peer_a_pub, mut peer_a, peer_a_addr) = set_peer_connection(Default::default()).await?; + let (peer_b_pub, mut peer_b, peer_b_addr) = set_peer_connection(Default::default()).await?; + + let peer_b = tokio::spawn(async move { + let peer_a_conn = peer_b.connect(peer_a_pub, peer_a_addr, false).await; + let _ = tokio::time::timeout(Duration::from_secs(500), peer_a_conn).await??; + Ok::<_, DynError>(()) + }); + + let peer_a = tokio::spawn(async move { + let peer_b_conn = peer_a.connect(peer_b_pub, peer_b_addr, false).await; + let _ = tokio::time::timeout(Duration::from_secs(500), peer_b_conn).await??; + Ok::<_, DynError>(()) + }); + + let (a, b) = tokio::try_join!(peer_a, peer_b)?; + a?; + b?; + Ok(()) + } + + #[tokio::test] + async fn simulate_nat_traversal_drop_first_packets_for_all() -> Result<(), DynError> { + // crate::config::set_logger(Some(tracing::level_filters::LevelFilter::TRACE)); + let (peer_a_pub, mut peer_a, peer_a_addr) = + set_peer_connection(PacketDropPolicy::Range(0..1)).await?; + let (peer_b_pub, mut peer_b, peer_b_addr) = + set_peer_connection(PacketDropPolicy::Range(0..1)).await?; + + let peer_b = tokio::spawn(async move { + let peer_a_conn = peer_b.connect(peer_a_pub, peer_a_addr, false).await; + let _ = tokio::time::timeout(Duration::from_secs(500), peer_a_conn).await??; + Ok::<_, DynError>(()) + }); + + let peer_a = tokio::spawn(async move { + let peer_b_conn = peer_a.connect(peer_b_pub, peer_b_addr, false).await; + let _ = tokio::time::timeout(Duration::from_secs(500), peer_b_conn).await??; + Ok::<_, DynError>(()) + }); + + let (a, b) = tokio::try_join!(peer_a, peer_b)?; + a?; + b?; + Ok(()) + } + + #[tokio::test] + async fn simulate_nat_traversal_drop_first_packets_of_peerb() -> Result<(), DynError> { + // crate::config::set_logger(Some(tracing::level_filters::LevelFilter::TRACE)); + let (peer_a_pub, mut peer_a, peer_a_addr) = set_peer_connection(Default::default()).await?; + let (peer_b_pub, mut peer_b, peer_b_addr) = + set_peer_connection(PacketDropPolicy::Range(0..1)).await?; + + let peer_b = tokio::spawn(async move { + let peer_a_conn = peer_b.connect(peer_a_pub, peer_a_addr, false).await; + let mut conn = tokio::time::timeout(Duration::from_secs(500), peer_a_conn).await??; + let _ = tokio::time::timeout(Duration::from_secs(3), conn.recv()).await; + Ok::<_, DynError>(()) + }); + + let peer_a = tokio::spawn(async move { + let peer_b_conn = peer_a.connect(peer_b_pub, peer_b_addr, false).await; + let mut conn = tokio::time::timeout(Duration::from_secs(500), peer_b_conn).await??; + let _ = tokio::time::timeout(Duration::from_secs(3), conn.recv()).await; + Ok::<_, DynError>(()) + }); + + let (a, b) = tokio::try_join!(peer_a, peer_b)?; + a?; + b?; + Ok(()) + } + + #[tokio::test] + async fn simulate_gateway_connection() -> Result<(), DynError> { + let (peer_a_pub, mut peer_a, peer_a_addr) = set_peer_connection(Default::default()).await?; + let (peer_b_pub, mut peer_b, peer_b_addr) = + set_gateway_connection(Default::default()).await?; + + let peer_b = tokio::spawn(async move { + let peer_a_conn = peer_b.connect(peer_a_pub, peer_a_addr, false).await; + let _ = tokio::time::timeout(Duration::from_secs(500), peer_a_conn).await??; + Ok::<_, DynError>(()) + }); + + let peer_a = tokio::spawn(async move { + let peer_b_conn = peer_a.connect(peer_b_pub, peer_b_addr, false).await; + let _ = tokio::time::timeout(Duration::from_secs(500), peer_b_conn).await??; + Ok::<_, DynError>(()) + }); + + let (a, b) = tokio::try_join!(peer_a, peer_b)?; + a?; + b?; + Ok(()) + } + + #[tokio::test] + async fn simulate_gateway_connection_drop_first_packets_of_gateway() -> Result<(), DynError> { + // crate::config::set_logger(Some(tracing::level_filters::LevelFilter::TRACE)); + let (peer_a_pub, mut peer_a, peer_a_addr) = set_peer_connection(Default::default()).await?; + let (peer_b_pub, mut peer_b, peer_b_addr) = + set_gateway_connection(PacketDropPolicy::Range(0..1)).await?; + + let peer_b = tokio::spawn(async move { + let peer_a_conn = peer_b.connect(peer_a_pub, peer_a_addr, false).await; + let _ = tokio::time::timeout(Duration::from_secs(500), peer_a_conn).await??; + Ok::<_, DynError>(()) + }); + + let peer_a = tokio::spawn(async move { + let peer_b_conn = peer_a.connect(peer_b_pub, peer_b_addr, false).await; + let _ = tokio::time::timeout(Duration::from_secs(500), peer_b_conn).await??; + Ok::<_, DynError>(()) + }); + + let (a, b) = tokio::try_join!(peer_a, peer_b)?; + a?; + b?; + Ok(()) + } + + #[tokio::test] + async fn simulate_gateway_connection_drop_first_packets_for_all() -> Result<(), DynError> { + // crate::config::set_logger(Some(tracing::level_filters::LevelFilter::TRACE)); + let (peer_a_pub, mut peer_a, peer_a_addr) = + set_peer_connection(PacketDropPolicy::Range(0..1)).await?; + let (peer_b_pub, mut peer_b, peer_b_addr) = + set_gateway_connection(PacketDropPolicy::Range(0..1)).await?; + + let peer_b = tokio::spawn(async move { + let peer_a_conn = peer_b.connect(peer_a_pub, peer_a_addr, false).await; + let _ = tokio::time::timeout(Duration::from_secs(500), peer_a_conn).await??; + Ok::<_, DynError>(()) + }); + + let peer_a = tokio::spawn(async move { + let peer_b_conn = peer_a.connect(peer_b_pub, peer_b_addr, false).await; + let _ = tokio::time::timeout(Duration::from_secs(500), peer_b_conn).await??; + Ok::<_, DynError>(()) + }); + + let (a, b) = tokio::try_join!(peer_a, peer_b)?; + a?; + b?; + Ok(()) + } + + #[tokio::test] + async fn simulate_gateway_connection_drop_first_packets_of_peer() -> Result<(), DynError> { + // crate::config::set_logger(Some(tracing::level_filters::LevelFilter::TRACE)); + let (peer_a_pub, mut peer_a, peer_a_addr) = + set_peer_connection(PacketDropPolicy::Range(0..1)).await?; + let (peer_b_pub, mut peer_b, peer_b_addr) = + set_gateway_connection(Default::default()).await?; + + let peer_b = tokio::spawn(async move { + let peer_a_conn = peer_b.connect(peer_a_pub, peer_a_addr, false).await; + let _ = tokio::time::timeout(Duration::from_secs(500), peer_a_conn).await??; + Ok::<_, DynError>(()) + }); + + let peer_a = tokio::spawn(async move { + let peer_b_conn = peer_a.connect(peer_b_pub, peer_b_addr, false).await; + let _ = tokio::time::timeout(Duration::from_secs(500), peer_b_conn).await??; + Ok::<_, DynError>(()) + }); + + let (a, b) = tokio::try_join!(peer_a, peer_b)?; + a?; + b?; + Ok(()) + } + + #[tokio::test] + async fn simulate_send_short_message() -> Result<(), DynError> { + // crate::config::set_logger(Some(tracing::level_filters::LevelFilter::TRACE)); + #[derive(Clone, Copy)] + struct TestData(&'static str, usize); + + impl TestFixture for TestData { + type Message = String; + fn expected_iterations(&self) -> usize { + 10 + } + + fn gen_msg(&mut self) -> Self::Message { + self.0.to_string() + } + + fn assert_message_ok(&self, _peer_idx: usize, msg: Self::Message) -> bool { + msg == "foo" + } + } + + run_test( + TestConfig { + peers: 10, + ..Default::default() + }, + Vec::from_iter((0..10).map(|i| TestData("foo", i))), + ) + .await + } + + #[tokio::test] + async fn simulate_send_streamed_message() -> Result<(), DynError> { + // crate::config::set_logger(Some(tracing::level_filters::LevelFilter::TRACE)); + #[derive(Clone, Copy)] + struct TestData(&'static str); + + impl TestFixture for TestData { + type Message = String; + fn expected_iterations(&self) -> usize { + 10 + } + + fn gen_msg(&mut self) -> Self::Message { + self.0.repeat(3000) + } + + fn assert_message_ok(&self, _: usize, msg: Self::Message) -> bool { + if self.0 == "foo" { + msg.contains("bar") && msg.len() == "bar".len() * 3000 + } else { + msg.contains("foo") && msg.len() == "foo".len() * 3000 + } + } + } + + run_test( + TestConfig::default(), + vec![TestData("foo"), TestData("bar")], + ) + .await + } + + // #[ignore] + #[tokio::test(flavor = "multi_thread", worker_threads = 5)] + // #[tokio::test] + async fn simulate_packet_dropping() -> Result<(), DynError> { + // crate::config::set_logger(Some(tracing::level_filters::LevelFilter::INFO)); + #[derive(Clone, Copy)] + struct TestData(&'static str); + + impl TestFixture for TestData { + type Message = String; + fn expected_iterations(&self) -> usize { + 10 + } + + fn gen_msg(&mut self) -> Self::Message { + self.0.repeat(1000) + } + + fn assert_message_ok(&self, _: usize, msg: Self::Message) -> bool { + if self.0 == "foo" { + msg.contains("bar") && msg.len() == "bar".len() * 1000 + } else { + msg.contains("foo") && msg.len() == "foo".len() * 1000 + } + } + } + + let mut tests = FuturesOrdered::new(); + let mut rng = rand::rngs::StdRng::seed_from_u64(3); + for factor in std::iter::repeat(()) + .map(|_| rng.gen::()) + .filter(|x| *x > 0.05 && *x < 0.25) + .take(5) + { + let wait_time = Duration::from_secs((((factor * 5.0 + 1.0) * 15.0) + 10.0) as u64); + tracing::info!( + "packet loss factor: {factor} (wait time: {wait_time})", + wait_time = wait_time.as_secs() + ); + tests.push_back(tokio::spawn(run_test( + TestConfig { + packet_drop_policy: PacketDropPolicy::Factor(factor), + wait_time, + ..Default::default() + }, + vec![TestData("foo"), TestData("bar")], + ))); + } + let mut test_no = 0; + while let Some(result) = tests.next().await { + result?.map_err(|e| { + tracing::error!(%test_no, "error in test"); + e + })?; + test_no += 1; + } + Ok(()) + } +} diff --git a/crates/core/src/transport/crypto.rs b/crates/core/src/transport/crypto.rs new file mode 100644 index 000000000..0b5c73f61 --- /dev/null +++ b/crates/core/src/transport/crypto.rs @@ -0,0 +1,62 @@ +use rand::rngs::OsRng; +use rsa::{Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey}; +use serde::{Deserialize, Serialize}; + +pub(super) struct TransportKeypair { + pub public: TransportPublicKey, + pub secret: TransportSecretKey, +} + +impl TransportKeypair { + pub fn new() -> Self { + let mut rng = OsRng; + // Key size, can be adjusted + const BITS: usize = 2048; + let priv_key = RsaPrivateKey::new(&mut rng, BITS).expect("failed to generate a key"); + let pub_key = RsaPublicKey::from(&priv_key); + + TransportKeypair { + public: TransportPublicKey(pub_key), + secret: TransportSecretKey(priv_key), + } + } +} + +#[derive(Serialize, Deserialize, Clone)] +pub(super) struct TransportPublicKey(RsaPublicKey); + +impl TransportPublicKey { + pub fn encrypt(&self, data: &[u8]) -> Vec { + let mut rng = OsRng; + let padding = Pkcs1v15Encrypt; + self.0 + .encrypt(&mut rng, padding, data) + .expect("failed to encrypt") + } +} + +#[derive(Clone)] +pub(super) struct TransportSecretKey(RsaPrivateKey); + +impl TransportSecretKey { + pub fn decrypt(&self, data: &[u8]) -> rsa::Result> { + let padding = Pkcs1v15Encrypt; + self.0.decrypt(padding, data) + } +} + +#[cfg(test)] +#[test] +fn key_sizes_and_decryption() { + let pair = TransportKeypair::new(); + let sym_key_bytes = rand::random::<[u8; 16]>(); + // use aes_gcm::KeyInit; + // let _sym_key = aes_gcm::aes::Aes128::new(&sym_key_bytes.into()); + let encrypted: Vec = pair.public.encrypt(&sym_key_bytes); + assert!( + encrypted.len() <= super::packet_data::MAX_PACKET_SIZE, + "packet size is too big" + ); + let bytes = pair.secret.decrypt(&encrypted).unwrap(); + assert_eq!(bytes, sym_key_bytes.as_slice()); +} diff --git a/crates/core/src/transport/packet_data.rs b/crates/core/src/transport/packet_data.rs new file mode 100644 index 000000000..195467575 --- /dev/null +++ b/crates/core/src/transport/packet_data.rs @@ -0,0 +1,349 @@ +use std::marker::PhantomData; +use std::{cell::RefCell, sync::Arc}; + +use aes_gcm::{ + aead::{generic_array::GenericArray, rand_core::SeedableRng, AeadInPlace}, + Aes128Gcm, +}; +use rand::{prelude::SmallRng, thread_rng, Rng}; + +use crate::transport::crypto::TransportPublicKey; + +use super::crypto::TransportSecretKey; +use super::TransportError; + +/// The maximum size of a received UDP packet, MTU typically is 1500 +pub(in crate::transport) const MAX_PACKET_SIZE: usize = 1500 - UDP_HEADER_SIZE; + +// These are the same as the AES-GCM 128 constants, but extracting them from Aes128Gcm +// as consts was awkward. +const NONCE_SIZE: usize = 12; +const TAG_SIZE: usize = 16; + +pub(super) const MAX_DATA_SIZE: usize = MAX_PACKET_SIZE - NONCE_SIZE - TAG_SIZE; +const UDP_HEADER_SIZE: usize = 8; + +thread_local! { + // This must be very fast, but doesn't need to be cryptographically secure. + static RNG: RefCell = RefCell::new( + SmallRng::from_rng(thread_rng()).expect("failed to create RNG") + ); +} + +struct AssertSize; + +impl AssertSize { + const OK: () = assert!(N <= MAX_PACKET_SIZE); +} + +// trying to bypass limitations with const generic checks on where clauses +const fn _check_valid_size() { + #[allow(clippy::let_unit_value)] + let () = AssertSize::::OK; +} + +#[derive(Clone)] +pub(super) struct PacketData { + data: [u8; N], + pub size: usize, + data_type: PhantomData
, +} + +pub(super) trait Encryption: Clone {} + +/// Decrypted packet +#[derive(Clone, Copy)] +pub(super) struct Plaintext; + +/// Packet is encrypted using symmetric crypto (most packets if not an intro packet) +#[derive(Clone, Copy)] +pub(super) struct SymmetricAES; + +/// Packet is encrypted using assympetric crypto (typically an intro packet) +#[derive(Clone, Copy)] +pub(super) struct AssymetricRSA; + +/// This is used when we don't know the encryption type of the packet, perhaps because we +/// haven't yet determined whether it is an intro packet. +#[derive(Clone, Copy)] +pub(super) struct UnknownEncryption; + +impl Encryption for Plaintext {} +impl Encryption for SymmetricAES {} +impl Encryption for AssymetricRSA {} +impl Encryption for UnknownEncryption {} + +pub(super) const fn packet_size() -> usize { + DATA_SIZE + NONCE_SIZE + TAG_SIZE +} + +fn internal_sym_decryption( + data: &[u8], + size: usize, + inbound_sym_key: &Aes128Gcm, +) -> Result<([u8; N], usize), aes_gcm::Error> { + debug_assert!(data.len() >= NONCE_SIZE + TAG_SIZE); + + let nonce = GenericArray::from_slice(&data[..NONCE_SIZE]); + // Adjusted to extract the tag from the end of the encrypted data + let tag = GenericArray::from_slice(&data[size - TAG_SIZE..size]); + let encrypted_data = &data[NONCE_SIZE..size - TAG_SIZE]; + let mut buffer = [0u8; N]; + let buffer_len = encrypted_data.len(); + buffer[..buffer_len].copy_from_slice(encrypted_data); + + inbound_sym_key.decrypt_in_place_detached(nonce, &[], &mut buffer[..buffer_len], tag)?; + Ok((buffer, buffer_len)) +} + +impl PacketData { + pub(super) fn data(&self) -> &[u8] { + &self.data[..self.size] + } +} + +impl PacketData { + #[cfg(test)] + pub(super) fn decrypt( + &self, + inbound_sym_key: &Aes128Gcm, + ) -> Result, aes_gcm::Error> { + let (buffer, buffer_len) = + internal_sym_decryption::(&self.data[..], self.size, inbound_sym_key)?; + + Ok(Self { + data: buffer, + size: buffer_len, + data_type: PhantomData, + }) + } + + pub fn prepared_send(self) -> Arc<[u8]> { + self.data[..self.size].into() + } +} + +impl PacketData { + pub(super) fn encrypt_with_pubkey(data: &[u8], remote_key: &TransportPublicKey) -> Self { + _check_valid_size::(); + let encrypted_data: Vec = remote_key.encrypt(data); + debug_assert!(encrypted_data.len() <= MAX_PACKET_SIZE); + let mut data = [0; N]; + data[..encrypted_data.len()].copy_from_slice(&encrypted_data[..]); + Self { + data, + size: encrypted_data.len(), + data_type: PhantomData, + } + } + + pub fn preparef_send(self) -> Arc<[u8]> { + self.data[..self.size].into() + } +} + +impl PacketData { + pub fn from_buf_plain(buf: impl AsRef<[u8]>) -> Self { + let mut data = [0; N]; + let buf = buf.as_ref(); + let size = buf.len(); + data[..size].copy_from_slice(buf); + Self { + size, + data, + data_type: PhantomData, + } + } + + pub(super) fn encrypt_symmetric(&self, cipher: &Aes128Gcm) -> PacketData { + _check_valid_size::(); + debug_assert!(self.size <= MAX_DATA_SIZE); + + let nonce: [u8; NONCE_SIZE] = RNG.with(|rng| rng.borrow_mut().gen()); + + let mut buffer = [0u8; N]; + buffer[..NONCE_SIZE].copy_from_slice(&nonce); + + // Encrypt the data in place + let payload_length = self.size; + buffer[NONCE_SIZE..NONCE_SIZE + payload_length].copy_from_slice(self.data()); + let tag = cipher + .encrypt_in_place_detached( + &nonce.into(), + &[], + &mut buffer[NONCE_SIZE..NONCE_SIZE + payload_length], + ) + .unwrap(); + + // Append the tag to the buffer + buffer[NONCE_SIZE + payload_length..NONCE_SIZE + payload_length + TAG_SIZE] + .copy_from_slice(tag.as_slice()); + + PacketData { + data: buffer, + size: NONCE_SIZE + payload_length + TAG_SIZE, + data_type: PhantomData, + } + } +} + +impl PacketData { + pub fn from_buf(buf: impl AsRef<[u8]>) -> Self { + let mut data = [0; N]; + let buf = buf.as_ref(); + let size = buf.len(); + data[..size].copy_from_slice(buf); + Self { + size, + data, + data_type: PhantomData, + } + } + + pub(super) fn is_intro_packet( + &self, + actual_intro_packet: &PacketData, + ) -> bool { + self.size == actual_intro_packet.size + && self.data[..self.size] == actual_intro_packet.data[..actual_intro_packet.size] + } + + pub(super) fn try_decrypt_sym( + &self, + inbound_sym_key: &Aes128Gcm, + ) -> Result, aes_gcm::Error> { + let (buffer, buffer_len) = + internal_sym_decryption::(&self.data[..], self.size, inbound_sym_key)?; + + Ok(PacketData { + data: buffer, + size: buffer_len, + data_type: PhantomData, + }) + } + + pub fn try_decrypt_asym( + &self, + key: &TransportSecretKey, + ) -> Result, TransportError> { + let r = key.decrypt(self.data()).map(|decrypted| { + let mut data = [0; N]; + data[..decrypted.len()].copy_from_slice(&decrypted[..]); + PacketData { + size: data.len(), + data, + data_type: PhantomData, + } + })?; + Ok(r) + } + + pub fn assert_assymetric(&self) -> PacketData { + PacketData { + data: self.data, + size: self.size, + data_type: PhantomData, + } + } +} + +impl Eq for PacketData {} + +impl PartialEq for PacketData { + fn eq(&self, other: &Self) -> bool { + self.size == other.size && self.data[..self.size] == other.data[..other.size] + } +} + +#[cfg(test)] +mod tests { + use super::*; + use aes_gcm::aead::rand_core::RngCore; + use aes_gcm::{Aes128Gcm, KeyInit}; + use rand::rngs::OsRng; + + #[test] + fn test_encryption_decryption() { + // Generate a random 128-bit (16 bytes) key + let mut key = [0u8; 16]; + OsRng.fill_bytes(&mut key); + + // Create a key object for AES-GCM + let key = GenericArray::from_slice(&key); + + // Create a new AES-128-GCM instance + let cipher = Aes128Gcm::new(key); + let data = b"Hello, world!"; + let unencrypted_packet = PacketData::<_, 1000>::from_buf_plain(data); + let encrypted_packet = unencrypted_packet.encrypt_symmetric(&cipher); + + let _overlap = longest_common_subsequence(&encrypted_packet.data, data.as_slice()); + + test_decryption(encrypted_packet, &cipher, unencrypted_packet); + } + + // Test detection of packet corruption + #[test] + fn test_encryption_decryption_corrupted() { + // Generate a random 128-bit (16 bytes) key + let mut key = [0u8; 16]; + OsRng.fill_bytes(&mut key); + + // Create a key object for AES-GCM + let key = GenericArray::from_slice(&key); + + // Create a new AES-128-GCM instance + let cipher = Aes128Gcm::new(key); + let data = b"Hello, world!"; + let unencrypted_packet = PacketData::<_, 1000>::from_buf_plain(data); + let mut encrypted_packet = unencrypted_packet.encrypt_symmetric(&cipher); + + // Corrupt the packet data + encrypted_packet.data[encrypted_packet.size / 2] = 0; + + // Ensure decryption fails + match encrypted_packet.decrypt(&cipher) { + Ok(_) => panic!("Decryption succeeded when it should have failed"), + Err(e) => assert_eq!(e, aes_gcm::Error), + } + } + + fn test_decryption( + packet_data: PacketData, + cipher: &Aes128Gcm, + original_data: PacketData, + ) { + match packet_data.decrypt(cipher) { + Ok(decrypted_data) => { + // Ensure decrypted data matches original + assert_eq!(&decrypted_data.data(), &original_data.data()); + } + Err(e) => panic!("Decryption failed with error: {:?}", e), + } + } + + fn longest_common_subsequence(a: &[u8], b: &[u8]) -> usize { + let m = a.len(); + let n = b.len(); + + // Initialize a 2D vector with zeros. The dimensions are (m+1) x (n+1). + let mut dp = vec![vec![0; n + 1]; m + 1]; + + // Iterate over each character in both sequences + for (i, _) in a.iter().enumerate() { + for (j, _) in b.iter().enumerate() { + if a[i] == b[j] { + // If characters match, increment the count from the previous subsequence + dp[i + 1][j + 1] = dp[i][j] + 1; + } else { + // Otherwise, the current state is the max of either omitting the current character + // from sequence 'a' or 'b'. + dp[i + 1][j + 1] = std::cmp::max(dp[i + 1][j], dp[i][j + 1]); + } + } + } + + // The value in the bottom-right cell of the matrix is the length of the LCS + dp[m][n] + } +} diff --git a/crates/core/src/transport/peer_connection.rs b/crates/core/src/transport/peer_connection.rs new file mode 100644 index 000000000..3a3e9b99b --- /dev/null +++ b/crates/core/src/transport/peer_connection.rs @@ -0,0 +1,463 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::atomic::AtomicU32; +use std::sync::Arc; +use std::time::Duration; +use std::vec::Vec; + +use crate::transport::packet_data::UnknownEncryption; +use aes_gcm::Aes128Gcm; +use futures::stream::FuturesUnordered; +use futures::{Future, StreamExt}; +use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; + +mod inbound_stream; +mod outbound_stream; + +use super::{ + connection_handler::SerializedMessage, + packet_data::{self, PacketData}, + received_packet_tracker::ReceivedPacketTracker, + received_packet_tracker::ReportResult, + sent_packet_tracker::{ResendAction, SentPacketTracker}, + symmetric_message::{self, SymmetricMessage, SymmetricMessagePayload}, + TransportError, +}; +use crate::util::time_source::InstantTimeSrc; + +type Result = std::result::Result; + +// TODO: measure the space overhead of SymmetricMessage::ShortMessage since is likely less than 100 +/// The max payload we can send in a single fragment, this MUST be less than packet_data::MAX_DATA_SIZE +/// since we need to account for the space overhead of SymmetricMessage::LongMessage metadata +const MAX_DATA_SIZE: usize = packet_data::MAX_DATA_SIZE - 100; + +#[must_use] +pub(super) struct RemoteConnection { + pub outbound_packets: mpsc::Sender<(SocketAddr, Arc<[u8]>)>, + pub outbound_symmetric_key: Aes128Gcm, + pub remote_addr: SocketAddr, + pub sent_tracker: Arc>>, + pub last_packet_id: Arc, + pub inbound_packet_recv: mpsc::Receiver>, + pub inbound_symmetric_key: Aes128Gcm, + pub inbound_symmetric_key_bytes: [u8; 16], + pub my_address: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[repr(transparent)] +#[serde(transparent)] +pub(crate) struct StreamId(u32); + +impl StreamId { + pub fn next() -> Self { + static NEXT_ID: AtomicU32 = AtomicU32::new(0); + Self(NEXT_ID.fetch_add(1, std::sync::atomic::Ordering::Release)) + } +} + +impl std::fmt::Display for StreamId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +type InboundStreamResult = Result<(StreamId, SerializedMessage), StreamId>; +type InboundStreamFut = Pin + Send>>; + +/// The `PeerConnection` struct is responsible for managing the connection with a remote peer. +/// It provides methods for sending and receiving messages to and from the remote peer. +/// +/// The `PeerConnection` struct maintains the state of the connection, including the remote +/// connection details, trackers for received and sent packets, and futures for inbound and +/// outbound streams. +/// +/// The `send` method is used to send serialized data to the remote peer. If the data size +/// exceeds the maximum allowed size, it is sent as a stream; otherwise, it is sent as a +/// short message. +/// +/// The `recv` method is used to receive incoming packets from the remote peer. It listens for +/// incoming packets or receipts, and resends packets if necessary. +/// +/// The `process_inbound` method is used to process incoming payloads based on their type. +/// +/// The `noop`, `outbound_short_message`, and `outbound_stream` methods are used internally for +/// sending different types of messages. +/// +/// The `packet_sending` function is a helper function used to send packets to the remote peer. +#[must_use = "call await on the `recv` function to start listening for incoming messages"] +pub(crate) struct PeerConnection { + remote_conn: RemoteConnection, + received_tracker: ReceivedPacketTracker, + inbound_streams: HashMap)>>, + inbound_stream_futures: FuturesUnordered>, + outbound_stream_futures: FuturesUnordered>, +} + +impl PeerConnection { + pub fn new(remote_conn: RemoteConnection) -> Self { + Self { + remote_conn, + received_tracker: ReceivedPacketTracker::new(), + inbound_streams: HashMap::new(), + inbound_stream_futures: FuturesUnordered::new(), + outbound_stream_futures: FuturesUnordered::new(), + } + } + + pub async fn send(&mut self, data: T) -> Result + where + T: Serialize + Send + 'static, + { + let data = tokio::task::spawn_blocking(move || bincode::serialize(&data).unwrap()) + .await + .unwrap(); + if data.len() > MAX_DATA_SIZE { + self.outbound_stream(data).await; + } else { + self.outbound_short_message(data).await?; + } + Ok(()) + } + + pub async fn recv(&mut self) -> Result> { + // listen for incoming messages or receipts or wait until is time to do anything else again + let mut resend_check = Some(tokio::time::sleep(tokio::time::Duration::from_secs(1))); + + const KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(30); + const KILL_CONNECTION_AFTER: Duration = Duration::from_secs(60); + let mut keep_alive = tokio::time::interval(KEEP_ALIVE_INTERVAL); + keep_alive.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + keep_alive.tick().await; + let mut last_received = std::time::Instant::now(); + + loop { + // tracing::trace!(remote = ?self.remote_conn.remote_addr, "waiting for inbound messages"); + tokio::select! { + inbound = self.remote_conn.inbound_packet_recv.recv() => { + let packet_data = inbound.ok_or(TransportError::ConnectionClosed)?; + last_received = std::time::Instant::now(); + let Ok(decrypted) = packet_data.try_decrypt_sym(&self.remote_conn.inbound_symmetric_key).map_err(|error| { + tracing::debug!(%error, remote = ?self.remote_conn.remote_addr, "Failed to decrypt packet, might be an intro packet or a partial packet"); + }) else { + // just ignore this message + // TODO: maybe check how frequently this happens and decide to drop a connection based on that + // if it is partial packets being received too often + // TODO: this branch should at much happen UdpPacketsListener::NAT_TRAVERSAL_MAX_ATTEMPTS + // for intro packets will be sent than this amount, so we could be checking for that initially + tracing::trace!(remote = ?self.remote_conn.remote_addr, "ignoring packet"); + continue; + }; + let msg = SymmetricMessage::deser(decrypted.data()).unwrap(); + let SymmetricMessage { + packet_id, + confirm_receipt, + payload, + } = msg; + #[cfg(test)] + { + tracing::trace!( + remote = %self.remote_conn.remote_addr, %packet_id, %payload, ?confirm_receipt, + "received inbound packet" + ); + } + self.remote_conn + .sent_tracker + .lock() + .report_received_receipts(&confirm_receipt); + match self.received_tracker.report_received_packet(packet_id) { + ReportResult::Ok => {} + ReportResult::AlreadyReceived => { + tracing::trace!(%packet_id, "already received packet"); + continue; + } + ReportResult::QueueFull => { + let receipts = self.received_tracker.get_receipts(); + tracing::debug!(?receipts, "queue full, reporting receipts"); + self.noop(receipts).await?; + }, + } + if let Some(msg) = self.process_inbound(payload).await.map_err(|error| { + tracing::error!(%error, %packet_id, remote = %self.remote_conn.remote_addr, "error processing inbound packet"); + error + })? { + tracing::debug!(%packet_id, "returning full stream message"); + return Ok(msg); + } + } + inbound_stream = self.inbound_stream_futures.next(), if !self.inbound_stream_futures.is_empty() => { + let Some(res) = inbound_stream else { + tracing::error!("unexpected no-stream from ongoing_inbound_streams"); + continue + }; + let Ok((stream_id, msg)) = res.map_err(|e| TransportError::Other(e.into()))? else { + tracing::error!("unexpected error from ongoing_inbound_streams"); + // TODO: may leave orphan stream recvs hanging around in this case + continue; + }; + self.inbound_streams.remove(&stream_id); + tracing::trace!(%stream_id, "stream finished"); + return Ok(msg); + } + outbound_stream = self.outbound_stream_futures.next(), if !self.outbound_stream_futures.is_empty() => { + let Some(res) = outbound_stream else { + tracing::error!("unexpected no-stream from ongoing_outbound_streams"); + continue + }; + res.map_err(|e| TransportError::Other(e.into()))?? + } + _ = keep_alive.tick() => { + if last_received.elapsed() > KILL_CONNECTION_AFTER { + tracing::warn!(remote = ?self.remote_conn.remote_addr, "connection timed out"); + return Err(TransportError::ConnectionClosed); + } + tracing::trace!(remote = ?self.remote_conn.remote_addr, "sending keep-alive"); + self.noop(vec![]).await?; + } + _ = resend_check.take().unwrap_or(tokio::time::sleep(Duration::from_secs(5))) => { + loop { + tracing::trace!(remote = ?self.remote_conn.remote_addr, "checking for resends"); + let maybe_resend = self.remote_conn + .sent_tracker + .lock() + .get_resend(); + match maybe_resend { + ResendAction::WaitUntil(wait_until) => { + resend_check = Some(tokio::time::sleep_until(wait_until.into())); + break; + } + ResendAction::Resend(idx, packet) => { + self.remote_conn + .outbound_packets + .send((self.remote_conn.remote_addr, packet.clone())) + .await + .map_err(|_| TransportError::ConnectionClosed)?; + self.remote_conn.sent_tracker.lock().report_sent_packet(idx, packet); + } + } + } + } + } + } + } + + /// Returns the external address of the peer holding this connection. + pub fn my_address(&self) -> Option { + self.remote_conn.my_address + } + + async fn process_inbound( + &mut self, + payload: SymmetricMessagePayload, + ) -> Result>> { + use SymmetricMessagePayload::*; + match payload { + ShortMessage { payload } => Ok(Some(payload)), + AckConnection { result: Err(cause) } => { + Err(TransportError::ConnectionEstablishmentFailure { cause }) + } + AckConnection { result: Ok(_) } => { + let packet = SymmetricMessage::ack_ok( + &self.remote_conn.outbound_symmetric_key, + self.remote_conn.inbound_symmetric_key_bytes, + self.remote_conn.remote_addr, + )?; + self.remote_conn + .outbound_packets + .send((self.remote_conn.remote_addr, packet.data().into())) + .await + .map_err(|_| TransportError::ConnectionClosed)?; + Ok(None) + } + StreamFragment { + stream_id, + total_length_bytes, + fragment_number, + payload, + } => { + if let Some(sender) = self.inbound_streams.get(&stream_id) { + sender + .send((fragment_number, payload)) + .await + .map_err(|_| TransportError::ConnectionClosed)?; + tracing::trace!(%stream_id, %fragment_number, "fragment pushed to existing stream"); + } else { + let (sender, receiver) = mpsc::channel(1); + tracing::trace!(%stream_id, %fragment_number, "new stream"); + self.inbound_streams.insert(stream_id, sender); + let mut stream = inbound_stream::InboundStream::new(total_length_bytes); + if let Some(msg) = stream.push_fragment(fragment_number, payload) { + self.inbound_streams.remove(&stream_id); + tracing::trace!(%stream_id, %fragment_number, "stream finished"); + return Ok(Some(msg)); + } + self.inbound_stream_futures + .push(tokio::spawn(inbound_stream::recv_stream( + stream_id, receiver, stream, + ))); + } + Ok(None) + } + NoOp => Ok(None), + } + } + + #[inline] + async fn noop(&mut self, receipts: Vec) -> Result<()> { + packet_sending( + self.remote_conn.remote_addr, + &self.remote_conn.outbound_packets, + self.remote_conn + .last_packet_id + .fetch_add(1, std::sync::atomic::Ordering::Release), + &self.remote_conn.outbound_symmetric_key, + receipts, + (), + &self.remote_conn.sent_tracker, + ) + .await + } + + #[inline] + async fn outbound_short_message(&mut self, data: SerializedMessage) -> Result<()> { + let receipts = self.received_tracker.get_receipts(); + let packet_id = self + .remote_conn + .last_packet_id + .fetch_add(1, std::sync::atomic::Ordering::Release); + packet_sending( + self.remote_conn.remote_addr, + &self.remote_conn.outbound_packets, + packet_id, + &self.remote_conn.outbound_symmetric_key, + receipts, + symmetric_message::ShortMessage(data), + &self.remote_conn.sent_tracker, + ) + .await?; + Ok(()) + } + + async fn outbound_stream(&mut self, data: SerializedMessage) { + let stream_id = StreamId::next(); + let task = tokio::spawn(outbound_stream::send_stream( + stream_id, + self.remote_conn.last_packet_id.clone(), + self.remote_conn.outbound_packets.clone(), + self.remote_conn.remote_addr, + data, + self.remote_conn.outbound_symmetric_key.clone(), + self.remote_conn.sent_tracker.clone(), + )); + self.outbound_stream_futures.push(task); + } +} + +async fn packet_sending( + remote_addr: SocketAddr, + outbound_packets: &mpsc::Sender<(SocketAddr, Arc<[u8]>)>, + packet_id: u32, + outbound_sym_key: &Aes128Gcm, + confirm_receipt: Vec, + payload: impl Into, + sent_tracker: &parking_lot::Mutex>, +) -> Result<()> { + // FIXME: here ensure that `confirm_receipt` won't make the packet exceed the max data size + // if it does, split it to send multiple noop packets with the receipts + + // tracing::trace!(packet_id, "sending packet"); + let packet = SymmetricMessage::serialize_msg_to_packet_data( + packet_id, + payload, + outbound_sym_key, + confirm_receipt, + )?; + outbound_packets + .send((remote_addr, packet.clone().prepared_send())) + .await + .map_err(|_| TransportError::ConnectionClosed)?; + sent_tracker + .lock() + .report_sent_packet(packet_id, packet.prepared_send()); + Ok(()) +} + +#[cfg(test)] +mod tests { + use aes_gcm::KeyInit; + use futures::TryFutureExt; + use std::net::{Ipv4Addr, SocketAddr}; + use tokio::sync::mpsc; + + use super::{ + inbound_stream::{recv_stream, InboundStream}, + outbound_stream::send_stream, + *, + }; + use crate::transport::packet_data::MAX_PACKET_SIZE; + + #[tokio::test] + async fn test_inbound_outbound_interaction() -> Result<(), Box> { + const MSG_LEN: usize = 1000; + let (sender, mut receiver) = mpsc::channel(1); + let remote_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8080); + let message: Vec<_> = std::iter::repeat(0) + .take(MSG_LEN) + .map(|_| rand::random::()) + .collect(); + let key = rand::random::<[u8; 16]>(); + let cipher = Aes128Gcm::new(&key.into()); + let sent_tracker = Arc::new(parking_lot::Mutex::new(SentPacketTracker::new())); + + let stream_id = StreamId::next(); + // Send a long message using the outbound stream + let outbound = tokio::task::spawn(send_stream( + stream_id, + Arc::new(AtomicU32::new(0)), + sender, + remote_addr, + message.clone(), + cipher.clone(), + sent_tracker, + )) + .map_err(|e| e.into()); + + let inbound = async { + // need to take care of decrypting and deserializing the inbound data before collecting into the message + let (tx, rx) = mpsc::channel(1); + let stream = InboundStream::new(MSG_LEN as u64); + let inbound_msg = tokio::task::spawn(recv_stream(stream_id, rx, stream)); + while let Some((_, network_packet)) = receiver.recv().await { + let decrypted = PacketData::<_, MAX_PACKET_SIZE>::from_buf(&network_packet) + .try_decrypt_sym(&cipher) + .map_err(TransportError::PrivateKeyDecryptionError)?; + let SymmetricMessage { + payload: + SymmetricMessagePayload::StreamFragment { + fragment_number, + payload, + .. + }, + .. + } = SymmetricMessage::deser(decrypted.data()).expect("symmetric message") + else { + return Err("unexpected message".into()); + }; + tx.send((fragment_number, payload)).await?; + } + let (_, msg) = inbound_msg + .await? + .map_err(|_| anyhow::anyhow!("stream failed"))?; + Ok::<_, Box>(msg) + }; + + let (out_res, inbound_msg) = tokio::try_join!(outbound, inbound)?; + out_res?; + assert_eq!(message, inbound_msg); + Ok(()) + } +} diff --git a/crates/core/src/transport/peer_connection/inbound_stream.rs b/crates/core/src/transport/peer_connection/inbound_stream.rs new file mode 100644 index 000000000..ddd8f1431 --- /dev/null +++ b/crates/core/src/transport/peer_connection/inbound_stream.rs @@ -0,0 +1,122 @@ +use tokio::sync::mpsc; + +use crate::transport::peer_connection::outbound_stream::SerializedStream; +use std::collections::BTreeMap; + +use super::StreamId; + +type FragmentIdx = u32; + +pub(super) async fn recv_stream( + stream_id: StreamId, + mut receiver: mpsc::Receiver<(FragmentIdx, Vec)>, + mut stream: InboundStream, +) -> Result<(StreamId, Vec), StreamId> { + while let Some((fragment_number, payload)) = receiver.recv().await { + if let Some(msg) = stream.push_fragment(fragment_number, payload) { + return Ok((stream_id, msg)); + } + } + Err(stream_id) +} + +pub(super) struct InboundStream { + total_length_bytes: u64, + /// Fragment numbers are 1-indexed + last_contiguous_fragment_idx: FragmentIdx, + non_contiguous_fragments: BTreeMap>, + payload: Vec, +} + +impl InboundStream { + pub fn new(total_length_bytes: u64) -> Self { + Self { + total_length_bytes, + last_contiguous_fragment_idx: 0, + non_contiguous_fragments: BTreeMap::new(), + payload: vec![], + } + } + + /// Returns some if the message has been completely streamed, none otherwise. + pub fn push_fragment( + &mut self, + fragment_number: FragmentIdx, + mut fragment: SerializedStream, + ) -> Option> { + // tracing::trace!( + // %fragment_number, + // last = %self.last_contiguous_fragment_idx, + // non_contig = ?self.non_contiguous_fragments.keys().collect::>(), + // "received stream fragment" + // ); + if fragment_number == self.last_contiguous_fragment_idx + 1 { + self.last_contiguous_fragment_idx = fragment_number; + self.payload.append(&mut fragment); + } else { + self.non_contiguous_fragments + .insert(fragment_number, fragment); + } + while let Some((idx, mut v)) = self.non_contiguous_fragments.pop_first() { + if idx == self.last_contiguous_fragment_idx + 1 { + self.last_contiguous_fragment_idx += 1; + self.payload.append(&mut v); + } else { + self.non_contiguous_fragments.insert(idx, v); + break; + } + } + self.get_and_clear() + } + + fn get_and_clear(&mut self) -> Option> { + if self.payload.len() as u64 == self.total_length_bytes { + Some(std::mem::take(&mut self.payload)) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::InboundStream; + + #[test] + fn test_simple_sequence() { + let mut stream = InboundStream::new(6); + assert_eq!(stream.push_fragment(1, vec![1, 2, 3]), None); + assert_eq!( + stream.push_fragment(2, vec![4, 5, 6]), + Some(vec![1, 2, 3, 4, 5, 6]) + ); + assert!(stream.non_contiguous_fragments.is_empty()); + assert!(stream.payload.is_empty()); + } + + #[test] + fn test_out_of_order_fragment_1() { + let mut stream = InboundStream::new(6); + assert_eq!(stream.push_fragment(1, vec![1, 2]), None); + assert_eq!(stream.push_fragment(3, vec![5, 6]), None); + assert_eq!( + stream.push_fragment(2, vec![3, 4]), + Some(vec![1, 2, 3, 4, 5, 6]) + ); + assert!(stream.non_contiguous_fragments.is_empty()); + assert!(stream.payload.is_empty()); + } + + #[test] + fn test_out_of_order_fragment_2() { + let mut stream = InboundStream::new(6); + assert_eq!(stream.push_fragment(2, vec![3, 4]), None); + assert_eq!(stream.push_fragment(3, vec![5, 6]), None); + assert_eq!( + stream.push_fragment(1, vec![1, 2]), + Some(vec![1, 2, 3, 4, 5, 6]) + ); + assert!(stream.non_contiguous_fragments.is_empty()); + assert!(stream.payload.is_empty()); + } +} diff --git a/crates/core/src/transport/peer_connection/outbound_stream.rs b/crates/core/src/transport/peer_connection/outbound_stream.rs new file mode 100644 index 000000000..a6c1b85ba --- /dev/null +++ b/crates/core/src/transport/peer_connection/outbound_stream.rs @@ -0,0 +1,147 @@ +use std::net::SocketAddr; +use std::sync::atomic::AtomicU32; +use std::sync::Arc; +use std::vec; + +use aes_gcm::Aes128Gcm; +use tokio::sync::mpsc; + +use crate::{ + transport::{ + packet_data, + sent_packet_tracker::SentPacketTracker, + symmetric_message::{self}, + TransportError, + }, + util::time_source::InstantTimeSrc, +}; + +use super::StreamId; + +pub(crate) type SerializedStream = Vec; + +// TODO: measure the space overhead of SymmetricMessage::LongMessage since is likely less than 100 +/// The max payload we can send in a single fragment, this MUST be less than packet_data::MAX_DATA_SIZE +/// since we need to account for the space overhead of SymmetricMessage::LongMessage metadata +const MAX_DATA_SIZE: usize = packet_data::MAX_DATA_SIZE - 100; + +// TODO: unit test +/// Handles sending a stream that is *not piped*. In the future this will be replaced by +/// piped streams which start forwarding before the stream has been received. +#[allow(clippy::too_many_arguments)] +pub(super) async fn send_stream( + stream_id: StreamId, + last_packet_id: Arc, + sender: mpsc::Sender<(SocketAddr, Arc<[u8]>)>, + destination_addr: SocketAddr, + mut stream_to_send: SerializedStream, + outbound_symmetric_key: Aes128Gcm, + sent_packet_tracker: Arc>>, +) -> Result<(), TransportError> { + tracing::debug!(stream_id = %stream_id.0, length = stream_to_send.len(), "sending stream"); + let total_length_bytes = stream_to_send.len() as u32; + let mut total_packets = stream_to_send.len() / MAX_DATA_SIZE; + total_packets += if stream_to_send.len() % MAX_DATA_SIZE == 0 { + 0 + } else { + 1 + }; + let mut sent_so_far = 0; + let mut next_fragment_number = 1; // Fragment numbers are 1-indexed + + loop { + if sent_so_far == total_packets { + break; + } + let rest = { + if stream_to_send.len() > MAX_DATA_SIZE { + let mut rest = stream_to_send.split_off(MAX_DATA_SIZE); + std::mem::swap(&mut stream_to_send, &mut rest); + rest + } else { + std::mem::take(&mut stream_to_send) + } + }; + let packet_id = last_packet_id.fetch_add(1, std::sync::atomic::Ordering::Release); + super::packet_sending( + destination_addr, + &sender, + packet_id, + &outbound_symmetric_key, + vec![], + symmetric_message::StreamFragment { + stream_id, + total_length_bytes: total_length_bytes as u64, + fragment_number: next_fragment_number, + payload: rest, + }, + &sent_packet_tracker, + ) + .await?; + next_fragment_number += 1; + sent_so_far += 1; + } + + // tracing::trace!(stream_id = %stream_id.0, total_packets = %sent_so_far, "stream sent"); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use aes_gcm::KeyInit; + use std::net::{Ipv4Addr, SocketAddr}; + use tests::packet_data::MAX_PACKET_SIZE; + use tokio::sync::mpsc; + + use super::{ + symmetric_message::{SymmetricMessage, SymmetricMessagePayload}, + *, + }; + use crate::transport::packet_data::PacketData; + + #[tokio::test] + async fn test_send_stream_success() -> Result<(), Box> { + let (outbound_sender, mut outbound_receiver) = mpsc::channel(1); + let remote_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8080); + let message: Vec<_> = std::iter::repeat(()) + .take(100_000) + .map(|_| rand::random::()) + .collect(); + let cipher = { + let key = rand::random::<[u8; 16]>(); + Aes128Gcm::new(&key.into()) + }; + let sent_tracker = Arc::new(parking_lot::Mutex::new(SentPacketTracker::new())); + + let background_task = tokio::spawn(send_stream( + StreamId::next(), + Arc::new(AtomicU32::new(0)), + outbound_sender, + remote_addr, + message.clone(), + cipher.clone(), + sent_tracker, + )); + + let mut inbound_bytes = Vec::new(); + while let Some((_, packet)) = outbound_receiver.recv().await { + let decrypted_packet = PacketData::<_, MAX_PACKET_SIZE>::from_buf(packet.as_ref()) + .try_decrypt_sym(&cipher) + .map_err(TransportError::PrivateKeyDecryptionError)?; + let deserialized = SymmetricMessage::deser(decrypted_packet.data())?; + let SymmetricMessagePayload::StreamFragment { payload, .. } = deserialized.payload + else { + panic!("Expected a StreamFragment, got {:?}", deserialized.payload); + }; + inbound_bytes.extend_from_slice(payload.as_ref()); + } + + let result = background_task.await?; + assert!(result.is_ok()); + assert_eq!(&message[..10], &inbound_bytes[..10]); + assert_eq!(inbound_bytes.len(), 100_000); + assert_eq!(&message[99_990..], &inbound_bytes[99_990..]); + Ok(()) + } +} diff --git a/crates/core/src/transport/rate_limiter.rs b/crates/core/src/transport/rate_limiter.rs new file mode 100644 index 000000000..75b2c6e39 --- /dev/null +++ b/crates/core/src/transport/rate_limiter.rs @@ -0,0 +1,185 @@ +use tokio::sync::mpsc; + +use std::collections::VecDeque; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use super::Socket; +use crate::util::time_source::{InstantTimeSrc, TimeSource}; + +/// Keeps track of the bandwidth used in the last window_size. Recommend a `window_size` of +/// 10 seconds. +pub(super) struct PacketRateLimiter { + packets: VecDeque<(usize, Instant)>, + window_size: Duration, + current_bandwidth: usize, + outbound_packets: mpsc::Receiver<(SocketAddr, Arc<[u8]>)>, + time_source: T, +} + +impl PacketRateLimiter { + pub(super) fn new( + window_size: Duration, + outbound_packets: mpsc::Receiver<(SocketAddr, Arc<[u8]>)>, + ) -> Self { + PacketRateLimiter { + packets: VecDeque::new(), + window_size, + current_bandwidth: 0, + outbound_packets, + time_source: InstantTimeSrc::new(), + } + } +} + +impl PacketRateLimiter { + pub(super) async fn rate_limiter(mut self, bandwidth_limit: usize, socket: Arc) { + while let Some((socket_addr, packet)) = self.outbound_packets.recv().await { + if let Some(wait_time) = self.can_send_packet(bandwidth_limit, packet.len()) { + tokio::time::sleep(wait_time).await; + if let Err(error) = socket.send_to(&packet, socket_addr).await { + tracing::debug!("Error sending packet: {:?}", error); + } else { + self.add_packet(packet.len()); + } + } else if let Err(error) = socket.send_to(&packet, socket_addr).await { + tracing::debug!(%socket_addr, "Error sending packet: {:?}", error); + } else { + self.add_packet(packet.len()); + } + } + tracing::debug!("Rate limiter task ended unexpectedly"); + } + + /// Report that a packet was sent + fn add_packet(&mut self, packet_size: usize) { + let now = self.time_source.now(); + self.packets.push_back((packet_size, now)); + self.current_bandwidth += packet_size; + self.cleanup(); + } + + /// Removes packets that are older than the window size. + fn cleanup(&mut self) { + let now = self.time_source.now(); + while self + .packets + .front() + .map_or(false, |&(_, time)| now - time > self.window_size) + { + let expired = self.packets.pop_front(); + if let Some((size, _)) = expired { + self.current_bandwidth -= size; + } + } + } + + /// Returns none if the packet can be sent immediately without `bandwidth_limit` being + /// exceeded within the `window_size`. Otherwise returns Some(wait_time) where wait_time is the + /// amount of time that should be waited before sending the packet. + /// + /// `bandwidth_limit` should be set to 50% higher than the target upstream bandwidth the + /// [topology manager](crate::topology::TopologyManager) is aiming for, as it serves + /// as a hard limit which we'd prefer not to hit. + fn can_send_packet(&mut self, bandwidth_limit: usize, packet_size: usize) -> Option { + self.cleanup(); + + if self.current_bandwidth + packet_size <= bandwidth_limit { + return None; + } + + let mut temp_bandwidth = self.current_bandwidth; + let mut wait_time = None; + + for &(size, time) in self.packets.iter() { + temp_bandwidth -= size; + if temp_bandwidth + packet_size <= bandwidth_limit { + wait_time = Some( + self.window_size + .saturating_sub(self.time_source.now() - time), + ); + break; + } + } + + wait_time + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::util::time_source::MockTimeSource; + + fn mock_tracker(window_size: Duration) -> PacketRateLimiter { + PacketRateLimiter { + packets: VecDeque::new(), + window_size, + current_bandwidth: 0, + outbound_packets: mpsc::channel(1).1, + time_source: MockTimeSource::new(Instant::now()), + } + } + + fn verify_bandwidth_match(tracker: &PacketRateLimiter) { + let mut total_bandwidth = 0; + for &(size, _) in tracker.packets.iter() { + total_bandwidth += size; + } + assert_eq!(total_bandwidth, tracker.current_bandwidth); + } + + #[test] + fn test_adding_packets() { + let mut tracker = PacketRateLimiter::new(Duration::from_secs(1), mpsc::channel(1).1); + verify_bandwidth_match(&tracker); + tracker.add_packet(1500); + verify_bandwidth_match(&tracker); + assert_eq!(tracker.packets.len(), 1); + } + + #[test] + fn test_bandwidth_calculation() { + let mut tracker = PacketRateLimiter::new(Duration::from_secs(1), mpsc::channel(1).1); + tracker.add_packet(1500); + tracker.add_packet(2500); + verify_bandwidth_match(&tracker); + assert_eq!( + tracker.packets.iter().map(|&(size, _)| size).sum::(), + 4000 + ); + } + + #[test] + fn test_packet_expiry() { + let mut tracker = mock_tracker(Duration::from_millis(200)); + tracker.add_packet(1500); + verify_bandwidth_match(&tracker); + tracker.time_source.advance_time(Duration::from_millis(300)); + tracker.cleanup(); + verify_bandwidth_match(&tracker); + assert!(tracker.packets.is_empty()); + } + + #[test] + fn test_wait_time_calculation() { + let mut tracker = mock_tracker(Duration::from_secs(1)); + tracker.add_packet(5000); + verify_bandwidth_match(&tracker); + tracker.time_source.advance_time(Duration::from_millis(500)); + tracker.add_packet(4000); + verify_bandwidth_match(&tracker); + match tracker.can_send_packet(10000, 2000) { + None => panic!("Should require waiting"), + Some(wait_time) => assert_eq!(wait_time, Duration::from_millis(500)), + } + } + + #[test] + fn test_immediate_send() { + let mut tracker = PacketRateLimiter::new(Duration::from_millis(10), mpsc::channel(1).1); + tracker.add_packet(3000); + assert_eq!(tracker.can_send_packet(10000, 2000), None); + } +} diff --git a/crates/core/src/transport/received_packet_tracker.rs b/crates/core/src/transport/received_packet_tracker.rs new file mode 100644 index 000000000..9191039ba --- /dev/null +++ b/crates/core/src/transport/received_packet_tracker.rs @@ -0,0 +1,227 @@ +use crate::transport::PacketId; +use crate::util::time_source::{InstantTimeSrc, TimeSource}; +use std::collections::{HashMap, VecDeque}; +use std::mem; +use std::time::{Duration, Instant}; + +/// How long to retain packets in case they need to be retransmitted +const RETAIN_TIME: Duration = Duration::from_secs(60); +const MAX_PENDING_RECEIPTS: usize = 20; + +/// This struct is responsible for tracking received packets and deciding when to send receipts +/// from/to a specific peer. +/// +/// The caller must report when packets are received using the `report_received_packets` method. +/// The caller must also call `get_receipts` periodically to check if any receipts need to be sent. +/// +/// `get_receipts` should be called whenever a packet is sent. +/// +/// `get_receipts` must **also** be called every `MAX_CONFIRMATION_DELAY` (100ms) and +/// if the returned list is not empty, the list should be sent as receipts immediately in a noop. +/// This may look something like this: +/// +/// ```rust,no_run +/// use super::MAX_CONFIRMATION_DELAY; +/// use std::thread::sleep; +/// let mut received_packet_tracker = todo!(); +/// loop { +/// let receipts = received_packet_tracker.get_receipts(); +/// if !receipts.is_empty() { +/// // Send receipts in a noop message +/// } +/// sleep(MAX_CONFIRMATION_DELAY); +/// } +/// ``` +pub(super) struct ReceivedPacketTracker { + pending_receipts: Vec, + packet_id_time: VecDeque<(PacketId, Instant)>, + time_by_packet_id: HashMap, + time_source: T, +} + +impl ReceivedPacketTracker { + pub(super) fn new() -> Self { + ReceivedPacketTracker { + pending_receipts: Vec::new(), + packet_id_time: VecDeque::new(), + time_by_packet_id: HashMap::new(), + time_source: InstantTimeSrc::new(), + } + } +} + +impl ReceivedPacketTracker { + pub(super) fn report_received_packet(&mut self, packet_id: PacketId) -> ReportResult { + self.cleanup(); + let current_time = self.time_source.now(); + + match self.time_by_packet_id.entry(packet_id) { + std::collections::hash_map::Entry::Occupied(_) => ReportResult::AlreadyReceived, + std::collections::hash_map::Entry::Vacant(e) => { + e.insert(current_time); + self.packet_id_time.push_back((packet_id, current_time)); + self.pending_receipts.push(packet_id); + + if self.pending_receipts.len() < MAX_PENDING_RECEIPTS { + ReportResult::Ok + } else { + ReportResult::QueueFull + } + } + } + } + + /// Returns a list of packets that have been received since the last call to this function. + /// This should be called every time a packet is sent to ensure that receipts are sent + /// promptly. Every `MAX_CONFIRMATION_DELAY` (50ms) this should be called and if the returned + /// list is not empty, the list should be sent as receipts immediately in a noop packet. + pub(super) fn get_receipts(&mut self) -> Vec { + self.cleanup(); + + mem::take(self.pending_receipts.as_mut()) + } + + /// This function cleans up the `packet_id_time` and `time_by_packet_id` data structures. + /// It removes entries that are older than `RETAIN_TIME`. + fn cleanup(&mut self) { + let remove_before = self.time_source.now() - RETAIN_TIME; + while self + .packet_id_time + .front() + .map_or(false, |&(_, time)| time < remove_before) + { + let expired = self.packet_id_time.pop_front(); + if let Some((packet_id, _)) = expired { + self.time_by_packet_id.remove(&packet_id); + } + } + // Note: We deliberately don't clean up the pending_receipts list because it will + // be emptied every time get_receipts is called. + } +} + +#[must_use] +#[derive(Debug, PartialEq)] +pub(super) enum ReportResult { + /// Packet was received for the first time and recorded + Ok, + + /// The packet has already been received, it will be re-acknowledged but + /// should otherwise be ignored + AlreadyReceived, + + /// The queue is full and receipts must be sent immediately + QueueFull, +} + +#[cfg(test)] +pub(in crate::transport) mod tests { + use super::*; + use crate::util::time_source::MockTimeSource; + + pub(in crate::transport) fn mock_received_packet_tracker( + ) -> ReceivedPacketTracker { + ReceivedPacketTracker { + pending_receipts: Vec::new(), + packet_id_time: VecDeque::new(), + time_by_packet_id: HashMap::new(), + time_source: MockTimeSource::new(Instant::now()), + } + } + + #[test] + fn test_initialization() { + let mut tracker = ReceivedPacketTracker { + pending_receipts: Vec::new(), + packet_id_time: VecDeque::new(), + time_by_packet_id: HashMap::new(), + time_source: MockTimeSource::new(Instant::now()), + }; + + assert_eq!(tracker.get_receipts().len(), 0); + assert_eq!(tracker.pending_receipts.len(), 0); + assert_eq!(tracker.time_by_packet_id.len(), 0); + } + + #[test] + fn test_report_receipt_ok() { + let mut tracker = ReceivedPacketTracker { + pending_receipts: Vec::new(), + packet_id_time: VecDeque::new(), + time_by_packet_id: HashMap::new(), + time_source: MockTimeSource::new(Instant::now()), + }; + + assert_eq!(tracker.report_received_packet(0), ReportResult::Ok); + assert_eq!(tracker.pending_receipts.len(), 1); + assert_eq!(tracker.time_by_packet_id.len(), 1); + } + + #[test] + fn test_report_receipt_already_received() { + let mut tracker = mock_received_packet_tracker(); + + assert_eq!(tracker.report_received_packet(0), ReportResult::Ok); + assert_eq!( + tracker.report_received_packet(0), + ReportResult::AlreadyReceived + ); + assert_eq!(tracker.pending_receipts.len(), 1); + assert_eq!(tracker.time_by_packet_id.len(), 1); + } + + #[test] + fn test_report_receipt_queue_full() { + let mut tracker = ReceivedPacketTracker { + pending_receipts: Vec::new(), + packet_id_time: VecDeque::new(), + time_by_packet_id: HashMap::new(), + time_source: MockTimeSource::new(Instant::now()), + }; + + for i in 0..(MAX_PENDING_RECEIPTS - 1) { + assert_eq!( + tracker.report_received_packet(i as PacketId), + ReportResult::Ok + ); + } + assert_eq!( + tracker.report_received_packet((MAX_PENDING_RECEIPTS as PacketId) + 1), + ReportResult::QueueFull + ); + assert_eq!(tracker.pending_receipts.len(), MAX_PENDING_RECEIPTS); + assert_eq!(tracker.time_by_packet_id.len(), MAX_PENDING_RECEIPTS); + } + + #[test] + fn test_cleanup() { + let mut tracker = ReceivedPacketTracker { + pending_receipts: Vec::new(), + packet_id_time: VecDeque::new(), + time_by_packet_id: HashMap::new(), + time_source: MockTimeSource::new(Instant::now()), + }; + + for i in 0..10 { + assert_eq!(tracker.report_received_packet(i), ReportResult::Ok); + } + assert_eq!(tracker.time_by_packet_id.len(), 10); + assert_eq!(tracker.packet_id_time.len(), 10); + + tracker + .time_source + .advance_time(RETAIN_TIME + Duration::from_secs(1)); + + tracker.cleanup(); + assert_eq!(tracker.time_by_packet_id.len(), 0); + assert_eq!(tracker.packet_id_time.len(), 0); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_many_trackers() { + let mut trackers = vec![]; + for _ in 1..100 { + trackers.push(ReceivedPacketTracker::new()); + } + } +} diff --git a/crates/core/src/transport/sent_packet_tracker.rs b/crates/core/src/transport/sent_packet_tracker.rs new file mode 100644 index 000000000..ba887617e --- /dev/null +++ b/crates/core/src/transport/sent_packet_tracker.rs @@ -0,0 +1,256 @@ +use super::PacketId; +use crate::util::time_source::{InstantTimeSrc, TimeSource}; +use std::collections::{HashMap, VecDeque}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +const NETWORK_DELAY_ALLOWANCE: Duration = Duration::from_millis(500); + +/// We can wait up to 100ms to confirm a message was received, this allows us to batch +/// receipts together and send them in a single message. +const MAX_CONFIRMATION_DELAY: Duration = Duration::from_millis(100); + +/// If we don't get a receipt for a message within 500ms, we assume the message was lost and +/// resend it. This must be significantly higher than MAX_CONFIRMATION_DELAY (100ms) to +/// account for network delay +pub(super) const MESSAGE_CONFIRMATION_TIMEOUT: Duration = { + let millis: u128 = MAX_CONFIRMATION_DELAY.as_millis() + NETWORK_DELAY_ALLOWANCE.as_millis(); + + // Check for overflow + if millis > u64::MAX as u128 { + panic!("Value too large for u64"); + } + + // Safe to convert now + Duration::from_millis(millis as u64) +}; + +/// Determines the accuracy/sensitivity of the packet loss estimate. A lower value will result +/// in a more accurate estimate, but it will take longer to converge to the true value. +const PACKET_LOSS_DECAY_FACTOR: f64 = 1.0 / 1000.0; + +/// This struct is responsible for tracking packets that have been sent but not yet acknowledged. +/// It is also responsible for deciding when to resend packets that have not been acknowledged. +/// +/// The caller must report when packets are sent and when receipts are received using the +/// `report_sent_packet` and `report_received_receipts` functions. The caller must also call +/// `get_resend` periodically to check if any packets need to be resent. +/// +/// The expectation is that get_resend will be called as part of a loop that looks something like +/// this: +/// +/// ```rust,no_run +/// let mut sent_packet_tracker = todo!(); +/// loop { +/// match sent_packet_tracker.get_resend() { +/// ResendAction::WaitUntil(wait_until) => { +/// sleep_until(wait_until).await; +/// } +/// ResendAction::Resend(packet_id, packet) => { +/// // Send packet and then call report_sent_packet again with the same packet_id. +/// } +/// } +/// } +/// ``` +pub(super) struct SentPacketTracker { + /// The list of packets that have been sent but not yet acknowledged + pending_receipts: HashMap>, + + resend_queue: VecDeque, + + packet_loss_proportion: f64, + + pub(super) time_source: T, +} + +impl SentPacketTracker { + pub(super) fn new() -> Self { + SentPacketTracker { + pending_receipts: HashMap::new(), + resend_queue: VecDeque::new(), + packet_loss_proportion: 0.0, + time_source: InstantTimeSrc::new(), + } + } +} + +impl SentPacketTracker { + /// Get an estimate of the proportion of outbound packets that were lost This is a value + /// between 0.0 and 1.0, where 0.0 means no packets are lost and 1.0 means all packets are + /// lost. This estimate will be biased towards 0.0 initially, and will converge to the + /// true value over time. It's accuracy will be approximately + /// `PACKET_LOSS_DECAY_FACTOR` (0.001). + pub(super) fn get_recent_packet_loss(&self) -> f64 { + self.packet_loss_proportion + } + + pub(super) fn report_sent_packet(&mut self, packet_id: PacketId, payload: Arc<[u8]>) { + self.pending_receipts.insert(packet_id, payload); + self.resend_queue.push_back(ResendQueueEntry { + timeout_at: self.time_source.now() + MESSAGE_CONFIRMATION_TIMEOUT, + packet_id, + }); + } + + pub(super) fn report_received_receipts(&mut self, packet_ids: &[PacketId]) { + for packet_id in packet_ids { + // This can be simplified but I'm leaving it like this for readability. + self.packet_loss_proportion = self.packet_loss_proportion + * (1.0 - PACKET_LOSS_DECAY_FACTOR) + + (PACKET_LOSS_DECAY_FACTOR * 0.0); + self.pending_receipts.remove(packet_id); + } + } + + /// Either get a packet that needs to be resent, or how long the caller should wait until + /// calling this function again. If a packet is resent you **must** call + /// `report_sent_packet` again with the same packet_id. + pub(super) fn get_resend(&mut self) -> ResendAction { + let now = self.time_source.now(); + + while let Some(entry) = self.resend_queue.pop_front() { + if entry.timeout_at > now { + if !self.pending_receipts.contains_key(&entry.packet_id) { + continue; + } + let wait_until = entry.timeout_at; + self.resend_queue.push_front(entry); + return ResendAction::WaitUntil(wait_until); + } else if let Some(packet) = self.pending_receipts.remove(&entry.packet_id) { + // Update packet loss proportion for a lost packet + // Resend logic + self.packet_loss_proportion = self.packet_loss_proportion + * (1.0 - PACKET_LOSS_DECAY_FACTOR) + + PACKET_LOSS_DECAY_FACTOR; + + return ResendAction::Resend(entry.packet_id, packet); + } + // If the packet is no longer in pending_receipts, it means its receipt has been received. + // No action needed, continue to check the next entry in the queue. + } + + ResendAction::WaitUntil(now + MESSAGE_CONFIRMATION_TIMEOUT) + } +} + +#[derive(Debug, PartialEq)] +pub enum ResendAction { + WaitUntil(Instant), + Resend(u32, Arc<[u8]>), +} + +struct ResendQueueEntry { + timeout_at: Instant, + packet_id: u32, +} + +// Unit tests +#[cfg(test)] +pub(in crate::transport) mod tests { + use super::*; + use crate::transport::MessagePayload; + use crate::util::time_source::MockTimeSource; + + pub(in crate::transport) fn mock_sent_packet_tracker() -> SentPacketTracker { + let time_source = MockTimeSource::new(Instant::now()); + + SentPacketTracker { + pending_receipts: HashMap::new(), + resend_queue: VecDeque::new(), + packet_loss_proportion: 0.0, + time_source, + } + } + + #[test] + fn test_report_sent_packet() { + let mut tracker = mock_sent_packet_tracker(); + tracker.report_sent_packet(1, vec![1, 2, 3].into()); + assert_eq!(tracker.pending_receipts.len(), 1); + assert_eq!(tracker.resend_queue.len(), 1); + assert_eq!(tracker.packet_loss_proportion, 0.0); + } + + #[test] + fn test_report_received_receipts() { + let mut tracker = mock_sent_packet_tracker(); + tracker.report_sent_packet(1, vec![1, 2, 3].into()); + tracker.report_received_receipts(&[1]); + assert_eq!(tracker.pending_receipts.len(), 0); + assert!(tracker.resend_queue.len() <= 1); + assert_eq!(tracker.packet_loss_proportion, 0.0); + } + + #[test] + fn test_packet_lost() { + let mut tracker = mock_sent_packet_tracker(); + tracker.report_sent_packet(1, vec![1, 2, 3].into()); + tracker + .time_source + .advance_time(MESSAGE_CONFIRMATION_TIMEOUT); + let resend_action = tracker.get_resend(); + assert_eq!(resend_action, ResendAction::Resend(1, vec![1, 2, 3].into())); + assert_eq!(tracker.pending_receipts.len(), 0); + assert_eq!(tracker.resend_queue.len(), 0); + assert_eq!(tracker.packet_loss_proportion, PACKET_LOSS_DECAY_FACTOR); + } + + #[test] + fn test_immediate_receipt_then_resend() { + let mut tracker = mock_sent_packet_tracker(); + + // Report two packets sent + tracker.report_sent_packet(1, vec![1, 2, 3].into()); + tracker.report_sent_packet(2, vec![4, 5, 6].into()); + + // Immediately report receipt for the first packet + tracker.report_received_receipts(&[1]); + + // Simulate time just before the resend time for packet 2 + tracker + .time_source + .advance_time(MESSAGE_CONFIRMATION_TIMEOUT - Duration::from_millis(1)); + + // This should not trigger a resend yet + match tracker.get_resend() { + ResendAction::WaitUntil(_) => (), + _ => panic!("Expected WaitUntil, got Resend too early"), + } + + // Now advance time to trigger resend for packet 2 + tracker.time_source.advance_time(Duration::from_millis(2)); + + // This should now trigger a resend for packet 2 + match tracker.get_resend() { + ResendAction::Resend(packet_id, _) => assert_eq!(packet_id, 2), + _ => panic!("Expected Resend for message ID 2"), + } + } + + #[test] + fn test_get_resend_with_pending_receipts() { + let mut tracker = mock_sent_packet_tracker(); + + tracker.report_sent_packet(0, MessagePayload::new().into()); + + tracker.time_source.advance_time(Duration::from_millis(10)); + + tracker.report_sent_packet(1, MessagePayload::new().into()); + + let packet_1_timeout = tracker.time_source.now() + MESSAGE_CONFIRMATION_TIMEOUT; + + // Acknowledge receipt of the first packet + tracker.report_received_receipts(&[0]); + + // The next call to get_resend should calculate the wait time based on the second packet (id 1) + match tracker.get_resend() { + ResendAction::WaitUntil(wait_until) => { + assert_eq!( + wait_until, packet_1_timeout, + "Wait time does not match expected for second packet" + ); + } + _ => panic!("Expected ResendAction::WaitUntil"), + } + } +} diff --git a/crates/core/src/transport/symmetric_message.rs b/crates/core/src/transport/symmetric_message.rs new file mode 100644 index 000000000..f60202d00 --- /dev/null +++ b/crates/core/src/transport/symmetric_message.rs @@ -0,0 +1,288 @@ +use std::{borrow::Cow, net::SocketAddr, sync::OnceLock}; + +use crate::transport::packet_data::SymmetricAES; +use aes_gcm::Aes128Gcm; +use serde::{Deserialize, Serialize}; +use serde_with::serde_as; + +use super::{ + packet_data::MAX_DATA_SIZE, peer_connection::StreamId, MessagePayload, PacketData, PacketId, +}; + +#[serde_as] +#[derive(Serialize, Deserialize)] +#[cfg_attr(test, derive(PartialEq, Debug, Clone))] +pub(super) struct SymmetricMessage { + pub packet_id: PacketId, + // #[serde(skip_serializing_if = "Vec::is_empty")] + pub confirm_receipt: Vec, + pub payload: SymmetricMessagePayload, +} + +impl SymmetricMessage { + pub const FIRST_PACKET_ID: u32 = 0u32; + + pub fn deser(bytes: &[u8]) -> Result { + bincode::deserialize(bytes) + } + + const ACK_ERROR: SymmetricMessage = SymmetricMessage { + packet_id: Self::FIRST_PACKET_ID, + confirm_receipt: Vec::new(), + payload: SymmetricMessagePayload::AckConnection { + // TODO: change to return UnsupportedProtocolVersion + result: Err(Cow::Borrowed( + "remote is using a different protocol version", + )), + }, + }; + + pub fn ack_error( + outbound_sym_key: &Aes128Gcm, + ) -> Result, bincode::Error> { + static SERIALIZED: OnceLock> = OnceLock::new(); + let bytes = SERIALIZED.get_or_init(|| { + let mut packet = [0u8; MAX_DATA_SIZE]; + let size = bincode::serialized_size(&Self::ACK_ERROR).unwrap(); + bincode::serialize_into(packet.as_mut_slice(), &Self::ACK_ERROR).unwrap(); + (&packet[..size as usize]).into() + }); + let packet = PacketData::from_buf_plain(bytes); + Ok(packet.encrypt_symmetric(outbound_sym_key)) + } + + pub fn ack_ok( + outbound_sym_key: &Aes128Gcm, + our_inbound_key: [u8; 16], + remote_addr: SocketAddr, + ) -> Result, bincode::Error> { + let message = Self { + packet_id: Self::FIRST_PACKET_ID, + confirm_receipt: vec![], + payload: SymmetricMessagePayload::AckConnection { + result: Ok(OutboundConnection { + key: our_inbound_key, + remote_addr, + }), + }, + }; + let mut packet = [0u8; MAX_DATA_SIZE]; + let size = bincode::serialized_size(&message)?; + debug_assert!(size <= MAX_DATA_SIZE as u64); + bincode::serialize_into(packet.as_mut_slice(), &message)?; + let bytes = &packet[..size as usize]; + + let packet = PacketData::from_buf_plain(bytes); + Ok(packet.encrypt_symmetric(outbound_sym_key)) + } + + pub fn serialize_msg_to_packet_data( + packet_id: PacketId, + payload: impl Into, + outbound_sym_key: &Aes128Gcm, + confirm_receipt: Vec, + ) -> Result, bincode::Error> { + let message = Self { + packet_id, + confirm_receipt, + payload: payload.into(), + }; + let mut packet = [0u8; MAX_DATA_SIZE]; + let size = bincode::serialized_size(&message)?; + debug_assert!(size <= MAX_DATA_SIZE as u64); + bincode::serialize_into(packet.as_mut_slice(), &message)?; + let bytes = &packet[..size as usize]; + let packet = PacketData::from_buf_plain(bytes); + Ok(packet.encrypt_symmetric(outbound_sym_key)) + } +} + +impl From<()> for SymmetricMessagePayload { + fn from(_: ()) -> Self { + Self::NoOp {} + } +} + +pub(super) struct ShortMessage(pub MessagePayload); + +impl From for SymmetricMessagePayload { + fn from(short_message: ShortMessage) -> Self { + Self::ShortMessage { + payload: short_message.0, + } + } +} + +pub(super) struct StreamFragment { + pub stream_id: StreamId, + pub total_length_bytes: u64, + pub fragment_number: u32, + pub payload: MessagePayload, +} + +impl From for SymmetricMessagePayload { + fn from(stream_fragment: StreamFragment) -> Self { + Self::StreamFragment { + stream_id: stream_fragment.stream_id, + total_length_bytes: stream_fragment.total_length_bytes, + fragment_number: stream_fragment.fragment_number, + payload: stream_fragment.payload, + } + } +} + +#[derive(Serialize, Deserialize)] +#[cfg_attr(test, derive(PartialEq, Debug, Clone))] +pub(super) struct OutboundConnection { + pub key: [u8; 16], + pub remote_addr: SocketAddr, +} + +#[derive(Serialize, Deserialize)] +#[cfg_attr(test, derive(PartialEq, Debug, Clone))] +pub(super) enum SymmetricMessagePayload { + AckConnection { + // a remote acknowledges a connection and returns the private key to use + // for communication and the remote address + result: Result>, + }, + ShortMessage { + payload: MessagePayload, + }, + StreamFragment { + stream_id: StreamId, + total_length_bytes: u64, // we shouldn't allow messages larger than u32, that's already crazy big + fragment_number: u32, + payload: MessagePayload, + }, + NoOp, +} + +#[cfg(test)] +impl std::fmt::Display for SymmetricMessagePayload { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SymmetricMessagePayload::AckConnection { result } => { + write!( + f, + "AckConnection: {}", + result.as_ref().map(|_| "Ok").unwrap_or("Err") + ) + } + SymmetricMessagePayload::ShortMessage { .. } => { + write!(f, "ShortMessage") + } + SymmetricMessagePayload::StreamFragment { + stream_id, + fragment_number, + .. + } => write!( + f, + "StreamFragment: (stream id: {:?}, fragment no: {:?}) ", + stream_id, fragment_number + ), + SymmetricMessagePayload::NoOp => write!(f, "NoOp"), + } + } +} + +#[cfg(test)] +mod test { + use std::net::Ipv4Addr; + + use aes_gcm::KeyInit; + + use super::*; + + fn gen_key() -> Aes128Gcm { + let key = rand::random::<[u8; 16]>(); + Aes128Gcm::new(&key.into()) + } + + fn serialization_round_trip( + payload: impl Into, + key: &Aes128Gcm, + ) -> SymmetricMessagePayload { + let enc_sym_packet = + SymmetricMessage::serialize_msg_to_packet_data(1, payload, key, vec![]).unwrap(); + let dec_sym_packet = enc_sym_packet.decrypt(key).unwrap(); + SymmetricMessage::deser(dec_sym_packet.data()) + .unwrap() + .payload + } + + #[test] + fn check_symmetric_message_serialization() { + let test_cases = [ + SymmetricMessagePayload::AckConnection { + result: Ok(OutboundConnection { + key: [0; 16], + remote_addr: (Ipv4Addr::LOCALHOST, 1234).into(), + }), + }, + SymmetricMessagePayload::AckConnection { + result: Err(Cow::Borrowed("error")), + }, + SymmetricMessagePayload::ShortMessage { + payload: std::iter::repeat(()) + .take(100) + .map(|_| rand::random::()) + .collect(), + }, + SymmetricMessagePayload::StreamFragment { + stream_id: StreamId::next(), + total_length_bytes: 100, + fragment_number: 1, + payload: std::iter::repeat(()) + .take(100) + .map(|_| rand::random::()) + .collect(), + }, + SymmetricMessagePayload::NoOp, + ]; + let key = gen_key(); + + for case in test_cases { + let deserialized = serialization_round_trip(case.clone(), &key); + assert_eq!(deserialized, case); + } + } + + #[test] + fn ack_error_msg() -> Result<(), Box> { + let key = gen_key(); + let packet = SymmetricMessage::ack_error(&key)?; + let data = packet.decrypt(&key).unwrap(); + let deser = SymmetricMessage::deser(data.data())?; + assert!(matches!( + deser.payload, + SymmetricMessagePayload::AckConnection { result: Err(_) } + )); + Ok(()) + } + + #[test] + fn ack_ok_msg() -> Result<(), Box> { + let enc = bincode::serialize(&SymmetricMessage { + packet_id: SymmetricMessage::FIRST_PACKET_ID, + confirm_receipt: vec![], + payload: SymmetricMessagePayload::AckConnection { + result: Ok(OutboundConnection { + key: [0; 16], + remote_addr: (Ipv4Addr::LOCALHOST, 1234).into(), + }), + }, + })?; + let _dec: SymmetricMessage = bincode::deserialize(&enc)?; + + let key = gen_key(); + let packet = SymmetricMessage::ack_ok(&key, [0; 16], (Ipv4Addr::LOCALHOST, 1234).into())?; + let data = packet.decrypt(&key).unwrap(); + let deser = SymmetricMessage::deser(data.data())?; + assert!(matches!( + deser.payload, + SymmetricMessagePayload::AckConnection { result: Ok(_) } + )); + Ok(()) + } +} diff --git a/crates/core/src/util.rs b/crates/core/src/util.rs index a90cfdb7b..8aa82a107 100644 --- a/crates/core/src/util.rs +++ b/crates/core/src/util.rs @@ -1,3 +1,5 @@ +pub(crate) mod time_source; + use std::{ collections::{BTreeMap, HashSet}, time::Duration, @@ -244,8 +246,7 @@ impl<'x> Contains for &'x Vec<&PeerId> { } #[cfg(test)] -pub mod tests { - +pub(crate) mod tests { use tempfile::TempDir; /// Use this to guarantee unique directory names in case you are running multiple tests in parallel. diff --git a/crates/core/src/util/time_source.rs b/crates/core/src/util/time_source.rs new file mode 100644 index 000000000..52631d04f --- /dev/null +++ b/crates/core/src/util/time_source.rs @@ -0,0 +1,146 @@ +use std::sync::atomic::{AtomicBool, AtomicPtr}; +use std::sync::Arc; +use std::thread; +use std::thread::sleep; +use std::time::{Duration, Instant}; + +const UPDATE_CACHED_TIME_EVERY: Duration = Duration::from_millis(10); + +pub trait TimeSource { + fn now(&self) -> Instant; +} + +/// A simple time source that returns the current time using `Instant::now()`. +#[derive(Clone, Copy)] +pub struct InstantTimeSrc(()); + +impl InstantTimeSrc { + pub fn new() -> Self { + InstantTimeSrc(()) + } +} + +impl TimeSource for InstantTimeSrc { + fn now(&self) -> Instant { + Instant::now() + } +} + +/// A time source that caches the current time in a global state to reduce +/// overhead in performance-critical sections. +/// +/// **Warning**: This time source will only be accurate to within about 20ms, +/// any usage should be tested carefully to verify that this inaccuracy is acceptable. +/// In the absence of such testing use [`InstantTimeSrc`] instead. +#[derive(Clone, Copy)] +pub(crate) struct CachingSystemTimeSrc(()); + +// Global atomic pointer to the cached time. Initialized as a null pointer. +static GLOBAL_TIME_STATE: AtomicPtr = AtomicPtr::new(std::ptr::null_mut()); + +impl CachingSystemTimeSrc { + #![allow(unused)] + // Creates a new instance and ensures only one updater task is spawned. + pub(crate) fn new() -> Self { + let mut current_unix_epoch_ts = Instant::now(); + + // Attempt to set the global time state if it's currently null. + // This ensures only the first thread to execute this will spawn the updater task. + if GLOBAL_TIME_STATE + .compare_exchange( + std::ptr::null_mut(), + (&mut current_unix_epoch_ts) as *mut _, + std::sync::atomic::Ordering::Acquire, + std::sync::atomic::Ordering::Relaxed, + ) + .is_ok() + { + // Use a flag to synchronize the updater task's initialization. + let drop_guard = Arc::new(AtomicBool::new(false)); + + // Spawn the updater task asynchronously. + let drop_guard_clone = drop_guard.clone(); + thread::spawn(move || Self::update_instant(drop_guard_clone)); + + // Wait until the updater task signals it's safe to proceed. + while !drop_guard.load(std::sync::atomic::Ordering::Acquire) { + std::hint::spin_loop(); + } + } + + CachingSystemTimeSrc(()) + } + + // Asynchronously updates the global time state every UPDATE_CACHED_TIME_EVERY (10ms). + fn update_instant(drop_guard: Arc) { + let mut now = Instant::now(); + + // Initially set the global time state and notify the constructor to proceed. + GLOBAL_TIME_STATE.store(&mut now, std::sync::atomic::Ordering::Release); + drop_guard.store(true, std::sync::atomic::Ordering::Release); + + loop { + // Update the time and store it in the global state. + now = Instant::now(); + GLOBAL_TIME_STATE.store(&mut now, std::sync::atomic::Ordering::Release); + + // Wait for 20ms before the next update. + sleep(UPDATE_CACHED_TIME_EVERY); + } + } +} + +impl TimeSource for CachingSystemTimeSrc { + // Returns the current time from the global state. + fn now(&self) -> Instant { + // Unsafe dereference is required for the raw pointer. + unsafe { *GLOBAL_TIME_STATE.load(std::sync::atomic::Ordering::Acquire) } + } +} + +#[cfg(test)] +#[derive(Clone)] +pub struct MockTimeSource { + current_instant: Instant, +} + +#[cfg(test)] +impl MockTimeSource { + pub fn new(start_instant: Instant) -> Self { + MockTimeSource { + current_instant: start_instant, + } + } + + pub fn advance_time(&mut self, duration: Duration) { + self.current_instant += duration; + } +} + +#[cfg(test)] +impl TimeSource for MockTimeSource { + fn now(&self) -> Instant { + self.current_instant + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn test_instant_is_updated() { + let time_source = CachingSystemTimeSrc::new(); + let first_instant = time_source.now(); + + assert!(first_instant.elapsed().as_millis() < 30); + + sleep(Duration::from_millis(120)); + let second_instant = time_source.now(); + + assert!(second_instant.elapsed().as_millis() < 30); + + assert!(second_instant > first_instant); + } +} diff --git a/crates/core/src/wasm_runtime/store.rs b/crates/core/src/wasm_runtime/store.rs index c44e4793e..eb289a083 100644 --- a/crates/core/src/wasm_runtime/store.rs +++ b/crates/core/src/wasm_runtime/store.rs @@ -23,7 +23,11 @@ pub(super) struct SafeWriter { impl SafeWriter { pub fn new(path: &Path, compact: bool) -> Result { let file = if compact { - OpenOptions::new().create(true).write(true).open(path)? + OpenOptions::new() + .create(true) + .truncate(false) + .write(true) + .open(path)? } else { OpenOptions::new() .create(true) diff --git a/crates/fdev/Cargo.toml b/crates/fdev/Cargo.toml index 81408fe9b..07b5b5176 100644 --- a/crates/fdev/Cargo.toml +++ b/crates/fdev/Cargo.toml @@ -17,7 +17,6 @@ chrono = { workspace = true } clap = { workspace = true, features = ["derive", "env"] } dashmap = { workspace = true } either = { workspace = true } -fastrand = { workspace = true } futures = { workspace = true } glob = "0.3" pico-args = "0.5" diff --git a/crates/fdev/src/main.rs b/crates/fdev/src/main.rs index 915e46acb..9b2aa24b6 100644 --- a/crates/fdev/src/main.rs +++ b/crates/fdev/src/main.rs @@ -40,7 +40,7 @@ fn main() -> Result<(), anyhow::Error> { let config = Config::parse(); freenet::config::Config::set_op_mode(config.additional.mode); if !config.sub_command.is_child() { - freenet::config::set_logger(); + freenet::config::set_logger(None); } tokio_rt.block_on(async move { let cwd = std::env::current_dir()?; diff --git a/crates/fdev/src/testing/multiple_process.rs b/crates/fdev/src/testing/multiple_process.rs index cfb2ce0ed..f1f7cce87 100644 --- a/crates/fdev/src/testing/multiple_process.rs +++ b/crates/fdev/src/testing/multiple_process.rs @@ -403,8 +403,8 @@ async fn child( peer_id: node_config.peer_id, }; std::env::set_var("FREENET_PEER_ID", node_config.peer_id.to_string()); - freenet::config::set_logger(); - let mut event_generator = MemoryEventsGen::::new_with_seed( + freenet::config::set_logger(None); + let mut event_generator: MemoryEventsGen = MemoryEventsGen::new_with_seed( receiver_ch.clone(), node_config.peer_id, test_config diff --git a/crates/fdev/src/testing/network.rs b/crates/fdev/src/testing/network.rs index daac5e649..deb88fc88 100644 --- a/crates/fdev/src/testing/network.rs +++ b/crates/fdev/src/testing/network.rs @@ -158,7 +158,7 @@ async fn start_supervisor(config: &TestConfig) -> anyhow::Result<(), Error> { async fn start_peer(config: &TestConfig, cmd_config: &NetworkProcessConfig) -> Result<(), Error> { std::env::set_var("FREENET_PEER_ID", cmd_config.clone().id.unwrap()); - freenet::config::set_logger(); + freenet::config::set_logger(None); if let Some(peer_id) = &cmd_config.id { let peer = NetworkPeer::new(peer_id.clone()).await?; peer.run(config, peer_id.clone()).await?; @@ -604,7 +604,7 @@ impl Runnable for NetworkPeer { let mut receiver_ch = self.receiver_ch.deref().clone(); receiver_ch.borrow_and_update(); - let mut memory_event_generator = MemoryEventsGen::::new_with_seed( + let mut memory_event_generator: MemoryEventsGen = MemoryEventsGen::new_with_seed( receiver_ch, self.config.peer_id, config.seed.expect("seed should be set for child process"), diff --git a/crates/fdev/src/testing/single_process.rs b/crates/fdev/src/testing/single_process.rs index 5c83d5221..2d58b4691 100644 --- a/crates/fdev/src/testing/single_process.rs +++ b/crates/fdev/src/testing/single_process.rs @@ -7,7 +7,7 @@ pub(super) async fn run(config: &super::TestConfig) -> anyhow::Result<(), super: let mut simulated_network = super::config_sim_network(config).await?; let join_handles = simulated_network - .start_with_rand_gen::( + .start_with_rand_gen::( config.seed(), config.max_contract_number.unwrap_or(config.nodes * 10), config.events as usize, diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index 48dd5bdbc..7fbec5f87 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -11,6 +11,7 @@ - [P2P Network](architecture/p2p-network.md) - [Intelligent Routing](architecture/irouting.md) +- [Transport](architecture/transport.md) # Developer Guide diff --git a/docs/src/architecture/transport.md b/docs/src/architecture/transport.md new file mode 100644 index 000000000..c958fbdf9 --- /dev/null +++ b/docs/src/architecture/transport.md @@ -0,0 +1,201 @@ +# Freenet Transport Protocol (FrTP) + +**Note**: This document is a work in progress and is subject to change, it is currently out-of-sync +with the codebase and should be updated to reflect the current state of the codebase once it has +stabilized. + +## Introduction + +The Freenet Transport Protocol (FrTP) is a UDP-based system designed to ensure reliable and encrypted +message transmission. This document outlines the key elements of FrTP, including connection +establishment, message handling, and rate limiting. + +## Overview + +- **Firewall Traversal**: FrTP allows peers behind firewalls to establish direct connections. +- **Security**: All messages are encrypted using AES128GCM, with RSA public key exchange for + connection establishment, should effectively thwart man-in-the-middle attacks. +- **Streaming**: Large messages can be streamed, meaning that a peer can start forwarding data + before the entire message is received. +- **Covert**: FrTP can run on any UDP port and FrTP packets look like random data, although more + sophisticated analysis of packet timing and size could be used to identify FrTP traffic. FrTP + can't be port-scanned as it won't respond to packets unless encrypted with the peer's public key. +- **Efficient**: FrTP is designed to minimize bandwidth usage, with rate limiting and confirmation + message batching. + +## Connection Establishment + +### Scenario 1: Both Peers Behind NAT + +This describes how to establish one side of a two-way connection, allowing Bob to send messages +to Alice. The process is symmetric in the other direction. + +#### Actors + +- Alice and Bob are both peers behind firewalls. + +#### Terminology + +- `Bob_public_key`: Bob's RSA public key. +- `Bob_private_key`: Bob's RSA private key. +- `Alice_inbound_symmetric_key`: AES128GCM symmetric key generated by Alice, used for decrypting + inbound messages from Bob. +- `hello_message(A->B)`: Message sent from Alice to Bob, containing `Alice_inbound_symmetric_key` + and a u16 protocol version number, encrypted using `B_public_key`. +- `hello_ack(B->A)`: Message sent from Bob to Alice acknowledging `hello_message(A->B)`, + encrypted using `Alice_inbound_symmetric_key`. + +#### Steps + +1. **Key Generation**: Alice generates a random AES128GCM symmetric key, + called `Alice_inbound_symmetric_key`. + +2. **Outbound Hello Message**: Alice encrypts `Alice_inbound_symmetric_key` with `Bob_public_key` and + a u16 protocol version number with Peer B's public key, to create `hello_message(A->B)`. + +3. **Sending Outbound Hello**: Alice repeatedly sends `hello_message(A->B)` every 200ms until + a `hello_ack(B->A)` from Bob is received or a 5-second timeout occurs, indicating connection + failure. + +4. **Receiving Inbound Hello**: Bob receives `hello_message(A->B)` and decrypts it using + `Bob_private_key`. If the protocol version is not supported, then Bob sends a `hello_ack(B->A)` + with an error code and terminates the connection. + +5. **Hello Acknowledgement**: Upon receiving `hello_ack(B->A)`, Alice stops sending `hello_message(A->B)` + and the inbound side of the connection is established. + +6. **Unexpected Hello Messages**: If Bob receives a `hello_message(A->B)` from Alice after + it has already sent a `hello_ack(B->A)`, then it should resend the `hello_ack(B->A)` and + otherwise ignore the message (this may occur if the initial `hello_ack(B->A)` is lost). + +### Scenario 2: Peer behind NAT connects to Gateway peer + +#### Actors + +- Alice is a peer behind a firewall, Gateway isn't behind a firewall and is configured to act as a + gateway peer for new peers to assimiate into the network. + +#### Terminology + +- `Gateway_public_key`: Gateway's RSA public key. +- `Gateway_private_key`: Gateway's RSA private key. +- `Alice_bidirectional_symmetric_key`: AES128GCM symmetric key generated by Alice, used for encrypting + and decrypting messages to/from Gateway. + +#### Steps + +1. **Key Generation**: Alice generates a random AES128GCM symmetric key, + called `Alice_bidrectional_symmetric_key`. + +2. **Outbound Hello Message**: Alice encrypts `Alice_bidrectional_symmetric_key` with `Gateway_public_key` and + a u16 protocol version number with Gateway's public key, to create `hello_message(A->G)`. + +3. **Sending Outbound Hello**: Alice repeatedly sends `hello_message(A->B)` every 200ms until + a `hello_ack(G->A)` from Gateway is received or a 5-second timeout occurs, indicating connection + failure. + +4. **Receiving Inbound Hello**: Gateway receives `hello_message(A->G)` and decrypts it using + `Gateway_private_key`. If the protocol version is not supported, then Gateway sends a `hello_ack(G->A)` + with an error code and terminates the connection, otherwise it sends a `hello_ack(G->A)` to Alice. + +5. **Hello Acknowledgement**: Upon receiving `hello_ack(G->A)`, Alice stops sending `hello_message(A->G)`, + and the the connection is established, Alice should use `Alice_bidirectional_symmetric_key` for + both encryption and decryption of packets sent to and received from Gateway. + +## Keep-Alive Protocol + +To maintain an open connection, `keep_alive` messages are exchanged every 30 seconds. A connection +is terminated if a peer fails to receive any message within 120 seconds. + +## Symmetric Message Schema + +```rust +pub(super) struct SymmetricMessage { + pub packet_id: PacketId, + pub confirm_receipt: Vec, + pub payload: SymmetricMessagePayload, +} + +pub(super) enum SymmetricMessagePayload { + AckConnection { + // if we successfully connected to a remote we attempt to connect to initially + // then we return our TransportPublicKey so they can enroute other peers to us + result: Result<(), Cow<'static, str>>, + }, + GatewayConnection { + // a gateway acknowledges a connection and returns the private key to use + // for communication + key: [u8; 16], + }, + ShortMessage { + payload: MessagePayload, + }, + StreamFragment { + stream_id: StreamId, + total_length_bytes: u64, // we shouldn't allow messages larger than u32, that's already crazy big + fragment_number: u32, + payload: MessagePayload, + }, + NoOp, +} + +pub enum HelloError { + UnsupportedProtocolVersion { + min_supported: u16, + max_supported: u16, + your_version: u16 + }, +} +``` + +## Message Handling + +### Dropped and Out-of-Order Messages + +- **Duplicate Detection**: Messages are checked for duplicate `message_id`. Duplicates trigger + an immediate `NoOperation` message with a reconfirmation in `confirm_receipt`. +- **Acknowledgement Timeout**: Messages are resent if not acknowledged within 2 seconds + (`MESSAGE_CONFIRMATION_TIMEOUT`). + +### Confirmation Batching + +- **Batching Strategy**: Receipts can be delayed up to 500ms (`MAX_CONFIRMATION_DELAY`) to enable + batch confirmation. +- **Queue Management**: Receipt queues exceeding 20 messages prompt immediate confirmation to + prevent overflow. + +## Message Types + +- **Short Messages**: Contained within a single UDP packet (up to 1kb). +- **Long Messages**: Split into fragments for larger payloads, enabling efficient data forwarding. + +## Rate Limiting + +- **Initial Setup**: Upstream bandwidth set 50% above desired usage to allow for traffic bursts. +- **Dynamic Adjustment**: Future adaptations may use isotonic regression for optimizing bandwidth + and packet loss + balance. +- **Implementation**: Bandwidth monitoring over 10-second windows (`BANDWIDTH_MEASUREMENT_WINDOW`). + Exceeding limits + triggers a 10ms sleep (`BANDWIDTH_CONTROL_SLEEP_DURATION`), with periodic reassessment. + +## Implementation Notes + +### Serialization + +- Try to avoid unnecessary copies of data, especially for large messages. +- Ensure serialization format is robust against untrustedf data. +- Note that there will be nested layers of serialization, both internal to FrTP and + by the FrTP user. + +Consider: + +- [FlatBuffers](https://crates.io/crates/flatbuffers/) +- [Serde Bytes](https://docs.rs/serde_bytes/latest/serde_bytes/) +- [BinCode](https://github.com/bincode-org/bincode) + +## Conclusion + +The Freenet Transport Protocol provides a robust framework for secure and efficient data +transmission. Its design considers NAT challenges, message integrity, and bandwidth management, +ensuring reliable communication in various network conditions.