diff --git a/.gitignore b/.gitignore index 23d5c7e9..89438de8 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,5 @@ tools/target /src/resources/public/* !/src/resources/public/content + +*.exe \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index e37cf9cf..cd4d4e3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,7 +108,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] @@ -119,7 +119,7 @@ checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] @@ -203,9 +203,9 @@ dependencies = [ [[package]] name = "base64" -version = "0.21.3" +version = "0.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "414dcefbc63d77c526a76b3afcf6fbb9b5e2791c19c3aa2297733208750c6e53" +checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" [[package]] name = "base64ct" @@ -259,9 +259,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" [[package]] name = "byteorder" @@ -271,9 +271,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" @@ -292,16 +292,15 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.29" +version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d87d9d13be47a5b7c3907137f1290b0459a7f80efb26be8c52afb11963bccb02" +checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" dependencies = [ "android-tzdata", "iana-time-zone", "js-sys", "num-traits", "serde", - "time", "wasm-bindgen", "windows-targets", ] @@ -401,7 +400,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] @@ -412,7 +411,7 @@ checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" dependencies = [ "darling_core", "quote", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] @@ -514,7 +513,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf0fab0b584e67341bbfedce7c8d59d9cebaa9088fa494338ed4f8be92130bd3" dependencies = [ "quote", - "syn 2.0.31", + "syn 2.0.36", "walkdir", ] @@ -585,9 +584,9 @@ checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6" [[package]] name = "flagset" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cda653ca797810c02f7ca4b804b40b8b95ae046eb989d356bce17919a8c25499" +checksum = "d52a7e408202050813e6f1d9addadcaafef3dca7530c7ddfb005d4081cce6779" [[package]] name = "flate2" @@ -635,6 +634,7 @@ checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -693,7 +693,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] @@ -744,7 +744,7 @@ checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" dependencies = [ "cfg-if", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", ] [[package]] @@ -976,6 +976,12 @@ dependencies = [ "hashbrown 0.14.0", ] +[[package]] +name = "indoc" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" + [[package]] name = "inherent" version = "1.0.10" @@ -984,30 +990,7 @@ checksum = "ce243b1bfa62ffc028f1cc3b6034ec63d649f3031bc8a4fbbb004e1ac17d1f68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", -] - -[[package]] -name = "interlink" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d6a90f4b1a13663a9bfb6c75a66c4379870218d5b667910eb47df5d136bc67a" -dependencies = [ - "futures-core", - "futures-sink", - "interlink-derive", - "tokio", -] - -[[package]] -name = "interlink-derive" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f00e8918e7eb15347ee01e4423d19be3ec30b4d9fec32b76cdbed1adff8ee73" -dependencies = [ - "darling", - "quote", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] @@ -1051,9 +1034,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.147" +version = "0.2.148" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" [[package]] name = "libm" @@ -1085,15 +1068,15 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.5" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503" +checksum = "1a9bad9f94746442c783ca431b22403b519cd7fbeed0533fdd6328b2f2212128" [[package]] name = "local-ip-address" -version = "0.5.4" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "885efb07efcd6ae1c6af70be7565544121424fa9e5b1c3e4b58bbbf141a58cef" +checksum = "3fefe707432eb6bd4704b3dacfc87aab269d56667ad05dcd6869534e8890e767" dependencies = [ "libc", "neli", @@ -1214,7 +1197,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" dependencies = [ "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", "windows-sys", ] @@ -1356,7 +1339,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] @@ -1440,7 +1423,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] @@ -1506,7 +1489,7 @@ checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" [[package]] name = "pocket-relay" -version = "0.5.9" +version = "0.5.10" dependencies = [ "argon2", "axum", @@ -1518,9 +1501,10 @@ dependencies = [ "email_address", "embeddy", "flate2", + "futures", "futures-util", "hyper", - "interlink", + "indoc", "local-ip-address", "log", "log-panics", @@ -1570,9 +1554,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.66" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" dependencies = [ "unicode-ident", ] @@ -1782,9 +1766,9 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "rustix" -version = "0.38.11" +version = "0.38.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0c3dde1fc030af041adc40e79c0e7fbcf431dd24870053d187d7c66e4b87453" +checksum = "d7db8590df6dfcd144d22afd1b83b36c21a18d7cbc1dc4bb5295a8712e9eb662" dependencies = [ "bitflags 2.4.0", "errno", @@ -1816,9 +1800,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.101.4" +version = "0.101.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d93931baf2d282fff8d3a532bbfd7653f734643161b87e3e01e59a04439bf0d" +checksum = "45a27e3b59326c16e23d30aeb7a36a24cc0d29e71d68ff611cdfb4a01d013bed" dependencies = [ "ring", "untrusted", @@ -1871,7 +1855,7 @@ dependencies = [ "proc-macro-error", "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] @@ -1907,7 +1891,7 @@ dependencies = [ "proc-macro2", "quote", "sea-bae", - "syn 2.0.31", + "syn 2.0.36", "unicode-ident", ] @@ -1964,9 +1948,9 @@ dependencies = [ [[package]] name = "sea-schema" -version = "0.14.0" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3e09eb40c78cee8fef8dfbb648036a26b7ad1f618499203ad0e8b6f97593f7f" +checksum = "0cd9561232bd1b82ea748b581f15909d11de0db6563ddcf28c5d908aee8282f1" dependencies = [ "futures", "sea-query", @@ -2002,14 +1986,14 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] name = "serde_json" -version = "1.0.105" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ "itoa", "ryu", @@ -2115,9 +2099,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" +checksum = "4031e820eb552adee9295814c0ced9e5cf38ddf1e8b7d566d6de8e2538ea989e" dependencies = [ "libc", "windows-sys", @@ -2419,9 +2403,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.31" +version = "2.0.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "718fa2415bcb8d8bd775917a1bf12a7931b6dfa890753378538118181e0cb398" +checksum = "91e02e55d62894af2a08aca894c6577281f76769ba47c94d5756bec8ac6e7373" dependencies = [ "proc-macro2", "quote", @@ -2436,9 +2420,9 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "tdf" -version = "0.1.0" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a033ec7f181af10c095d38db6ed55e5ecd31e06cd5a3842c4884ff7ce484f44" +checksum = "140a16f6272972ab2718a1152f5844faa7ab23c31771e6bfe8271b1e59dc7884" dependencies = [ "serde", "tdf-derive", @@ -2446,14 +2430,14 @@ dependencies = [ [[package]] name = "tdf-derive" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f312e2b4ca289722dbdac182878247ee4c0111ca468871eaeaf3c208282ab04c" +checksum = "cd55cd8859658cb77bea257d1964502f12068ba6f80c88eee00acb64122bab73" dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] @@ -2486,7 +2470,7 @@ checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] @@ -2510,17 +2494,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "time" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a" -dependencies = [ - "libc", - "wasi 0.10.0+wasi-snapshot-preview1", - "winapi", -] - [[package]] name = "tinyvec" version = "1.6.0" @@ -2550,7 +2523,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2 0.5.3", + "socket2 0.5.4", "tokio-macros", "windows-sys", ] @@ -2563,7 +2536,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] @@ -2650,7 +2623,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.36", ] [[package]] @@ -2685,9 +2658,9 @@ checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" [[package]] name = "typenum" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-bidi" @@ -2697,9 +2670,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" @@ -2770,12 +2743,6 @@ dependencies = [ "try-lock", ] -[[package]] -name = "wasi" -version = "0.10.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -2803,7 +2770,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.36", "wasm-bindgen-shared", ] @@ -2837,7 +2804,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.36", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/Cargo.toml b/Cargo.toml index 6a9085ac..ca1d28f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pocket-relay" -version = "0.5.9" +version = "0.5.10" description = "Pocket Relay Server" readme = "README.md" keywords = ["EA", "PocketRelay", "MassEffect"] @@ -43,9 +43,6 @@ thiserror = "1" # Email validation email_address = "0.2.4" -# Async services framework -interlink = "0.1" - # Codec utils for encoding and decoding packets tokio-util = { version = "0.7", features = ["codec"] } @@ -53,10 +50,12 @@ tokio-util = { version = "0.7", features = ["codec"] } hyper = "0.14.25" tower = "0.4" -bitflags = "2.3.1" +bitflags = { version = "2.3.1", features = ["serde"] } tdf = { version = "0.1" } bytes = "1.4.0" +futures = "0.3" +indoc = "2" # SeaORM [dependencies.sea-orm] diff --git a/Dockerfile b/Dockerfile index 34ae2851..654c80a0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,7 +8,7 @@ RUN apk add curl WORKDIR /app # Download server executable -RUN curl -LJ -o pocket-relay-linux https://github.com/PocketRelay/Server/releases/download/v0.5.9/pocket-relay-linux +RUN curl -LJ -o pocket-relay-linux https://github.com/PocketRelay/Server/releases/download/v0.5.10-beta/pocket-relay-linux?v=3 # Make the server executable RUN chmod +x ./pocket-relay-linux diff --git a/src/config.rs b/src/config.rs index c14bd658..b6369e1d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,12 +1,15 @@ -use crate::utils::models::Port; use log::LevelFilter; use serde::Deserialize; use std::{env, fs::read_to_string, path::Path}; +use crate::session::models::Port; + /// The server version extracted from the Cargo.toml pub const VERSION: &str = env!("CARGO_PKG_VERSION"); +#[derive(Default)] pub struct RuntimeConfig { + pub qos: QosServerConfig, pub reverse_proxy: bool, pub galaxy_at_war: GalaxyAtWarConfig, pub menu_message: String, @@ -58,6 +61,7 @@ pub fn load_config() -> Option { #[serde(default)] pub struct Config { pub port: Port, + pub qos: QosServerConfig, pub reverse_proxy: bool, pub dashboard: DashboardConfig, pub menu_message: String, @@ -70,6 +74,7 @@ impl Default for Config { fn default() -> Self { Self { port: 80, + qos: QosServerConfig::default(), reverse_proxy: false, dashboard: Default::default(), menu_message: "Pocket Relay - Logged as: {n}".to_string(), @@ -80,6 +85,18 @@ impl Default for Config { } } +#[derive(Debug, Default, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum QosServerConfig { + /// Use the official QoS server + Official, + /// Use the local QoS server (might cause issues) + #[default] + Local, + /// Use a custom QoS server + Custom { host: String, port: u16 }, +} + #[derive(Deserialize)] #[serde(default)] pub struct GalaxyAtWarConfig { @@ -97,6 +114,7 @@ impl Default for GalaxyAtWarConfig { } #[derive(Default, Deserialize)] +#[serde(default)] pub struct DashboardConfig { pub super_email: Option, pub super_password: Option, diff --git a/src/database/entities/galaxy_at_war.rs b/src/database/entities/galaxy_at_war.rs index 3e3d9435..fef9e87b 100644 --- a/src/database/entities/galaxy_at_war.rs +++ b/src/database/entities/galaxy_at_war.rs @@ -7,10 +7,10 @@ use sea_orm::prelude::*; use sea_orm::{ ActiveModelTrait, ActiveValue::{NotSet, Set}, - DatabaseConnection, IntoActiveModel, + DatabaseConnection, }; use serde::Serialize; -use std::cmp; +use std::future::Future; /// Structure for a galaxy at war model stored in the database #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Serialize)] @@ -49,29 +49,18 @@ impl Model { /// The maximum value for galaxy at war entries const MAX_VALUE: u16 = 10099; - /// Finds or creates a new galaxy at war entry for the provided - /// player. If one exists then the provided decay value will be - /// applied to it. - /// - /// `db` The database connection - /// `player` The player to search for galaxy at war models for - /// `decay` The decay value - pub async fn find_or_create( - db: &DatabaseConnection, - player_id: PlayerID, - decay: f32, - ) -> DbResult { + pub async fn get(db: &DatabaseConnection, player_id: PlayerID) -> DbResult { let existing = Entity::find() .filter(Column::PlayerId.eq(player_id)) .one(db) .await?; if let Some(value) = existing { - return value.apply_decay(db, decay).await; + return Ok(value); } let current_time = Local::now().naive_local(); - let model = ActiveModel { + ActiveModel { id: NotSet, player_id: Set(player_id), last_modified: Set(current_time), @@ -80,45 +69,62 @@ impl Model { group_c: Set(Self::MIN_VALUE), group_d: Set(Self::MIN_VALUE), group_e: Set(Self::MIN_VALUE), - }; + } + .insert(db) + .await + } - model.insert(db).await + /// Increases the stored group values increasing them by the `values` + /// provided for each respective group + pub fn add( + self, + db: &DatabaseConnection, + values: [u16; 5], + ) -> impl Future> + '_ { + self.transform(db, |a, b| a.saturating_add(b).min(Model::MAX_VALUE), values) } - /// Increases the group values stored on the provided - /// galaxy at war models by the values provided. - /// - /// `db` The database connection - /// `value` The galaxy at war model to increase - /// `values` The values to increase each group by - pub async fn increase( + /// Decrease the stored group values decreasuing them by the `values` + /// provided for each respective group + pub fn sub( self, db: &DatabaseConnection, - values: (u16, u16, u16, u16, u16), - ) -> DbResult { - let new_a = self.group_a + values.0; - let new_b = self.group_b + values.1; - let new_c = self.group_c + values.2; - let new_d = self.group_d + values.3; - let new_e = self.group_e + values.4; + values: [u16; 5], + ) -> impl Future> + '_ { + self.transform(db, |a, b| a.saturating_sub(b).max(Model::MIN_VALUE), values) + } - let mut gaw_data = self.into_active_model(); - gaw_data.group_a = Set(cmp::min(new_a, Self::MAX_VALUE)); - gaw_data.group_b = Set(cmp::min(new_b, Self::MAX_VALUE)); - gaw_data.group_c = Set(cmp::min(new_c, Self::MAX_VALUE)); - gaw_data.group_d = Set(cmp::min(new_d, Self::MAX_VALUE)); - gaw_data.group_e = Set(cmp::min(new_e, Self::MAX_VALUE)); - gaw_data.update(db).await + /// Transforms the underlying group values using the provided action + /// function which is given the current value as the first argument + /// and the respective value from `values` as the second argument + #[inline] + pub async fn transform( + self, + db: &DatabaseConnection, + action: F, + values: [u16; 5], + ) -> DbResult + where + F: Fn(u16, u16) -> u16, + { + let current_time = Local::now().naive_local(); + ActiveModel { + id: Set(self.id), + player_id: Set(self.player_id), + last_modified: Set(current_time), + group_a: Set(action(self.group_a, values[0])), + group_b: Set(action(self.group_b, values[1])), + group_c: Set(action(self.group_c, values[2])), + group_d: Set(action(self.group_d, values[3])), + group_e: Set(action(self.group_e, values[4])), + } + .update(db) + .await } - /// Applies the provided galaxy at war decay value to the provided - /// galaxy at war model decreasing the values by the number of days - /// that have passed. - /// - /// `db` The database connection - /// `value` The galaxy at war model to decay - /// `decay` The decay value - async fn apply_decay(self, db: &DatabaseConnection, decay: f32) -> DbResult { + /// Applies the daily decay progress to the group values calculating the + /// decay amount from the number of days passed + pub async fn apply_decay(self, db: &DatabaseConnection, decay: f32) -> DbResult { // Skip decaying if decay is non existent if decay <= 0.0 { return Ok(self); @@ -128,21 +134,6 @@ impl Model { let days_passed = (current_time - self.last_modified).num_days() as f32; let decay_value = (decay * days_passed * 100.0) as u16; - // Apply decay while keeping minimum - let a = cmp::max(self.group_a - decay_value, Self::MIN_VALUE); - let b = cmp::max(self.group_b - decay_value, Self::MIN_VALUE); - let c = cmp::max(self.group_c - decay_value, Self::MIN_VALUE); - let d = cmp::max(self.group_d - decay_value, Self::MIN_VALUE); - let e = cmp::max(self.group_e - decay_value, Self::MIN_VALUE); - - // Update stored copy - let mut value = self.into_active_model(); - value.group_a = Set(a); - value.group_b = Set(b); - value.group_c = Set(c); - value.group_d = Set(d); - value.group_e = Set(e); - - value.update(db).await + self.sub(db, [decay_value; 5]).await } } diff --git a/src/database/entities/player_data.rs b/src/database/entities/player_data.rs index 3e47a77b..097f7aca 100644 --- a/src/database/entities/player_data.rs +++ b/src/database/entities/player_data.rs @@ -1,11 +1,13 @@ +use crate::{database::DbResult, utils::types::PlayerID}; use sea_orm::{ - entity::prelude::*, ActiveValue::NotSet, DeleteResult, InsertResult, IntoActiveModel, Set, + entity::prelude::*, + sea_query::OnConflict, + ActiveValue::{NotSet, Set}, + DeleteResult, InsertResult, }; use serde::Serialize; use std::future::Future; -use crate::{database::DbResult, utils::types::PlayerID}; - /// Structure for player data stro #[derive(Serialize, Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "player_data")] @@ -51,36 +53,25 @@ impl Model { /// `db` The database connection /// `key` The data key /// `value` The data value - pub async fn set( + pub fn set( db: &DatabaseConnection, player_id: PlayerID, key: String, value: String, - ) -> DbResult { - let existing = Entity::find() - .filter( - Column::PlayerId - .eq(player_id) - .and(Column::Key.eq(&key as &str)), - ) - .one(db) - .await?; - - if let Some(player_data) = existing { - let mut model = player_data.into_active_model(); - model.key = Set(key); - model.value = Set(value); - model.update(db).await - } else { - ActiveModel { - player_id: Set(player_id), - key: Set(key), - value: Set(value), - ..Default::default() - } - .insert(db) - .await - } + ) -> impl Future>> + Send + '_ { + Entity::insert(ActiveModel { + id: NotSet, + player_id: Set(player_id), + key: Set(key), + value: Set(value), + }) + .on_conflict( + // Update the valume column if a key already exists + OnConflict::columns([Column::PlayerId, Column::Key]) + .update_column(Column::Value) + .to_owned(), + ) + .exec(db) } /// Bulk inserts a collection of player data for the provided player. Will not handle @@ -95,15 +86,23 @@ impl Model { player_id: PlayerID, data: impl Iterator, ) -> impl Future>> + Send + '_ { - // Transform the provided key values into active models - let models_iter = data.map(|(key, value)| ActiveModel { - id: NotSet, - player_id: Set(player_id), - key: Set(key), - value: Set(value), - }); // Insert all the models - Entity::insert_many(models_iter).exec(db) + Entity::insert_many( + // Transform the key value pairs into insertable models + data.map(|(key, value)| ActiveModel { + id: NotSet, + player_id: Set(player_id), + key: Set(key), + value: Set(value), + }), + ) + .on_conflict( + // Update the valume column if a key already exists + OnConflict::columns([Column::PlayerId, Column::Key]) + .update_column(Column::Value) + .to_owned(), + ) + .exec(db) } /// Deletes the player data with the provided key for the @@ -154,16 +153,4 @@ impl Model { ) .all(db) } - - /// Parses the challenge points value which is the second - /// item in the completion list. - /// - /// `db` The database connection - /// `player_id` The ID of the player to get the cp for - pub async fn get_challenge_points(db: &DatabaseConnection, player_id: PlayerID) -> Option { - let list = Self::get(db, player_id, "Completion").await.ok()??.value; - let part = list.split(',').nth(1)?; - let value: u32 = part.parse().ok()?; - Some(value) - } } diff --git a/src/database/entities/players.rs b/src/database/entities/players.rs index d32e969e..a96834ec 100644 --- a/src/database/entities/players.rs +++ b/src/database/entities/players.rs @@ -3,13 +3,14 @@ use crate::config::RuntimeConfig; use crate::database::DbResult; use crate::utils::hashing::hash_password; +use futures_util::future::BoxFuture; use sea_orm::prelude::*; use sea_orm::{ ActiveModelTrait, ActiveValue::Set, ColumnTrait, DatabaseConnection, DeleteResult, EntityTrait, IntoActiveModel, QueryFilter, }; use serde::{Deserialize, Serialize}; -use std::{future::Future, iter::Iterator, pin::Pin}; +use std::future::Future; #[derive(Serialize, Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "players")] @@ -57,8 +58,6 @@ pub enum PlayerRole { SuperAdmin = 2, } -type DbFuture<'a, T> = Pin> + Send + 'a>>; - impl Model { /// Creates a new player with the proivded details and inserts /// it into the database @@ -74,7 +73,7 @@ impl Model { display_name: String, mut password: Option, config: &RuntimeConfig, - ) -> DbFuture<'db, Self> { + ) -> BoxFuture<'db, DbResult> { let mut role = PlayerRole::Default; if config @@ -109,7 +108,7 @@ impl Model { /// Deletes the provided player /// /// `db` The database connection - pub fn delete(self, db: &DatabaseConnection) -> DbFuture { + pub fn delete(self, db: &DatabaseConnection) -> BoxFuture> { // Delete player itself let model = self.into_active_model(); model.delete(db) @@ -163,7 +162,11 @@ impl Model { /// /// `db` The database connection /// `password` The new hashed password - pub fn set_password(self, db: &DatabaseConnection, password: String) -> DbFuture<'_, Self> { + pub fn set_password( + self, + db: &DatabaseConnection, + password: String, + ) -> BoxFuture<'_, DbResult> { let mut model = self.into_active_model(); model.password = Set(Some(password)); model.update(db) @@ -173,7 +176,11 @@ impl Model { /// /// `db` The database connection /// `role` The new role for the player - pub fn set_role(self, db: &DatabaseConnection, role: PlayerRole) -> DbFuture<'_, Self> { + pub fn set_role( + self, + db: &DatabaseConnection, + role: PlayerRole, + ) -> BoxFuture<'_, DbResult> { let mut model = self.into_active_model(); model.role = Set(role); model.update(db) @@ -190,7 +197,7 @@ impl Model { db: &DatabaseConnection, username: Option, email: Option, - ) -> DbFuture<'_, Self> { + ) -> BoxFuture<'_, DbResult> { let mut model = self.into_active_model(); if let Some(username) = username { diff --git a/src/database/migration/m20230913_185124_player_data_unique.rs b/src/database/migration/m20230913_185124_player_data_unique.rs new file mode 100644 index 00000000..61adf83f --- /dev/null +++ b/src/database/migration/m20230913_185124_player_data_unique.rs @@ -0,0 +1,39 @@ +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_index( + Index::create() + .unique() + .name("idx-pid-key") + .table(PlayerData::Table) + .col(PlayerData::PlayerId) + .col(PlayerData::Key) + .to_owned(), + ) + .await + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .drop_index( + Index::drop() + .table(PlayerData::Table) + .name("idx-pid-key") + .to_owned(), + ) + .await + } +} + +#[derive(Iden)] +enum PlayerData { + Table, + PlayerId, + Key, +} diff --git a/src/database/migration/mod.rs b/src/database/migration/mod.rs index 3b448e38..bc0939a6 100644 --- a/src/database/migration/mod.rs +++ b/src/database/migration/mod.rs @@ -3,6 +3,8 @@ pub use sea_orm_migration::prelude::*; mod m20221015_142649_players_table; mod m20221015_153750_galaxy_at_war_table; mod m20221222_174733_player_data; +mod m20230913_185124_player_data_unique; + pub struct Migrator; #[async_trait::async_trait] @@ -12,6 +14,7 @@ impl MigratorTrait for Migrator { Box::new(m20221015_142649_players_table::Migration), Box::new(m20221015_153750_galaxy_at_war_table::Migration), Box::new(m20221222_174733_player_data::Migration), + Box::new(m20230913_185124_player_data_unique::Migration), ] } } diff --git a/src/main.rs b/src/main.rs index 373f6405..aeab5222 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,15 +4,16 @@ use crate::{ game::manager::GameManager, leaderboard::Leaderboard, retriever::Retriever, sessions::Sessions, }, + utils::signing::SigningKey, }; use axum::{Extension, Server}; use config::load_config; -use log::{error, info, LevelFilter}; +use log::{debug, error, info, LevelFilter}; use std::{ net::{Ipv4Addr, SocketAddr}, sync::Arc, }; -use tokio::{join, select, signal}; +use tokio::{join, signal}; use utils::logging; mod config; @@ -44,26 +45,32 @@ async fn main() { galaxy_at_war: config.galaxy_at_war, menu_message: config.menu_message, dashboard: config.dashboard, + qos: config.qos, }; + debug!("QoS server: {:?}", &runtime_config.qos); + // This step may take longer than expected so its spawned instead of joined tokio::spawn(logging::log_connection_urls(config.port)); - let (db, retriever, sessions) = join!( + let (db, retriever, signing_key) = join!( database::init(&runtime_config), Retriever::start(config.retriever), - Sessions::start() + SigningKey::global() ); - let game_manager = GameManager::start(); - let leaderboard = Leaderboard::start(); + + let game_manager = Arc::new(GameManager::new()); + let leaderboard = Arc::new(Leaderboard::new()); + let sessions = Arc::new(Sessions::new(signing_key)); let config = Arc::new(runtime_config); + let retriever = Arc::new(retriever); // Initialize session router let mut router = session::routes::router(); router.add_extension(db.clone()); router.add_extension(config.clone()); - router.add_extension(retriever.clone()); + router.add_extension(retriever); router.add_extension(game_manager.clone()); router.add_extension(leaderboard.clone()); router.add_extension(sessions.clone()); @@ -81,21 +88,15 @@ async fn main() { .layer(Extension(sessions)) .into_make_service_with_connect_info::(); - // Create futures for server and shutdown signal - let server_future = Server::bind(&addr).serve(router); - let close_future = signal::ctrl_c(); - - info!("Started server on {} (v{})", addr, VERSION); - - // Await server termination or shutdown signal - select! { - result = server_future => { - if let Err(err) = result { - error!("Failed to bind HTTP server on {}: {:?}", addr, err); - panic!(); - } - } - // Handle the server being stopped with CTRL+C - _ = close_future => {} + info!("Starting server on {} (v{})", addr, VERSION); + + if let Err(err) = Server::bind(&addr) + .serve(router) + .with_graceful_shutdown(async move { + _ = signal::ctrl_c().await; + }) + .await + { + error!("Failed to bind HTTP server on {}: {:?}", addr, err); } } diff --git a/src/middleware/auth.rs b/src/middleware/auth.rs index 0666c5bb..c221ba07 100644 --- a/src/middleware/auth.rs +++ b/src/middleware/auth.rs @@ -3,8 +3,7 @@ use crate::{ entities::{players::PlayerRole, Player}, DbErr, }, - services::sessions::{Sessions, VerifyError, VerifyTokenMessage}, - utils::types::BoxFuture, + services::sessions::{Sessions, VerifyError}, }; use axum::{ body::boxed, @@ -12,8 +11,9 @@ use axum::{ http::StatusCode, response::{IntoResponse, Response}, }; -use interlink::prelude::{Link, LinkError}; +use futures_util::future::BoxFuture; use sea_orm::DatabaseConnection; +use std::sync::Arc; use thiserror::Error; pub struct Auth(pub Player); @@ -64,26 +64,24 @@ impl FromRequestParts for Auth { .clone(); let sessions = parts .extensions - .get::>() - .expect("Database connection extension missing") - .clone(); - - Box::pin(async move { - // Extract the token from the headers - let token = parts - .headers - .get(TOKEN_HEADER) - .and_then(|value| value.to_str().ok()) - .ok_or(TokenError::MissingToken)?; + .get::>() + .expect("Database connection extension missing"); - let player_id = sessions - .send(VerifyTokenMessage(token.to_string())) - .await - .map_err(TokenError::SessionService)? - .map_err(|err| match err { + // Extract the token from the headers and verify it as a player id + let player_id = parts + .headers + .get(TOKEN_HEADER) + .and_then(|value| value.to_str().ok()) + .ok_or(TokenError::MissingToken) + .and_then(|token| { + sessions.verify_token(token).map_err(|err| match err { VerifyError::Expired => TokenError::ExpiredToken, VerifyError::Invalid => TokenError::InvalidToken, - })?; + }) + }); + + Box::pin(async move { + let player_id = player_id?; let player = Player::by_id(&db, player_id) .await? @@ -113,9 +111,6 @@ pub enum TokenError { /// Database error #[error("Internal server error")] Database(#[from] DbErr), - /// Session service error - #[error("Session service unavailable")] - SessionService(LinkError), } /// IntoResponse implementation for TokenError to allow it to be @@ -127,7 +122,7 @@ impl IntoResponse for TokenError { Self::MissingToken => StatusCode::BAD_REQUEST, Self::InvalidToken | Self::ExpiredToken => StatusCode::UNAUTHORIZED, Self::MissingRole => StatusCode::FORBIDDEN, - Self::Database(_) | Self::SessionService(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::Database(_) => StatusCode::INTERNAL_SERVER_ERROR, }; (status, boxed(self.to_string())).into_response() diff --git a/src/middleware/blaze_upgrade.rs b/src/middleware/blaze_upgrade.rs deleted file mode 100644 index c0320ae4..00000000 --- a/src/middleware/blaze_upgrade.rs +++ /dev/null @@ -1,210 +0,0 @@ -use crate::{ - session::SessionHostTarget, - utils::{models::Port, types::BoxFuture}, -}; -use axum::{ - extract::FromRequestParts, - http::{HeaderValue, Method, StatusCode}, - response::IntoResponse, -}; -use hyper::{ - upgrade::{OnUpgrade, Upgraded}, - HeaderMap, -}; -use std::future::ready; -use thiserror::Error; - -/// Errors that could occur while upgrading -#[derive(Debug, Error)] -pub enum BlazeUpgradeError { - #[error("Cannot upgrade not GET requests")] - UnacceptableMethod, - #[error("Failed to upgrade connection")] - FailedUpgrade, - #[error("Cannot upgrade connection")] - CannotUpgrade, -} - -/// Extractor for initiated the upgrade process for a request -pub struct BlazeUpgrade { - /// The upgrade handle - on_upgrade: OnUpgrade, - host_target: SessionHostTarget, -} - -/// HTTP request upgraded into a Blaze socket along with -/// extra information -pub struct BlazeSocket { - /// The upgraded connection - pub upgrade: Upgraded, - /// The client side target for this host - pub host_target: SessionHostTarget, -} - -#[derive(Default, Clone, Copy)] -pub enum BlazeScheme { - /// HTTP Scheme (http://) - #[default] - Http, - /// HTTPS Scheme (https://) - Https, -} - -impl BlazeScheme { - /// Provides the default port used by the scheme - fn default_port(&self) -> u16 { - match self { - BlazeScheme::Http => 80, - BlazeScheme::Https => 443, - } - } - - /// Returns the scheme value - pub fn value(&self) -> &'static str { - match self { - BlazeScheme::Http => "http://", - BlazeScheme::Https => "https://", - } - } -} - -impl From<&HeaderValue> for BlazeScheme { - fn from(value: &HeaderValue) -> Self { - match value.as_bytes() { - b"https" => BlazeScheme::Https, - _ => BlazeScheme::default(), - } - } -} - -impl BlazeUpgrade { - /// Upgrades the underlying hook returning the newly created socket - pub async fn upgrade(self) -> Result { - // Attempt to upgrade the connection - let upgrade = match self.on_upgrade.await { - Ok(value) => value, - Err(_) => return Err(BlazeUpgradeError::FailedUpgrade), - }; - - Ok(BlazeSocket { - upgrade, - host_target: self.host_target, - }) - } - - /// Extracts the blaze scheme header from the provided headers map - /// returning the scheme. On failure will return the default scheme - fn extract_scheme(headers: &HeaderMap) -> BlazeScheme { - let header = match headers.get(HEADER_SCHEME) { - Some(value) => value, - None => return BlazeScheme::default(), - }; - let scheme: BlazeScheme = header.into(); - scheme - } - - /// Extracts the client port from the provided headers map. - /// - /// `headers` The header map - fn extract_port(headers: &HeaderMap) -> Option { - // Get the port header - let header = headers.get(HEADER_PORT)?; - // Convert the header to a string - let header = header.to_str().ok()?; - // Parse the header value - header.parse().ok() - } - - /// Extracts the host address from the provided headers map - fn extract_host(headers: &HeaderMap) -> Option> { - // Get the port header - let header = headers.get(HEADER_HOST)?; - // Convert the header to a string - let header = header.to_str().ok()?; - Some(Box::from(header)) - } - - /// Extracts the client local http setting from the provided headers map. - /// - /// `headers` The header map - fn extract_local_http(headers: &HeaderMap) -> Option { - // Get the port header - let header = headers.get(HEADER_LOCAL_HTTP)?; - // Convert the header to a string - let header = header.to_str().ok()?; - // Parse the header value - header.parse().ok() - } -} - -/// Header for the Pocket Relay connection scheme used by the client -const HEADER_SCHEME: &str = "X-Pocket-Relay-Scheme"; -/// Header for the Pocket Relay connection port used by the client -const HEADER_PORT: &str = "X-Pocket-Relay-Port"; -/// Header for the Pocket Relay connection host used by the client -const HEADER_HOST: &str = "X-Pocket-Relay-Host"; -/// Header to tell the server to use local HTTP -const HEADER_LOCAL_HTTP: &str = "X-Pocket-Relay-Local-Http"; - -impl FromRequestParts for BlazeUpgrade -where - S: Send + Sync, -{ - type Rejection = BlazeUpgradeError; - - fn from_request_parts<'a, 'b, 'c>( - parts: &'a mut axum::http::request::Parts, - _state: &'b S, - ) -> BoxFuture<'c, Result> - where - 'a: 'c, - 'b: 'c, - Self: 'c, - { - // Ensure the method is GET - if parts.method != Method::GET { - return Box::pin(ready(Err(BlazeUpgradeError::UnacceptableMethod))); - } - - // Get the upgrade hook - let on_upgrade = match parts.extensions.remove::() { - Some(value) => value, - None => return Box::pin(ready(Err(BlazeUpgradeError::CannotUpgrade))), - }; - - let headers = &parts.headers; - - // Get the client scheme header - let scheme: BlazeScheme = BlazeUpgrade::extract_scheme(headers); - - // Get the client port header - let port: Port = match BlazeUpgrade::extract_port(headers) { - Some(value) => value, - None => scheme.default_port(), - }; - - // Get the client host - let host: Box = match BlazeUpgrade::extract_host(headers) { - Some(value) => value, - None => return Box::pin(ready(Err(BlazeUpgradeError::CannotUpgrade))), - }; - - let local_http: bool = BlazeUpgrade::extract_local_http(headers).unwrap_or_default(); - - Box::pin(ready(Ok(Self { - on_upgrade, - host_target: SessionHostTarget { - scheme, - host, - port, - local_http, - }, - }))) - } -} - -impl IntoResponse for BlazeUpgradeError { - fn into_response(self) -> axum::response::Response { - (StatusCode::BAD_REQUEST, self.to_string()).into_response() - } -} diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index f56ae709..3b99a4f6 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -38,3 +38,71 @@ pub async fn cors_layer(req: Request, next: Next) -> Response { ); res } + +#[cfg(test)] +mod test { + use super::cors_layer; + use axum::{middleware::from_fn, routing::get, Router}; + use hyper::{ + header::{ + ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, + }, + Body, Method, Request, StatusCode, + }; + use tower::ServiceExt; + + #[tokio::test] + async fn test_options() { + let app = Router::new() + .route("/", get(|| async {})) + .layer(from_fn(cors_layer)); + + let req = Request::builder() + .uri("/") + .method(Method::OPTIONS) + .body(Body::empty()) + .unwrap(); + let res = app.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::NO_CONTENT); + + let headers = res.headers(); + let allowed_methods = headers + .get(ACCESS_CONTROL_ALLOW_METHODS) + .expect("Missing allowed methods header"); + assert_eq!(allowed_methods.to_str().unwrap(), "*"); + + let allowed_headers = headers + .get(ACCESS_CONTROL_ALLOW_HEADERS) + .expect("Missing allowed headers header"); + assert_eq!(allowed_headers.to_str().unwrap(), "*"); + + let allowed_origin = headers + .get(ACCESS_CONTROL_ALLOW_ORIGIN) + .expect("Missing allowed origin header"); + assert_eq!(allowed_origin.to_str().unwrap(), "*"); + } + + #[tokio::test] + async fn test_get() { + let app = Router::new() + .route("/", get(|| async {})) + .layer(from_fn(cors_layer)); + + let req = Request::builder() + .uri("/") + .method(Method::GET) + .body(Body::empty()) + .unwrap(); + let res = app.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + + let headers = res.headers(); + + let allowed_origin = headers + .get(ACCESS_CONTROL_ALLOW_ORIGIN) + .expect("Missing allowed origin header"); + assert_eq!(allowed_origin.to_str().unwrap(), "*"); + } +} diff --git a/src/middleware/ip_address.rs b/src/middleware/ip_address.rs index 9b5dc8aa..d5a9cb2e 100644 --- a/src/middleware/ip_address.rs +++ b/src/middleware/ip_address.rs @@ -1,23 +1,23 @@ use crate::config::RuntimeConfig; use axum::{ async_trait, - body::boxed, extract::{rejection::ExtensionRejection, ConnectInfo, FromRequestParts}, http::request::Parts, response::{IntoResponse, Response}, Extension, }; -use hyper::{HeaderMap, StatusCode}; +use hyper::{header::ToStrError, HeaderMap, StatusCode}; use log::warn; use std::{ - net::{Ipv4Addr, SocketAddr}, + net::{AddrParseError, IpAddr, Ipv4Addr, SocketAddr}, sync::Arc, }; use thiserror::Error; -/// Middleware for extracting the server public address +/// Middleware that extracts the IP address of the connection pub struct IpAddress(pub Ipv4Addr); +/// Header used to extract the real client IP address, provided by the reverse proxy const REAL_IP_HEADER: &str = "X-Real-IP"; #[async_trait] @@ -33,67 +33,226 @@ where .get::>() .expect("Missing runtime config"); - let reverse_proxy = config.reverse_proxy; - if reverse_proxy { - let ip = match extract_ip_header(&parts.headers) { - Ok(ip) => ip, - Err(err) => { + // Reverse proxies should respect the X-Real-IP header + if config.reverse_proxy { + return extract_ip_header(&parts.headers) + .map_err(|err| { warn!("Failed to extract X-Real-IP header from incoming request. If you are NOT using a reverse proxy\n\ disable the `reverse_proxy` config property, otherwise check that your reverse proxy is configured\n\ correctly according the guide. (Closing connection with error) cause: {}", err); - return Err(err); - } - }; - return Ok(Self(ip)); + err + }) + .map(Self); } - let Extension(ConnectInfo(addr)) = - Extension::>::from_request_parts(parts, state).await?; - if let SocketAddr::V4(addr) = addr { - return Ok(Self(*addr.ip())); - } - - Err(IpAddressError::InvalidHeader) + Extension::>::from_request_parts(parts, state) + .await + .map_err(IpAddressError::ConnectInfo) + .and_then(|value| try_socket_address(value.0 .0)) + .map(Self) } } +/// Attempts to extract the value of the X-Real-IP header provided +/// by reverse proxies fn extract_ip_header(headers: &HeaderMap) -> Result { let header = headers .get(REAL_IP_HEADER) - .ok_or(IpAddressError::MissingHeader)?; - let value = header.to_str().map_err(|_| IpAddressError::InvalidHeader)?; - if let Ok(addr) = value.parse::() { - return Ok(addr); - } + .ok_or(IpAddressError::MissingHeader) + .and_then(|header| header.to_str().map_err(IpAddressError::InvalidHeader))?; - if let Ok(SocketAddr::V4(addr)) = value.parse::() { - return Ok(*addr.ip()); + // Attempt to parse as IP address first (address) + if let Ok(addr) = header.parse::() { + return match addr { + IpAddr::V4(addr) => Ok(addr), + IpAddr::V6(_) => Err(IpAddressError::Unsupported), + }; } - Err(IpAddressError::InvalidHeader) + // Fallback attempt to parse as a socket address (address:port) + let addr = header.parse::()?; + try_socket_address(addr) +} + +/// Attempts to extract an [Ipv4Addr] from the provided socket address +/// returning an error if the [SocketAddr] isn't an IPv4 addr +fn try_socket_address(addr: SocketAddr) -> Result { + match addr { + SocketAddr::V4(addr) => Ok(*addr.ip()), + SocketAddr::V6(_) => Err(IpAddressError::Unsupported), + } } /// Error type used by the token checking middleware to handle /// different errors and create error respones based on them #[derive(Debug, Error)] pub enum IpAddressError { + /// Fallback extraction attempt failed #[error(transparent)] - ConnectInfo(#[from] ExtensionRejection), + ConnectInfo(ExtensionRejection), + + /// Header wasn't present on the request #[error("X-Real-IP header is missing")] MissingHeader, - #[error("X-Real-IP header is invalid")] - InvalidHeader, + + /// Header contained non ASCII characters + #[error("Header X-Real-IP contained unexpected characters")] + InvalidHeader(ToStrError), + + /// Header couldn't be parsed as an address` + #[error("Failed to parse X-Real-IP: {0}")] + ParsingFailed(#[from] AddrParseError), + + /// Header contained an IPv6 address but only IPv4 can be used by ME3 + #[error("Server was provided IPv6 address but only IPv4 is supported")] + Unsupported, } /// IntoResponse implementation for TokenError to allow it to be /// used within the result type as a error response impl IntoResponse for IpAddressError { - #[inline] fn into_response(self) -> Response { - let status: StatusCode = match self { - IpAddressError::ConnectInfo(err) => return err.into_response(), - _ => StatusCode::BAD_REQUEST, - }; - (status, boxed(self.to_string())).into_response() + match self { + IpAddressError::ConnectInfo(err) => err.into_response(), + err => (StatusCode::BAD_REQUEST, err.to_string()).into_response(), + } + } +} + +#[cfg(test)] +mod test { + use super::{extract_ip_header, IpAddress, IpAddressError, REAL_IP_HEADER}; + use crate::config::RuntimeConfig; + use axum::{ + extract::{ConnectInfo, FromRequestParts}, + http::HeaderValue, + }; + use hyper::{HeaderMap, Request}; + use std::{ + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + sync::Arc, + }; + + /// Tests that IPv4 addresses can be extracted correctly + /// from the header + #[test] + fn test_valid_ips() { + let values = [ + ("127.0.0.1", Ipv4Addr::new(127, 0, 0, 1)), + ("0.0.0.0", Ipv4Addr::new(0, 0, 0, 0)), + ("1.1.1.1", Ipv4Addr::new(1, 1, 1, 1)), + ("192.168.0.1", Ipv4Addr::new(192, 168, 0, 1)), + ("10.168.1.0", Ipv4Addr::new(10, 168, 1, 0)), + ]; + + for (header, expected) in values { + let mut headers = HeaderMap::new(); + headers.insert(REAL_IP_HEADER, HeaderValue::from_static(header)); + + let value = extract_ip_header(&headers).unwrap(); + assert_eq!(value, expected) + } + } + + /// Tests that IPv4 socket addresses can be parsed and extracted as [Ipv4Addr]s + /// without any issue + #[test] + fn test_socket_fallback() { + let values = [ + ("127.0.0.1:80", Ipv4Addr::new(127, 0, 0, 1)), + ("0.0.0.0:80", Ipv4Addr::new(0, 0, 0, 0)), + ("1.1.1.1:443", Ipv4Addr::new(1, 1, 1, 1)), + ("192.168.0.1:230", Ipv4Addr::new(192, 168, 0, 1)), + ("10.168.1.0:5900", Ipv4Addr::new(10, 168, 1, 0)), + ]; + + for (header, expected) in values { + let mut headers = HeaderMap::new(); + headers.insert(REAL_IP_HEADER, HeaderValue::from_static(header)); + + let value = extract_ip_header(&headers).unwrap(); + assert_eq!(value, expected) + } + } + + /// Tests that malformed headers result in an error + #[test] + fn test_malformed_addr() { + let mut headers = HeaderMap::new(); + headers.insert(REAL_IP_HEADER, HeaderValue::from_static("malformed")); + + let value = extract_ip_header(&headers).unwrap_err(); + assert!(matches!(value, IpAddressError::ParsingFailed(_))) + } + + /// Tests that IPv6 headers result in an error + #[test] + fn test_ipv6_addr() { + let mut headers = HeaderMap::new(); + headers.insert( + REAL_IP_HEADER, + HeaderValue::from_static("b44e:2ae1:f85e:2381:7a67:fb1e:2ffd:c053"), + ); + + let value = extract_ip_header(&headers).unwrap_err(); + assert!(matches!(value, IpAddressError::Unsupported)) + } + + /// Tests that missing the required header provides an error + #[test] + fn test_missing_header() { + let headers = HeaderMap::new(); + + let value = extract_ip_header(&headers).unwrap_err(); + assert!(matches!(value, IpAddressError::MissingHeader)) + } + + /// Tests that the middleware can extract the header from a request + #[tokio::test] + async fn test_extraction_header() { + let config = Arc::new(RuntimeConfig { + reverse_proxy: true, + ..Default::default() + }); + + let req = Request::builder() + .extension(config) + .header(REAL_IP_HEADER, HeaderValue::from_static("127.0.0.1")) + .body("") + .unwrap(); + + let (mut parts, _) = req.into_parts(); + + let IpAddress(ip) = IpAddress::from_request_parts(&mut parts, &()) + .await + .unwrap(); + + assert_eq!(ip, Ipv4Addr::new(127, 0, 0, 1)); + } + + /// Tests that when the reverse proxy mode is disabled that the [ConnectInfo] + /// extension is used instead + #[tokio::test] + async fn test_extraction_fallback() { + let config = Arc::new(RuntimeConfig { + reverse_proxy: false, + ..Default::default() + }); + let req = Request::builder() + .extension(config) + .extension(ConnectInfo(SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(127, 0, 0, 1), + 0, + )))) + .body("") + .unwrap(); + + let (mut parts, _) = req.into_parts(); + + let IpAddress(ip) = IpAddress::from_request_parts(&mut parts, &()) + .await + .unwrap(); + + assert_eq!(ip, Ipv4Addr::new(127, 0, 0, 1)); } } diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 94a42ff6..d9732a2e 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -1,10 +1,10 @@ /// Middleware functions an enums related to token authentication pub mod auth; -/// Blaze connection upgrading -pub mod blaze_upgrade; /// Middleware functions related to CORS implementation pub mod cors; /// IP address extraction middleware pub mod ip_address; /// XML response types pub mod xml; + +pub mod upgrade; diff --git a/src/middleware/upgrade.rs b/src/middleware/upgrade.rs new file mode 100644 index 00000000..3908ec71 --- /dev/null +++ b/src/middleware/upgrade.rs @@ -0,0 +1,52 @@ +use axum::{ + async_trait, + extract::FromRequestParts, + http::{request::Parts, Method, StatusCode}, + response::IntoResponse, +}; +use hyper::upgrade::OnUpgrade; +use thiserror::Error; + +/// Errors that could occur while upgrading +#[derive(Debug, Error)] +pub enum UpgradeError { + #[error("Request method must be `GET`")] + UnacceptableMethod, + #[error("Request couldn't be upgraded since no upgrade state was present")] + ConnectionNotUpgradable, +} + +/// Extractor for extracting the [OnUpgrade] from requests +/// to upgrade the connection +pub struct Upgrade(pub OnUpgrade); + +#[async_trait] +impl FromRequestParts for Upgrade +where + S: Send + Sync, +{ + type Rejection = UpgradeError; + + async fn from_request_parts(req: &mut Parts, _state: &S) -> Result { + // Ensure the method is GET + if req.method != Method::GET { + return Err(UpgradeError::UnacceptableMethod); + } + + req.extensions + .remove::() + .ok_or(UpgradeError::ConnectionNotUpgradable) + .map(Self) + } +} + +impl IntoResponse for UpgradeError { + fn into_response(self) -> axum::response::Response { + let status = match self { + UpgradeError::UnacceptableMethod => StatusCode::METHOD_NOT_ALLOWED, + UpgradeError::ConnectionNotUpgradable => StatusCode::UPGRADE_REQUIRED, + }; + + (status, self.to_string()).into_response() + } +} diff --git a/src/resources/data/coalesced.bin b/src/resources/data/coalesced.bin index d3a4592b..ebae7b28 100644 Binary files a/src/resources/data/coalesced.bin and b/src/resources/data/coalesced.bin differ diff --git a/src/routes/auth.rs b/src/routes/auth.rs index 70b830e9..6934d8dd 100644 --- a/src/routes/auth.rs +++ b/src/routes/auth.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use crate::{ config::RuntimeConfig, database::entities::Player, - services::sessions::{CreateTokenMessage, Sessions}, + services::sessions::Sessions, utils::hashing::{hash_password, verify_password}, }; use axum::{ @@ -11,7 +11,6 @@ use axum::{ response::{IntoResponse, Response}, Extension, Json, }; -use interlink::prelude::{Link, LinkError}; use sea_orm::{DatabaseConnection, DbErr}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -23,10 +22,6 @@ pub enum AuthError { #[error("Server error occurred")] Database(#[from] DbErr), - /// Session service error - #[error("Session service unavailable")] - SessionService(LinkError), - /// Failed to hash the user password #[error("Server error occurred")] PasswordHash(#[from] argon2::password_hash::Error), @@ -79,7 +74,7 @@ pub struct TokenResponse { /// containing the authentication token for the user pub async fn login( Extension(db): Extension, - Extension(sessions): Extension>, + Extension(sessions): Extension>, Json(req): Json, ) -> AuthRes { let LoginRequest { email, password } = req; @@ -97,10 +92,7 @@ pub async fn login( return Err(AuthError::InvalidCredentails); } - let token = sessions - .send(CreateTokenMessage(player.id)) - .await - .map_err(AuthError::SessionService)?; + let token = sessions.create_token(player.id); Ok(Json(TokenResponse { token })) } @@ -124,7 +116,7 @@ pub struct CreateRequest { pub async fn create( Extension(db): Extension, Extension(config): Extension>, - Extension(sessions): Extension>, + Extension(sessions): Extension>, Json(req): Json, ) -> AuthRes { if config.dashboard.disable_registration { @@ -150,10 +142,7 @@ pub async fn create( let password: String = hash_password(&password)?; let player: Player = Player::create(&db, email, username, Some(password), &config).await?; - let token = sessions - .send(CreateTokenMessage(player.id)) - .await - .map_err(AuthError::SessionService)?; + let token = sessions.create_token(player.id); Ok(Json(TokenResponse { token })) } @@ -161,9 +150,7 @@ pub async fn create( impl IntoResponse for AuthError { fn into_response(self) -> Response { let status_code = match &self { - Self::Database(_) | Self::PasswordHash(_) | Self::SessionService(_) => { - StatusCode::INTERNAL_SERVER_ERROR - } + Self::Database(_) | Self::PasswordHash(_) => StatusCode::INTERNAL_SERVER_ERROR, Self::InvalidCredentails | Self::OriginAccess => StatusCode::UNAUTHORIZED, Self::EmailTaken | Self::InvalidUsername => StatusCode::BAD_REQUEST, Self::RegistrationDisabled => StatusCode::FORBIDDEN, diff --git a/src/routes/games.rs b/src/routes/games.rs index 65fc34e3..0b250c9b 100644 --- a/src/routes/games.rs +++ b/src/routes/games.rs @@ -1,10 +1,7 @@ use crate::{ database::entities::players::PlayerRole, middleware::auth::Auth, - services::game::{ - manager::{GameManager, GetGameMessage, SnapshotQueryMessage}, - GameSnapshot, SnapshotMessage, - }, + services::game::{manager::GameManager, GameSnapshot}, utils::types::GameID, }; use axum::{ @@ -13,8 +10,8 @@ use axum::{ response::{IntoResponse, Response}, Extension, Json, }; -use interlink::prelude::{Link, LinkError}; use serde::{Deserialize, Serialize}; +use std::sync::Arc; use thiserror::Error; /// Errors that could occur while working with game endpoints @@ -23,9 +20,6 @@ pub enum GamesError { /// The requested game could not be found (For specific game lookup) #[error("Game not found")] NotFound, - /// Something went wrong with the link to the games service - #[error("Failed to access games service")] - Link(#[from] LinkError), } /// Response type alias for JSON responses with GamesError @@ -63,22 +57,19 @@ pub struct GamesResponse { /// players with admin level or greater access. pub async fn get_games( Query(query): Query, - Extension(game_manager): Extension>, + Extension(game_manager): Extension>, Auth(auth): Auth, ) -> GamesRes { let GamesRequest { offset, count } = query; let count: usize = count.unwrap_or(20) as usize; let offset: usize = offset * count; + let include_net = auth.role >= PlayerRole::Admin; // Retrieve the game snapshots let (games, more) = game_manager - .send(SnapshotQueryMessage { - offset, - count, - include_net: auth.role >= PlayerRole::Admin, - }) - .await?; + .create_snapshot(offset, count, include_net) + .await; Ok(Json(GamesResponse { games, more })) } @@ -92,19 +83,15 @@ pub async fn get_games( /// players with admin level or greater access. pub async fn get_game( Path(game_id): Path, - Extension(game_manager): Extension>, + Extension(game_manager): Extension>, Auth(auth): Auth, ) -> GamesRes { let game = game_manager - .send(GetGameMessage { game_id }) - .await? + .get_game(game_id) + .await .ok_or(GamesError::NotFound)?; - - let snapshot = game - .send(SnapshotMessage { - include_net: auth.role >= PlayerRole::Admin, - }) - .await?; + let game = &*game.read().await; + let snapshot = game.snapshot(auth.role >= PlayerRole::Admin); Ok(Json(snapshot)) } @@ -114,7 +101,6 @@ impl IntoResponse for GamesError { fn into_response(self) -> Response { let status_code = match &self { Self::NotFound => StatusCode::NOT_FOUND, - Self::Link(_) => StatusCode::INTERNAL_SERVER_ERROR, }; (status_code, self.to_string()).into_response() diff --git a/src/routes/gaw.rs b/src/routes/gaw.rs index 6c499538..4eefcd5a 100644 --- a/src/routes/gaw.rs +++ b/src/routes/gaw.rs @@ -11,7 +11,7 @@ use crate::{ DatabaseConnection, DbErr, DbResult, }, middleware::xml::Xml, - services::sessions::{Sessions, VerifyTokenMessage}, + services::sessions::Sessions, utils::parsing::PlayerClass, }; use axum::{ @@ -20,7 +20,7 @@ use axum::{ response::{IntoResponse, Response}, Extension, }; -use interlink::prelude::Link; +use indoc::formatdoc; use serde::Deserialize; use std::{fmt::Display, sync::Arc}; use tokio::try_join; @@ -55,38 +55,36 @@ pub struct AuthQuery { /// /// `query` The query containing the auth token (In this case the hex player ID) pub async fn shared_token_login(Query(query): Query) -> Xml { - let response = format!( - r#" - - 0 - - 0 - - - - - 0 - {} - - - - - - - UNKNOWN - 0 - BLAZE_EXTERNAL_REF_TYPE_UNKNOWN - - - - 0 - - - -"#, - query.auth - ); - Xml(response) + Xml(formatdoc! {r#" + + + 0 + + 0 + + + + + 0 + {} + + + + + + + UNKNOWN + 0 + BLAZE_EXTERNAL_REF_TYPE_UNKNOWN + + + + 0 + + + + + "# ,query.auth}) } /// GET /galaxyatwar/getRatings/:id @@ -99,7 +97,7 @@ pub async fn get_ratings( Path(id): Path, Extension(db): Extension, Extension(config): Extension>, - Extension(sessions): Extension>, + Extension(sessions): Extension>, ) -> Result { let (gaw_data, promotions) = get_player_gaw_data(&db, sessions, &id, &config).await?; Ok(ratings_response(gaw_data, promotions)) @@ -139,11 +137,11 @@ pub async fn increase_ratings( Query(query): Query, Extension(db): Extension, Extension(config): Extension>, - Extension(sessions): Extension>, + Extension(sessions): Extension>, ) -> Result { let (gaw_data, promotions) = get_player_gaw_data(&db, sessions, &id, &config).await?; let gaw_data = gaw_data - .increase(&db, (query.a, query.b, query.c, query.d, query.e)) + .add(&db, [query.a, query.b, query.c, query.d, query.e]) .await?; Ok(ratings_response(gaw_data, promotions)) } @@ -155,14 +153,12 @@ pub async fn increase_ratings( /// `id` The hex ID of the player async fn get_player_gaw_data( db: &DatabaseConnection, - sessions: Link, + sessions: Arc, token: &str, config: &RuntimeConfig, ) -> Result<(GalaxyAtWar, u32), GAWError> { let player_id = sessions - .send(VerifyTokenMessage(token.to_string())) - .await - .map_err(|_| GAWError::ServerError)? + .verify_token(token) .map_err(|_| GAWError::InvalidToken)?; let player = Player::by_id(db, player_id) @@ -170,9 +166,11 @@ async fn get_player_gaw_data( .ok_or(GAWError::InvalidToken)?; let (gaw_data, promotions) = try_join!( - GalaxyAtWar::find_or_create(db, player.id, config.galaxy_at_war.decay), + GalaxyAtWar::get(db, player.id), get_promotions(db, &player, config) )?; + let gaw_data = gaw_data.apply_decay(db, config.galaxy_at_war.decay).await?; + Ok((gaw_data, promotions)) } @@ -208,32 +206,31 @@ fn ratings_response(ratings: GalaxyAtWar, promotions: u32) -> Xml { // Calculate the average value for the level let level = (a + b + c + d + e) / 5; - let response = format!( - r#" - - - {a} - {b} - {c} - {d} - {e} - - {level} - - {promotions} - 0 - 0 - 0 - 0 - 0 - 0 - 0 - 0 - 0 - -"# - ); - Xml(response) + Xml(formatdoc! {r#" + + + + {a} + {b} + {c} + {d} + {e} + + {level} + + {promotions} + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + + + "#}) } /// Display implementation for the GAWError this will be displayed @@ -279,8 +276,7 @@ impl IntoResponse for GAWError { GAWError::ServerError => StatusCode::INTERNAL_SERVER_ERROR, }; - let mut response = self.to_string().into_response(); - *response.status_mut() = status; + let mut response = (status, self.to_string()).into_response(); response .headers_mut() .insert(header::CONTENT_TYPE, HeaderValue::from_static("text/xml")); diff --git a/src/routes/leaderboard.rs b/src/routes/leaderboard.rs index c1d518ec..9d2bcb0f 100644 --- a/src/routes/leaderboard.rs +++ b/src/routes/leaderboard.rs @@ -1,5 +1,7 @@ +use std::sync::Arc; + use crate::{ - services::leaderboard::{models::*, Leaderboard, QueryMessage}, + services::leaderboard::{models::*, Leaderboard}, utils::types::PlayerID, }; use axum::{ @@ -8,7 +10,6 @@ use axum::{ response::{IntoResponse, Response}, Extension, Json, }; -use interlink::prelude::{Link, LinkError}; use sea_orm::DatabaseConnection; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -21,15 +22,9 @@ pub enum LeaderboardError { /// The provided query range was out of bounds on the underlying query #[error("Unacceptable query range")] InvalidRange, - /// Something went wrong with the link to the leaderboard service - #[error("Failed to access leaderboard service")] - Link(#[from] LinkError), /// The requested player was not found in the leaderboard #[error("Player not found")] PlayerNotFound, - /// Error for when a unknown leaderboard is requested - #[error("Unknown leaderboard")] - UnknownLeaderboard, } /// Structure of a query requesting a specific leaderboard contains @@ -66,16 +61,13 @@ pub struct LeaderboardResponse<'a> { /// `name` The name of the leaderboard type to query /// `query` The leaderboard query pub async fn get_leaderboard( - Path(name): Path, + Path(ty): Path, Extension(db): Extension, - Extension(leaderboard): Extension>, + Extension(leaderboard): Extension>, Query(query): Query, ) -> Result { let LeaderboardQuery { offset, count } = query; - let ty: LeaderboardType = - LeaderboardType::try_parse(&name).ok_or(LeaderboardError::UnknownLeaderboard)?; - /// The default number of entries to return in a leaderboard response const DEFAULT_COUNT: u8 = 40; @@ -84,12 +76,14 @@ pub async fn get_leaderboard( // Calculate the start and ending indexes let start: usize = offset * count; - let group = leaderboard.send(QueryMessage(ty, db)).await?; + let group = leaderboard.query(ty, &db).await; - let (entries, more) = group + let entries = group .get_normal(start, count) .ok_or(LeaderboardError::InvalidRange)?; + let more = group.has_more(start, count); + let response = Json(LeaderboardResponse { total: group.values.len(), entries, @@ -107,13 +101,11 @@ pub async fn get_leaderboard( /// `name` The name of the leaderboard type to query /// `player_id` The ID of the player to find the leaderboard ranking of pub async fn get_player_ranking( + Path((ty, player_id)): Path<(LeaderboardType, PlayerID)>, Extension(db): Extension, - Extension(leaderboard): Extension>, - Path((name, player_id)): Path<(String, PlayerID)>, + Extension(leaderboard): Extension>, ) -> Result { - let ty: LeaderboardType = - LeaderboardType::try_parse(&name).ok_or(LeaderboardError::UnknownLeaderboard)?; - let group = leaderboard.send(QueryMessage(ty, db)).await?; + let group = leaderboard.query(ty, &db).await; let entry = match group.get_entry(player_id) { Some(value) => value, @@ -130,9 +122,8 @@ impl IntoResponse for LeaderboardError { #[inline] fn into_response(self) -> Response { let status = match &self { - Self::PlayerNotFound | Self::UnknownLeaderboard => StatusCode::NOT_FOUND, + Self::PlayerNotFound => StatusCode::NOT_FOUND, Self::InvalidRange => StatusCode::BAD_REQUEST, - Self::Link(_) => StatusCode::INTERNAL_SERVER_ERROR, }; (status, self.to_string()).into_response() } diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 8f11a896..11b5106e 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -6,6 +6,8 @@ use axum::{ use crate::middleware::cors::cors_layer; +use self::server::clear_log; + mod auth; mod games; mod gaw; @@ -31,7 +33,13 @@ pub fn router() -> Router { .route("/increaseRatings/:id", get(gaw::increase_ratings)), ) // Quality of service - .route("/qos/qos", get(qos::qos)) + .nest( + "/qos", + Router::new() + .route("/qos", get(qos::qos)) + .route("/firewall", get(qos::firewall)) + .route("/firetype", get(qos::firetype)), + ) // Dashboard API .nest( "/api", @@ -86,7 +94,7 @@ pub fn router() -> Router { "/server", Router::new() .route("/", get(server::server_details)) - .route("/log", get(server::get_log)) + .route("/log", get(server::get_log).delete(clear_log)) .route("/upgrade", get(server::upgrade)) .route("/telemetry", post(server::submit_telemetry)) .route("/dashboard", get(server::dashboard_details)), diff --git a/src/routes/players.rs b/src/routes/players.rs index 97e6282b..fdc944cc 100644 --- a/src/routes/players.rs +++ b/src/routes/players.rs @@ -494,15 +494,16 @@ pub async fn set_data( Path((player_id, key)): Path<(PlayerID, String)>, Extension(db): Extension, Json(req): Json, -) -> PlayersRes { +) -> PlayersResult<()> { let player: Player = find_player(&db, player_id).await?; if !auth.has_permission_over(&player) { return Err(PlayersError::InvalidPermission); } - let data = PlayerData::set(&db, player.id, key, req.value).await?; - Ok(Json(data)) + PlayerData::set(&db, player.id, key.clone(), req.value).await?; + + Ok(()) } /// DELETE /api/players/:id/data/:key @@ -541,7 +542,7 @@ pub async fn get_player_gaw( Extension(db): Extension, ) -> PlayersRes { let player = find_player(&db, player_id).await?; - let galax_at_war = GalaxyAtWar::find_or_create(&db, player.id, 0.0).await?; + let galax_at_war = GalaxyAtWar::get(&db, player.id).await?; Ok(Json(galax_at_war)) } diff --git a/src/routes/public.rs b/src/routes/public.rs index ab535212..4027b277 100644 --- a/src/routes/public.rs +++ b/src/routes/public.rs @@ -4,6 +4,7 @@ use axum::{ response::{IntoResponse, Response}, }; use embeddy::Embedded; +use futures_util::future::BoxFuture; use hyper::{header::CONTENT_TYPE, StatusCode}; use std::{ convert::Infallible, @@ -12,8 +13,6 @@ use std::{ }; use tower::Service; -use crate::utils::types::BoxFuture; - /// Resources embedded from the public data folder such as the /// dashboard static assets and the content for the ingame store. /// diff --git a/src/routes/qos.rs b/src/routes/qos.rs index 9ad4f273..1040a7a5 100644 --- a/src/routes/qos.rs +++ b/src/routes/qos.rs @@ -1,8 +1,10 @@ //! Routes for the Quality of Service server. Unknown whether any of the //! response address and ports are correct however this request must succeed //! or the client doesn't seem to know its external IP -use crate::{middleware::xml::Xml, utils::models::Port}; + +use crate::middleware::xml::Xml; use axum::extract::Query; +use indoc::formatdoc; use log::debug; use serde::Deserialize; @@ -12,6 +14,8 @@ pub struct QosQuery { /// The port the client is using #[serde(rename = "prpt")] port: u16, + + qtype: u8, } /// GET /qos/qos @@ -37,19 +41,76 @@ pub async fn qos(Query(query): Query) -> Xml { debug!("Recieved QOS query: (Port: {})", query.port); /// Port for the local Quality of Service server - const QOS_PORT: Port = 42130; + const QOS_PORT: u16 = 42130; + // const QOS_PORT: u16 = 17499; const IP: u32 = u32::from_be_bytes([127, 0, 0, 1]); + // const IP: u32 = 2733913518; + + if query.qtype == 1 { + Xml(formatdoc! {r#" + + + 0 + {} + 0 + {} + 1 + 0 + + "#, QOS_PORT, IP + }) + } else { + Xml(formatdoc! {r#" + + + 10 + {} + 1200 + {} + 1 + 1 + + "#, QOS_PORT, IP + }) + } +} - let response = format!( - r#" 0 - {} - 0 - 127.0.0.1 - {} - 1 - 0 -"#, - QOS_PORT, IP - ); - Xml(response) +/// GET /qos/firewall +/// +/// Called by game: /qos/firewall?vers=1&nint=2 +pub async fn firewall() -> Xml { + // TODO: Appears to point to two other servers + // 162.244.53.174 + // 162.244.53.175 + Xml(formatdoc! {r#" + + + + 2733913518 + 2733913519 + + 2 + + 17500 + 17501 + + 747 + 502 + + "# + }) +} + +/// GET /qos/firetype +/// +/// Called by game: /qos/firetype?vers=1&rqid=768&rqsc=526&inip=-1062708997&inpt=3659 +pub async fn firetype() -> Xml { + // TODO: Firetype + // my wifi is: 2 cloudflare vpn is: 4 + Xml(formatdoc! {r#" + + + 2 + + "#}) } diff --git a/src/routes/server.rs b/src/routes/server.rs index e57191fe..4e2f92f3 100644 --- a/src/routes/server.rs +++ b/src/routes/server.rs @@ -4,28 +4,21 @@ use crate::{ config::{RuntimeConfig, VERSION}, database::entities::players::PlayerRole, - middleware::{auth::AdminAuth, blaze_upgrade::BlazeUpgrade, ip_address::IpAddress}, - services::{game::manager::GameManager, sessions::Sessions}, - session::{packet::PacketCodec, router::BlazeRouter, Session}, + middleware::{auth::AdminAuth, ip_address::IpAddress, upgrade::Upgrade}, + services::sessions::Sessions, + session::{router::BlazeRouter, Session}, utils::logging::LOG_FILE_NAME, }; use axum::{ - body::Empty, - http::{header, HeaderValue, StatusCode}, + http::{header, StatusCode}, response::{IntoResponse, Response}, Extension, Json, }; -use interlink::{prelude::Link, service::Service}; +use hyper::upgrade::OnUpgrade; use log::{debug, error}; use serde::{Deserialize, Serialize}; -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, -}; -use tokio::{fs::read_to_string, io::split}; -use tokio_util::codec::{FramedRead, FramedWrite}; - -static SESSION_IDS: AtomicU32 = AtomicU32::new(1); +use std::{net::Ipv4Addr, sync::Arc}; +use tokio::fs::{read_to_string, OpenOptions}; /// Response detailing the information about this Pocket Relay server /// contains the version information as well as the server information @@ -77,69 +70,88 @@ pub async fn dashboard_details( pub async fn upgrade( IpAddress(addr): IpAddress, Extension(router): Extension>, - Extension(game_manager): Extension>, - Extension(sessions): Extension>, - upgrade: BlazeUpgrade, + Extension(sessions): Extension>, + Upgrade(upgrade): Upgrade, ) -> Response { - // TODO: Socket address extraction for forwarded reverse proxy - - tokio::spawn(async move { - let socket = match upgrade.upgrade().await { - Ok(value) => value, - Err(err) => { - error!("Failed to upgrade blaze socket: {}", err); - return; - } - }; - Session::create(|ctx| { - // Obtain a session ID - let session_id = SESSION_IDS.fetch_add(1, Ordering::AcqRel); - - // Attach reader and writers to the session context - let (read, write) = split(socket.upgrade); - let read = FramedRead::new(read, PacketCodec); - let write = FramedWrite::new(write, PacketCodec); - - ctx.attach_stream(read, true); - let writer = ctx.attach_sink(write); - - Session::new( - session_id, - socket.host_target, - writer, - addr, - router, - game_manager, - sessions, - ) - }); - }); - - let mut response = Empty::new().into_response(); - // Use the switching protocols status code - *response.status_mut() = StatusCode::SWITCHING_PROTOCOLS; - - let headers = response.headers_mut(); - // Add the upgraidng headers - headers.insert(header::CONNECTION, HeaderValue::from_static("upgrade")); - headers.insert(header::UPGRADE, HeaderValue::from_static("blaze")); - - response + // Spawn the upgrading process to its own task + tokio::spawn(handle_upgrade(upgrade, addr, router, sessions)); + + // Let the client know to upgrade its connection + ( + // Switching protocols status code + StatusCode::SWITCHING_PROTOCOLS, + // Headers required for upgrading + [(header::CONNECTION, "upgrade"), (header::UPGRADE, "blaze")], + ) + .into_response() +} + +/// Handles upgrading a connection and starting a new session +/// from the connection +pub async fn handle_upgrade( + upgrade: OnUpgrade, + addr: Ipv4Addr, + router: Arc, + sessions: Arc, +) { + let upgraded = match upgrade.await { + Ok(upgraded) => upgraded, + Err(err) => { + error!("Failed to upgrade client connection: {}", err); + return; + } + }; + + Session::start(upgraded, addr, router, sessions); } /// GET /api/server/log /// -/// Handles loading and responding with the server log file -/// contents for the log section on the super admin portion -/// of the dashboard +/// Responds with the server log file contents +/// +/// Requires super admin authentication pub async fn get_log(AdminAuth(auth): AdminAuth) -> Result { if auth.role < PlayerRole::SuperAdmin { return Err(StatusCode::FORBIDDEN); } let path = std::path::Path::new(LOG_FILE_NAME); - read_to_string(path) + read_to_string(path).await.map_err(|err| { + error!("Failed to read server log file: {}", err); + StatusCode::INTERNAL_SERVER_ERROR + }) +} + +/// DELETE /api/server/log +/// +/// Truncates the server log file, useful for long log files that +/// are starting to take up lots of space or have out-served their +/// usefulness +/// +/// Requires super admin authentication +pub async fn clear_log(AdminAuth(auth): AdminAuth) -> Result<(), StatusCode> { + if auth.role < PlayerRole::SuperAdmin { + return Err(StatusCode::FORBIDDEN); + } + + let path = std::path::Path::new(LOG_FILE_NAME); + + // Open the file + let file = OpenOptions::new() + .write(true) + .open(path) .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) + .map_err(|err| { + error!("Failed to open server log file: {}", err); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Truncate the file + file.set_len(0).await.map_err(|err| { + error!("Failed to truncate server log file: {}", err); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + Ok(()) } /// Structure of a telemetry message coming from a client diff --git a/src/services/game/manager.rs b/src/services/game/manager.rs index 98299cda..7b329045 100644 --- a/src/services/game/manager.rs +++ b/src/services/game/manager.rs @@ -1,50 +1,38 @@ -use super::{ - models::{DatalessContext, GameSettings, GameSetupContext, PlayerState}, - rules::RuleSet, - AddPlayerMessage, AttrMap, CheckJoinableMessage, Game, GameJoinableState, GamePlayer, - GameSnapshot, -}; +use super::{rules::RuleSet, AttrMap, Game, GameJoinableState, GamePlayer, GameRef, GameSnapshot}; use crate::{ - services::game::models::AsyncMatchmakingStatus, - session::{packet::Packet, PushExt}, + session::{ + models::game_manager::{ + AsyncMatchmakingStatus, GameSettings, GameSetupContext, MatchmakingResult, + }, + packet::Packet, + }, utils::{ components::game_manager, + hashing::IntHashMap, types::{GameID, PlayerID}, }, }; -use interlink::prelude::*; -use log::debug; +use log::{debug, warn}; use std::{ - collections::{HashMap, VecDeque}, - ops::DerefMut, - sync::Arc, - time::SystemTime, + collections::VecDeque, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + time::{Duration, SystemTime}, }; use tokio::{sync::RwLock, task::JoinSet}; /// Manager which controls all the active games on the server /// commanding them to do different actions and removing them /// once they are no longer used -#[derive(Service)] pub struct GameManager { /// The map of games to the actual game address - games: HashMap>, + games: RwLock>, /// Stored value for the ID to give the next game - next_id: GameID, + next_id: AtomicU32, /// Matchmaking entry queue - queue: Arc>>, -} - -impl GameManager { - /// Starts a new game manager service returning its link - pub fn start() -> Link { - let this = GameManager { - games: Default::default(), - next_id: 1, - queue: Default::default(), - }; - this.start() - } + queue: RwLock>, } /// Entry into the matchmaking queue @@ -57,351 +45,260 @@ struct MatchmakingEntry { started: SystemTime, } -/// Message to remove a player from the matchmaking queue -#[derive(Message)] -pub struct RemoveQueueMessage { - /// The player ID of the player to remove - pub player_id: PlayerID, -} - -impl Handler for GameManager { - /// Empty response type - type Response = Fr; - - fn handle( - &mut self, - msg: RemoveQueueMessage, - _ctx: &mut ServiceContext, - ) -> Self::Response { - let queue_handle = self.queue.clone(); - Fr::new(Box::pin(async move { - let mut queue = queue_handle.write().await; - queue.retain(|value| value.player.player.id != msg.player_id); - })) - } -} - -/// Message to add a new player to the matchmaking queue -#[derive(Message)] -pub struct QueuePlayerMessage { - /// The player to add to the queue - pub player: GamePlayer, - /// The rules for the player - pub rule_set: Arc, -} +const DEFAULT_FIT: u16 = 21600; -impl Handler for GameManager { - /// Empty response type - type Response = Fr; +impl GameManager { + /// Max number of times to poll a game for shutdown before erroring + const MAX_RELEASE_ATTEMPTS: u8 = 5; - fn handle( - &mut self, - msg: QueuePlayerMessage, - _ctx: &mut ServiceContext, - ) -> Self::Response { - let started = SystemTime::now(); - let queue_handle = self.queue.clone(); - Fr::new(Box::pin(async move { - let mut queue = queue_handle.write().await; - queue.push_back(MatchmakingEntry { - player: msg.player, - rule_set: msg.rule_set, - started, - }); - })) + /// Starts a new game manager service returning its link + pub fn new() -> Self { + Self { + games: Default::default(), + next_id: AtomicU32::new(1), + queue: Default::default(), + } } -} - -/// Message for taking a snapshot of multiple games -/// within the specified query range -#[derive(Message)] -#[msg(rtype = "(Vec, bool)")] -pub struct SnapshotQueryMessage { - /// The offset to start querying games from - pub offset: usize, - /// The number of games to query - pub count: usize, - /// Whether to include sensitively player net info - pub include_net: bool, -} - -/// Handler for snapshot query messages -impl Handler for GameManager { - type Response = Fr; - - fn handle( - &mut self, - msg: SnapshotQueryMessage, - _ctx: &mut ServiceContext, - ) -> Self::Response { - let SnapshotQueryMessage { - offset, - count, - include_net, - } = msg; + pub async fn create_snapshot( + &self, + offset: usize, + count: usize, + include_net: bool, + ) -> (Vec, bool) { // Create the futures using the handle action before passing // them to a future to be awaited let mut join_set = JoinSet::new(); - // Obtained an order set of the keys from the games map - let mut keys: Vec<&GameID> = self.games.keys().collect(); - keys.sort(); - - // Whether there is more keys that what was requested - let more = keys.len() > offset + count; - - // Spawn tasks for obtaining snapshots to each game - keys.into_iter() - // Skip to the desired offset - .skip(offset) - // Take the desired number of keys - .take(count) - // Take the game links for the keys - .filter_map(|key| self.games.get(key)) - // Clone the obtained game links - .cloned() - // Spawn the snapshot tasks - .for_each(|game| { - join_set - .spawn(async move { game.send(super::SnapshotMessage { include_net }).await }); - }); - - Fr::new(Box::pin(async move { - // Allocate a list for the snapshots - let mut snapshots = Vec::with_capacity(join_set.len()); - - // Recieve all the snapshots from their tasks - while let Some(result) = join_set.join_next().await { - if let Ok(Ok(snapshot)) = result { - snapshots.push(snapshot); - } - } - - (snapshots, more) - })) - } -} - -/// Message for creating a new game using the game manager -/// responds with a link to the created game and its ID -#[derive(Message)] -#[msg(rtype = "(Link, GameID)")] -pub struct CreateMessage { - /// The initial game attributes - pub attributes: AttrMap, - /// The initial game setting - pub setting: GameSettings, - /// The host player for the game - pub host: GamePlayer, -} + let more = { + let games = &*self.games.read().await; + + // Create an ordered set + let mut items: Vec<(&GameID, &GameRef)> = games.iter().collect(); + items.sort_by_key(|(key, _)| *key); + + // Whether there is more keys that what was requested + let more = items.len() > offset + count; + + // Spawn tasks for obtaining snapshots to each game + items + .into_iter() + // Skip to the desired offset + .skip(offset) + // Take the desired number of keys + .take(count) + // Iterate over the game links + .map(|(_, value)| value.clone()) + // Spawn the snapshot tasks + .for_each(|game| { + join_set.spawn(async move { + let game = &*game.read().await; + game.snapshot(include_net) + }); + }); -/// Handler for creating games -impl Handler for GameManager { - type Response = Mr; + more + }; - fn handle(&mut self, mut msg: CreateMessage, ctx: &mut ServiceContext) -> Self::Response { - let id = self.next_id; + // Allocate a list for the snapshots + let mut snapshots = Vec::with_capacity(join_set.len()); - self.next_id = self.next_id.wrapping_add(1); + // Recieve all the snapshots from their tasks + while let Some(result) = join_set.join_next().await { + if let Ok(snapshot) = result { + snapshots.push(snapshot); + } + } - msg.host.state = PlayerState::ActiveConnected; + (snapshots, more) + } - let link = Game::start(id, msg.attributes, msg.setting, ctx.link()); - self.games.insert(id, link.clone()); + pub async fn remove_queue(&self, player_id: PlayerID) { + let queue = &mut *self.queue.write().await; + queue.retain(|value| value.player.player.id != player_id); + } - let _ = link.do_send(AddPlayerMessage { - player: msg.host, - context: GameSetupContext::Dataless(DatalessContext::CreateGameSetup), + pub async fn queue(&self, player: GamePlayer, rule_set: Arc) { + let started = SystemTime::now(); + let queue = &mut *self.queue.write().await; + queue.push_back(MatchmakingEntry { + player, + rule_set, + started, }); - - Mr((link, id)) } -} - -/// Message for requesting a link to a game with the provided -/// ID responds with a link to the game if it exists -#[derive(Message)] -#[msg(rtype = "Option>")] -pub struct GetGameMessage { - /// The ID of the game to get a link to - pub game_id: GameID, -} -/// Handler for getting a specific game -impl Handler for GameManager { - type Response = Mr; + pub async fn add_to_game( + &self, + game_ref: GameRef, + player: GamePlayer, + context: GameSetupContext, + ) { + let player_link = player.link.clone(); + + // Add the player to the game + let game_id = { + let game = &mut *game_ref.write().await; + game.add_player(player, context); + game.id + }; - fn handle(&mut self, msg: GetGameMessage, _ctx: &mut ServiceContext) -> Self::Response { - let link = self.games.get(&msg.game_id).cloned(); - Mr(link) + // Update the player current game + player_link.set_game(game_id, game_ref).await; } -} -/// Message for attempting to add a player to any existing -/// games within this game manager -#[derive(Message)] -#[msg(rtype = "TryAddResult")] -pub struct TryAddMessage { - /// The player to attempt to add - pub player: GamePlayer, - /// The set of rules the player requires the game has - pub rule_set: Arc, -} + pub async fn add_from_matchmaking(&self, game_ref: GameRef, player: GamePlayer) { + let msid = player.player.id; + + // MUST be sent to players atleast once when matchmaking otherwise it may fail + player.link.push(Packet::notify( + game_manager::COMPONENT, + game_manager::MATCHMAKING_ASYNC_STATUS, + AsyncMatchmakingStatus { player_id: msid }, + )); + + // Add the player to the game + self.add_to_game( + game_ref, + player, + GameSetupContext::Matchmaking { + fit_score: DEFAULT_FIT, + max_fit_score: DEFAULT_FIT, + session_id: msid, + result: MatchmakingResult::JoinedExistingGame, + player_id: msid, + }, + ) + .await; + } -/// Result of attempting to add a player. Success will -/// consume the game player and Failure will return the -/// game player back -pub enum TryAddResult { - /// The player was added to the game - Success, - /// The player failed to be added and was returned back - Failure(GamePlayer), -} + pub async fn create_game( + self: &Arc, + attributes: AttrMap, + setting: GameSettings, + ) -> (GameRef, GameID) { + let id = self.next_id.fetch_add(1, Ordering::AcqRel); + let game = Game::new(id, attributes, setting, self.clone()); + let link = Arc::new(RwLock::new(game)); + { + let games = &mut *self.games.write().await; + games.insert(id, link.clone()); + } -/// Handler for attempting to add a player -impl Handler for GameManager { - type Response = Fr; + (link, id) + } - fn handle(&mut self, msg: TryAddMessage, _ctx: &mut ServiceContext) -> Self::Response { - // Take a copy of the current games list - let games = self.games.clone(); + pub async fn get_game(&self, game_id: GameID) -> Option { + let games = &*self.games.read().await; + games.get(&game_id).cloned() + } - Fr::new(Box::pin(async move { - let player = msg.player; + pub async fn try_add(&self, player: GamePlayer, rule_set: &RuleSet) -> Result<(), GamePlayer> { + let games = &*self.games.read().await; - // Message asking for the game joinable state - let msg = CheckJoinableMessage { - rule_set: Some(msg.rule_set), + // Attempt to find a game thats joinable + for (id, link) in games { + let join_state = { + let link = &*link.read().await; + link.joinable_state(Some(rule_set)) }; - // Attempt to find a game thats joinable - for (id, link) in games { - // Check if the game is joinable - if let Ok(GameJoinableState::Joinable) = link.send(msg.clone()).await { - debug!("Found matching game (GID: {})", id); - let msid = player.player.id; - let _ = link.do_send(AddPlayerMessage { - player, - context: GameSetupContext::Matchmaking(msid), - }); - return TryAddResult::Success; - } + // Check if the game is joinable + if let GameJoinableState::Joinable = join_state { + debug!("Found matching game (GID: {})", id); + + // Add the player to the game + self.add_from_matchmaking(link.clone(), player).await; + + return Ok(()); } + } - TryAddResult::Failure(player) - })) + Err(player) } -} -/// Message for removing a game from the manager -#[derive(Message)] -pub struct RemoveGameMessage { - /// The ID of the game to remove - pub game_id: GameID, -} + pub async fn remove_game(&self, game_id: GameID) { + let games = &mut *self.games.write().await; + if let Some(mut game) = games.remove(&game_id) { + let mut attempt: u8 = 1; + + // Attempt to obtain the owned game + let game = loop { + if attempt > Self::MAX_RELEASE_ATTEMPTS { + let references = Arc::strong_count(&game); + warn!( + "Failed to stop game {} there are still {} references to it", + game_id, references + ); + return; + } -/// Handler for removing a game -impl Handler for GameManager { - type Response = (); + match Arc::try_unwrap(game) { + Ok(value) => break value, + Err(arc) => { + let wait = 5 * attempt as u64; + let references = Arc::strong_count(&arc); + debug!( + "Game {} still has {} references to it, waiting {}s", + game_id, references, wait + ); + tokio::time::sleep(Duration::from_secs(wait)).await; + game = arc; + attempt += 1; + continue; + } + } + }; - fn handle(&mut self, msg: RemoveGameMessage, _ctx: &mut ServiceContext) { - // Remove the game - if let Some(value) = self.games.remove(&msg.game_id) { - value.stop(); + let game = game.into_inner(); + game.stopped(); } } -} -/// Process the contents of the matchmaking queue against -/// a game link -#[derive(Message)] -pub struct ProcessQueueMessage { - pub link: Link, - pub game_id: GameID, -} - -impl Handler for GameManager { - type Response = Fr; - - fn handle( - &mut self, - msg: ProcessQueueMessage, - _ctx: &mut ServiceContext, - ) -> Self::Response { - let queue_handle = self.queue.clone(); - - Fr::new(Box::pin(async move { - let mut queue = queue_handle.write().await; - let queue = queue.deref_mut(); - if queue.is_empty() { - return; - } + pub async fn process_queue(&self, link: GameRef, game_id: GameID) { + let queue = &mut *self.queue.write().await; + if queue.is_empty() { + return; + } - let link = msg.link; - - while let Some(entry) = queue.front() { - let join_state = match link - .send(CheckJoinableMessage { - rule_set: Some(entry.rule_set.clone()), - }) - .await - { - Ok(value) => value, - // Game is no longer available stop checking - Err(_) => break, - }; - - // TODO: If player has been in queue long enough create - // a game matching their specifics - - match join_state { - GameJoinableState::Joinable => { - let entry = queue - .pop_front() - .expect("Expecting matchmaking entry but nothing was present"); + while let Some(entry) = queue.front() { + let join_state = { + let link = &*link.read().await; + link.joinable_state(Some(&entry.rule_set)) + }; - debug!( - "Found player from queue adding them to the game (GID: {})", - msg.game_id - ); - let time = SystemTime::now(); - let elapsed = time.duration_since(entry.started); - if let Ok(elapsed) = elapsed { - debug!("Matchmaking time elapsed: {}s", elapsed.as_secs()) - } - - let msid = entry.player.player.id; - - // Send the async update (TODO: Do this at intervals) - entry.player.link.push(Packet::notify( - game_manager::COMPONENT, - game_manager::MATCHMAKING_ASYNC_STATUS, - AsyncMatchmakingStatus { player_id: msid }, - )); - - // Add the player to the game - if link - .do_send(AddPlayerMessage { - player: entry.player, - context: GameSetupContext::Matchmaking(msid), - }) - .is_err() - { - break; - } - } - GameJoinableState::Full => { - // If the game is not joinable push the entry back to the - // front of the queue and early return - break; - } - GameJoinableState::NotMatch => { - // TODO: Check started time and timeout - // player if they've been waiting too long + // TODO: If player has been in queue long enough create + // a game matching their specifics + + match join_state { + GameJoinableState::Joinable => { + let entry = queue + .pop_front() + .expect("Expecting matchmaking entry but nothing was present"); + + debug!( + "Found player from queue adding them to the game (GID: {})", + game_id + ); + let time = SystemTime::now(); + let elapsed = time.duration_since(entry.started); + if let Ok(elapsed) = elapsed { + debug!("Matchmaking time elapsed: {}s", elapsed.as_secs()) } + + // Add the player to the game + self.add_from_matchmaking(link.clone(), entry.player).await; + } + GameJoinableState::Full | GameJoinableState::Stopping => { + // If the game is not joinable push the entry back to the + // front of the queue and early return + break; + } + GameJoinableState::NotMatch => { + // TODO: Check started time and timeout + // player if they've been waiting too long } } - })) + } } } diff --git a/src/services/game/mod.rs b/src/services/game/mod.rs index ff515739..8475deb8 100644 --- a/src/services/game/mod.rs +++ b/src/services/game/mod.rs @@ -1,77 +1,49 @@ use self::{manager::GameManager, rules::RuleSet}; use crate::{ database::entities::Player, - services::game::manager::{ProcessQueueMessage, RemoveGameMessage}, session::{ - packet::Packet, router::RawBlaze, DetailsMessage, InformSessions, PushExt, Session, - SetGameMessage, + models::game_manager::{ + AdminListChange, AdminListOperation, AttributesChange, GameSettings, GameSetupContext, + GameSetupResponse, GameState, GetGameDetails, HostMigrateFinished, HostMigrateStart, + JoinComplete, PlayerJoining, PlayerRemoved, PlayerState, PlayerStateChange, + RemoveReason, SettingChange, StateChange, + }, + packet::Packet, + router::RawBlaze, + NetData, SessionLink, }, utils::{ - components::{game_manager, user_sessions}, - models::NetData, + components::game_manager, types::{GameID, PlayerID}, }, }; -use interlink::prelude::*; -use log::debug; -use models::*; +use log::{debug, warn}; use serde::Serialize; use std::sync::Arc; -use tdf::{ObjectId, TdfMap, TdfSerialize, TdfSerializer}; +use tdf::{ObjectId, TdfMap, TdfSerializer}; +use tokio::sync::RwLock; pub mod manager; -pub mod models; pub mod rules; +pub type GameRef = Arc>; + /// Game service running within the server pub struct Game { /// Unique ID for this game pub id: GameID, + /// The current game state pub state: GameState, /// The current game setting - pub setting: GameSettings, + pub settings: GameSettings, /// The game attributes pub attributes: AttrMap, - /// The list of players in this game - pub players: Vec, - /// Services access - pub game_manager: Link, -} - -impl Service for Game { - fn stopping(&mut self) { - debug!("Game is stopping (GID: {})", self.id); - // Remove the stopping game - let _ = self - .game_manager - .do_send(RemoveGameMessage { game_id: self.id }); - } -} -impl Game { - /// Starts a new game service with the provided initial state - /// - /// `id` The unique ID for the game - /// `attributes` The initial game attributes - /// `setting` The initial game setting value - pub fn start( - id: GameID, - attributes: AttrMap, - setting: GameSettings, - game_manager: Link, - ) -> Link { - let this = Game { - id, - state: GameState::Initializing, - setting, - attributes, - players: Vec::with_capacity(4), - game_manager, - }; + pub players: Vec, - this.start() - } + /// Services access + pub game_manager: Arc, } /// Snapshot of the current game state and players @@ -96,11 +68,11 @@ pub type AttrMap = TdfMap; /// within a game pub struct GamePlayer { /// Session player - pub player: Player, + pub player: Arc, /// Session address - pub link: Link, + pub link: SessionLink, /// Networking information for the player - pub net: NetData, + pub net: Arc, /// The mesh state of the player pub state: PlayerState, } @@ -114,7 +86,7 @@ pub struct GamePlayerSnapshot { /// The player name of the snapshot pub display_name: Box, /// The player net data of the snapshot if collected - pub net: Option, + pub net: Option>, } impl GamePlayer { @@ -124,7 +96,7 @@ impl GamePlayer { /// `player` The session player /// `net` The player networking details /// `addr` The session address - pub fn new(player: Player, net: NetData, link: Link) -> Self { + pub fn new(player: Arc, net: Arc, link: SessionLink) -> Self { Self { player, link, @@ -133,8 +105,9 @@ impl GamePlayer { } } - pub fn set_game(&self, game: Option) { - let _ = self.link.do_send(SetGameMessage { game }); + #[inline] + pub fn push(&self, packet: Packet) { + self.link.push(packet) } /// Takes a snapshot of the current player state @@ -152,48 +125,57 @@ impl GamePlayer { } pub fn encode(&self, game_id: GameID, slot: usize, w: &mut S) { - w.tag_blob_empty(b"BLOB"); - w.tag_u8(b"EXID", 0); - w.tag_owned(b"GID", game_id); - w.tag_u32(b"LOC", 0x64654445); - w.tag_str(b"NAME", &self.player.display_name); - w.tag_u32(b"PID", self.player.id); - w.tag_ref(b"PNET", &self.net.addr); - w.tag_owned(b"SID", slot); - w.tag_u8(b"SLOT", 0); - w.tag_ref(b"STAT", &self.state); - w.tag_u16(b"TIDX", 0xffff); - w.tag_u8(b"TIME", 0); /* Unix timestamp in millseconds */ - w.tag_alt(b"UGID", ObjectId::new_raw(0, 0, 0)); - w.tag_u32(b"UID", self.player.id); - w.tag_group_end(); + w.group_body(|w| { + w.tag_blob_empty(b"BLOB"); + w.tag_u8(b"EXID", 0); + w.tag_owned(b"GID", game_id); + w.tag_u32(b"LOC", 0x64654445); + w.tag_str(b"NAME", &self.player.display_name); + w.tag_u32(b"PID", self.player.id); + w.tag_ref(b"PNET", &self.net.addr); + w.tag_owned(b"SID", slot); + w.tag_u8(b"SLOT", 0); + w.tag_ref(b"STAT", &self.state); + w.tag_u16(b"TIDX", 0xffff); + w.tag_u8(b"TIME", 0); /* Unix timestamp in millseconds */ + w.tag_alt(b"UGID", ObjectId::new_raw(0, 0, 0)); + w.tag_u32(b"UID", self.player.id); + }); } } -impl Drop for GamePlayer { - fn drop(&mut self) { - // Clear player game when game player is dropped - self.set_game(None); - } +/// Different results for checking if a game is +/// joinable +pub enum GameJoinableState { + /// Game is currenlty joinable + Joinable, + /// Game is full + Full, + /// The game doesn't match the provided rules + NotMatch, + /// The game is stopping + Stopping, } -/// Message to add a new player to this game -#[derive(Message)] -pub struct AddPlayerMessage { - /// The player to add to the game - pub player: GamePlayer, - /// Context to which the player should be added - pub context: GameSetupContext, -} +impl Game { + /// Constant for the maximum number of players allowed in + /// a game at one time. Used to determine a games full state + const MAX_PLAYERS: usize = 4; -/// Handler for adding a player to the game -impl Handler for Game { - type Response = (); + pub async fn game_data(&self) -> RawBlaze { + let data = GetGameDetails { game: self }; + data.into() + } - fn handle(&mut self, msg: AddPlayerMessage, _ctx: &mut ServiceContext) { + pub fn add_player(&mut self, mut player: GamePlayer, context: GameSetupContext) { let slot = self.players.len(); - self.players.push(msg.player); + // Player is the host player (They are connected) + if slot == 0 { + player.state = PlayerState::ActiveConnected; + } + + self.players.push(player); // Obtain the player that was just added let player = self @@ -201,12 +183,10 @@ impl Handler for Game { .last() .expect("Player was added but is missing from players"); - // Whether the player was not the host player - let is_other = slot != 0; - - if is_other { + // Player isn't the host player + if slot != 0 { // Notify other players of the joined player - self.notify_all( + self.push_all(&Packet::notify( game_manager::COMPONENT, game_manager::PLAYER_JOINING, PlayerJoining { @@ -214,143 +194,25 @@ impl Handler for Game { player, game_id: self.id, }, - ); + )); // Update other players with the client details - self.update_clients(player); + self.add_user_sub(player.player.id, player.link.clone()); } // Notify the joiner of the game details - self.notify_game_setup(player, msg.context); - - // Set current game of this player - player.set_game(Some(self.id)); - - if is_other { - // Provide the new players session details to the other players - let links: Vec> = self - .players - .iter() - .map(|player| player.link.clone()) - .collect(); - let _ = player.link.do_send(InformSessions { links }); - } - } -} - -/// Message to alter the current game state -#[derive(Message)] -pub struct SetStateMessage { - /// The new game state - pub state: GameState, -} - -/// Handler for setting the game state -impl Handler for Game { - type Response = (); - - fn handle(&mut self, msg: SetStateMessage, _ctx: &mut ServiceContext) { - self.state = msg.state; - self.notify_state(); - } -} - -/// Message for setting the current game setting value -#[derive(Message)] -pub struct SetSettingMessage { - /// The new setting value - pub setting: GameSettings, -} - -/// Handler for setting the game setting -impl Handler for Game { - type Response = (); - - fn handle(&mut self, msg: SetSettingMessage, _ctx: &mut ServiceContext) { - let setting = msg.setting; - debug!("Updating game setting (Value: {:?})", &setting); - self.setting = setting; - self.notify_all( - game_manager::COMPONENT, - game_manager::GAME_SETTINGS_CHANGE, - SettingChange { - id: self.id, - setting, - }, - ); - } -} - -/// Message for setting the game attributes -#[derive(Message)] -pub struct SetAttributesMessage { - /// The new attributes - pub attributes: AttrMap, -} - -/// Handler for setting the game attributes -impl Handler for Game { - type Response = (); - - fn handle(&mut self, msg: SetAttributesMessage, ctx: &mut ServiceContext) { - let attributes = msg.attributes; - - debug!("Updating game attributes"); - let packet = Packet::notify( - game_manager::COMPONENT, - game_manager::GAME_ATTRIB_CHANGE, - AttributesChange { - id: self.id, - attributes: &attributes, - }, - ); - - self.attributes.insert_presorted(attributes.into_inner()); - self.push_all(&packet); - - // Don't update matchmaking for full games - if self.players.len() < Self::MAX_PLAYERS { - let _ = self.game_manager.do_send(ProcessQueueMessage { - link: ctx.link(), - game_id: self.id, - }); - } + self.notify_game_setup(player, context); } -} - -/// Message to update the mesh connection state between -/// clients -#[derive(Message)] -pub struct UpdateMeshMessage { - /// The ID of the session updating its connection - pub id: PlayerID, - /// The target player that its updating with - pub target: PlayerID, - /// The player mesh state - pub state: PlayerState, -} - -/// Handler for updating mesh connections -impl Handler for Game { - type Response = (); - fn handle(&mut self, msg: UpdateMeshMessage, _ctx: &mut ServiceContext) { - let state = msg.state; + pub fn update_mesh(&mut self, id: PlayerID, target: PlayerID, state: PlayerState) { if let PlayerState::ActiveConnecting = state { // Ensure the target player is in the game - if !self - .players - .iter() - .any(|value| value.player.id == msg.target) - { + if !self.players.iter().any(|value| value.player.id == target) { return; } // Find the index of the session player - let session = self - .players - .iter_mut() - .find(|value| value.player.id == msg.id); + let session = self.players.iter_mut().find(|value| value.player.id == id); let session = match session { Some(value) => value, @@ -367,55 +229,39 @@ impl Handler for Game { state: session.state, }; + // TODO: Move into a "connection complete" function + // Notify players of the player state change - self.notify_all( + self.push_all(&Packet::notify( game_manager::COMPONENT, game_manager::GAME_PLAYER_STATE_CHANGE, state_change, - ); + )); // Notify players of the completed connection - self.notify_all( + self.push_all(&Packet::notify( game_manager::COMPONENT, game_manager::PLAYER_JOIN_COMPLETED, JoinComplete { game_id: self.id, player_id, }, - ); + )); // Add the player to the admin list self.modify_admin_list(player_id, AdminListOperation::Add); } } -} - -/// Message for removing a player from the game -#[derive(Message)] -#[msg(rtype = "()")] -pub struct RemovePlayerMessage { - /// The ID of the player/session to remove - pub id: u32, - /// The reason for removing the player - pub reason: RemoveReason, -} -/// Handler for removing a player from the game -impl Handler for Game { - type Response = (); - fn handle( - &mut self, - msg: RemovePlayerMessage, - ctx: &mut ServiceContext, - ) -> Self::Response { + pub fn remove_player(&mut self, id: u32, reason: RemoveReason) { // Already empty game handling if self.players.is_empty() { - ctx.stop(); + self.stop(); return; } // Find the player index - let index = self.players.iter().position(|v| v.player.id == msg.id); + let index = self.players.iter().position(|v| v.player.id == id); let index = match index { Some(value) => value, @@ -425,12 +271,15 @@ impl Handler for Game { // Remove the player let player = self.players.remove(index); - // Set current game of this player - player.set_game(None); + // Clear current game of this player + let clear_link = player.link.clone(); + tokio::spawn(async move { + let _ = clear_link.clear_game().await; + }); // Update the other players - self.notify_player_removed(&player, msg.reason); - self.notify_fetch_data(&player); + self.notify_player_removed(&player, reason); + self.rem_user_sub(player.player.id, player.link.clone()); self.modify_admin_list(player.player.id, AdminListOperation::Remove); debug!( @@ -445,107 +294,83 @@ impl Handler for Game { if self.players.is_empty() { // Game is empty stop it - ctx.stop(); + self.stop(); } } -} -/// Handler for checking if a game is joinable -#[derive(Message, Clone)] -#[msg(rtype = "GameJoinableState")] -pub struct CheckJoinableMessage { - /// The player rule set if one is provided - pub rule_set: Option>, -} + pub fn new( + id: GameID, + attributes: AttrMap, + settings: GameSettings, + game_manager: Arc, + ) -> Game { + Game { + id, + attributes, + settings, + state: Default::default(), + players: Default::default(), + game_manager, + } + } -/// Different results for checking if a game is -/// joinable -pub enum GameJoinableState { - /// Game is currenlty joinable - Joinable, - /// Game is full - Full, - /// The game doesn't match the provided rules - NotMatch, -} + /// Called by the game manager service once this game has been stopped and + /// removed from the game list + fn stopped(self) { + debug!("Game is stopped (GID: {})", self.id); + } + + fn stop(&mut self) { + // Mark the game as stopping + self.state = GameState::Destructing; + + if !self.players.is_empty() { + warn!("Game {} was stopped with players still present", self.id); + } + + // Remove the stopping game + let game_manager = self.game_manager.clone(); + let game_id = self.id; + tokio::spawn(async move { + game_manager.remove_game(game_id).await; + }); + } -/// Handler for checking if a game is joinable -impl Handler for Game { - type Response = Mr; + pub fn joinable_state(&self, rule_set: Option<&RuleSet>) -> GameJoinableState { + if let GameState::Destructing = self.state { + return GameJoinableState::Stopping; + } - fn handle( - &mut self, - msg: CheckJoinableMessage, - _ctx: &mut ServiceContext, - ) -> Self::Response { // Handle full game if self.players.len() >= Self::MAX_PLAYERS { - return Mr(GameJoinableState::Full); + return GameJoinableState::Full; } // Check ruleset matches - if let Some(rule_set) = msg.rule_set { + if let Some(rule_set) = rule_set { if !rule_set.matches(&self.attributes) { - return Mr(GameJoinableState::NotMatch); + return GameJoinableState::NotMatch; } } - Mr(GameJoinableState::Joinable) + GameJoinableState::Joinable } -} - -/// Message to take a snapshot of the game and its state -#[derive(Message)] -#[msg(rtype = "GameSnapshot")] -pub struct SnapshotMessage { - /// Whether to include the networking details in the snapshot - pub include_net: bool, -} - -/// Handler for taking snapshots of the game and its state -impl Handler for Game { - type Response = Mr; - fn handle(&mut self, msg: SnapshotMessage, _ctx: &mut ServiceContext) -> Self::Response { + pub fn snapshot(&self, include_net: bool) -> GameSnapshot { let players = self .players .iter() - .map(|value| value.snapshot(msg.include_net)) + .map(|value| value.snapshot(include_net)) .collect(); - Mr(GameSnapshot { + + GameSnapshot { id: self.id, state: self.state, - setting: self.setting.bits(), + setting: self.settings.bits(), attributes: self.attributes.clone(), players, - }) - } -} - -/// Message for getting an encoded packet body of the game data -#[derive(Message)] -#[msg(rtype = "RawBlaze")] -pub struct GetGameDataMessage; - -/// Handler for getting an encoded packet body of the game data -impl Handler for Game { - type Response = Mr; - - fn handle( - &mut self, - _msg: GetGameDataMessage, - _ctx: &mut ServiceContext, - ) -> Self::Response { - let data = GetGameDetails { game: self }; - let data: RawBlaze = data.into(); - Mr(data) + } } -} - -impl Game { - /// Constant for the maximum number of players allowed in - /// a game at one time. Used to determine a games full state - const MAX_PLAYERS: usize = 4; /// Writes the provided packet to all connected sessions. /// Does not wait for the write to complete just waits for @@ -555,48 +380,97 @@ impl Game { fn push_all(&self, packet: &Packet) { self.players .iter() - .for_each(|value| value.link.push(packet.clone())); + .for_each(|value| value.push(packet.clone())); } - /// Sends a notification packet to all the connected session - /// with the provided component and contents - /// - /// `component` The packet component - /// `contents` The packet contents - fn notify_all(&self, component: u16, command: u16, contents: C) { - let packet = Packet::notify(component, command, contents); - self.push_all(&packet); - } + pub fn set_state(&mut self, state: GameState) { + self.state = state; + + debug!("Updated game state (Value: {:?})", &state); - /// Notifies all players of the current game state - fn notify_state(&self) { - self.notify_all( + self.push_all(&Packet::notify( game_manager::COMPONENT, game_manager::GAME_STATE_CHANGE, - StateChange { + StateChange { id: self.id, state }, + )); + } + + pub fn set_settings(&mut self, settings: GameSettings) { + self.settings = settings; + + debug!("Updated game setting (Value: {:?})", &settings); + + self.push_all(&Packet::notify( + game_manager::COMPONENT, + game_manager::GAME_SETTINGS_CHANGE, + SettingChange { + id: self.id, + settings, + }, + )); + } + + pub fn set_attributes(&mut self, attributes: AttrMap) { + let packet = Packet::notify( + game_manager::COMPONENT, + game_manager::GAME_ATTRIB_CHANGE, + AttributesChange { id: self.id, - state: self.state, + attributes: &attributes, }, ); + + self.attributes.insert_presorted(attributes.into_inner()); + + debug!("Updated game attributes"); + + self.push_all(&packet); } - /// Updates all the client details for the provided session. - /// Tells each client to send session updates to the session - /// and the session to send them as well. - /// - /// `session` The session to update for - fn update_clients(&self, player: &GamePlayer) { - debug!("Updating clients with new session details"); - self.players.iter().for_each(|value| { - if value.player.id != player.player.id { - let addr1 = player.link.clone(); - let addr2 = value.link.clone(); - - // Queue the session details to be sent to this client - let _ = player.link.do_send(DetailsMessage { link: addr2 }); - let _ = value.link.do_send(DetailsMessage { link: addr1 }); - } - }); + /// Creates a subscription between all the users and the the target player + fn add_user_sub(&self, target_id: PlayerID, target_link: SessionLink) { + debug!("Adding user subscriptions"); + + // Subscribe all the clients to eachother + self.players + .iter() + .filter(|other| other.player.id.ne(&target_id)) + .for_each(|other| { + let other_id = other.player.id; + let other_link = other.link.clone(); + let target_link = target_link.clone(); + + tokio::spawn(async move { + target_link + .add_subscriber(other_id, other_link.clone()) + .await; + other_link + .add_subscriber(target_id, target_link.clone()) + .await; + }); + }); + } + + /// Notifies the provided player and all other players + /// in the game that they should remove eachother from + /// their player data list + fn rem_user_sub(&self, target_id: PlayerID, target_link: SessionLink) { + debug!("Removing user subscriptions"); + + // Unsubscribe all the clients from eachother + self.players + .iter() + .filter(|other| other.player.id.ne(&target_id)) + .for_each(|other| { + let other_id = other.player.id; + let other_link = other.link.clone(); + let target_link = target_link.clone(); + + tokio::spawn(async move { + target_link.remove_subscriber(other_id).await; + other_link.remove_subscriber(target_id).await; + }); + }); } /// Notifies the provided player that the game has been setup and @@ -608,12 +482,12 @@ impl Game { let packet = Packet::notify( game_manager::COMPONENT, game_manager::GAME_SETUP, - GameDetails { + GameSetupResponse { game: self, context, }, ); - player.link.push(packet); + player.push(packet); } /// Modifies the psudo admin list this list doesn't actually exist in @@ -628,7 +502,7 @@ impl Game { None => return, }; - self.notify_all( + self.push_all(&Packet::notify( game_manager::COMPONENT, game_manager::ADMIN_LIST_CHANGE, AdminListChange { @@ -637,7 +511,7 @@ impl Game { operation, host_id: host.player.id, }, - ); + )); } /// Notifies all the session and the removed session that a @@ -657,35 +531,7 @@ impl Game { }, ); self.push_all(&packet); - player.link.push(packet); - } - - /// Notifies all the sessions in the game to fetch the player data - /// for the provided session and the session to fetch the extended - /// data for all the other sessions. Will early return if there - /// are no players left. - /// - /// `session` The session to update with the other clients - /// `player_id` The player id of the session to update - fn notify_fetch_data(&self, player: &GamePlayer) { - self.notify_all( - user_sessions::COMPONENT, - user_sessions::FETCH_EXTENDED_DATA, - FetchExtendedData { - player_id: player.player.id, - }, - ); - - for other_player in &self.players { - let packet = Packet::notify( - user_sessions::COMPONENT, - user_sessions::FETCH_EXTENDED_DATA, - FetchExtendedData { - player_id: other_player.player.id, - }, - ); - player.link.push(packet) - } + player.push(packet); } /// Attempts to migrate the host of this game if there are still players @@ -693,8 +539,8 @@ impl Game { fn try_migrate_host(&mut self) { // TODO: With more than one player this fails - // Obtain the new player at the first index - let new_host = match self.players.first() { + // Obtain the new host player + let host_id = match self.players.first().map(|player| player.player.id) { Some(value) => value, None => return, }; @@ -702,29 +548,25 @@ impl Game { debug!("Starting host migration (GID: {})", self.id); // Start host migration - self.state = GameState::Migrating; - self.notify_state(); - self.notify_all( + self.set_state(GameState::Migrating); + self.push_all(&Packet::notify( game_manager::COMPONENT, game_manager::HOST_MIGRATION_START, HostMigrateStart { game_id: self.id, - host_id: new_host.player.id, + host_id, pmig: 2, slot: 0, }, - ); + )); // Finished host migration - self.state = GameState::InGame; - self.notify_state(); - self.notify_all( + self.set_state(GameState::InGame); + self.push_all(&Packet::notify( game_manager::COMPONENT, game_manager::HOST_MIGRATION_FINISHED, HostMigrateFinished { game_id: self.id }, - ); - - self.update_clients(new_host); + )); debug!("Finished host migration (GID: {})", self.id); } diff --git a/src/services/game/models.rs b/src/services/game/models.rs deleted file mode 100644 index 62b53165..00000000 --- a/src/services/game/models.rs +++ /dev/null @@ -1,629 +0,0 @@ -use super::{AttrMap, Game, GamePlayer}; -use crate::utils::{ - models::NetworkAddress, - types::{GameID, GameSlot, PlayerID}, -}; -use bitflags::bitflags; - -use serde::Serialize; -use tdf::{ - TdfDeserialize, TdfDeserializeOwned, TdfSerialize, TdfSerializeOwned, TdfType, TdfTyped, -}; - -/// Different states the game can be in -#[derive( - Default, Debug, Serialize, Clone, Copy, PartialEq, Eq, TdfSerialize, TdfDeserialize, TdfTyped, -)] -#[repr(u8)] -pub enum GameState { - NewState = 0x0, - #[tdf(default)] - #[default] - Initializing = 0x1, - Virtual = 0x2, - PreGame = 0x82, - InGame = 0x83, - PostGame = 0x4, - Migrating = 0x5, - Destructing = 0x6, - Resetable = 0x7, - ReplaySetup = 0x8, -} - -bitflags! { - #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] - pub struct GameSettings: u16 { - const NONE = 0; - const OPEN_TO_BROWSING = 1; - const OPEN_TO_MATCHMAKING = 2; - const OPEN_TO_INVITES = 4; - const OPEN_TO_JOIN_BY_PLAYER = 8; - const HOST_MIGRATABLE = 0x10; - const RANKED = 0x20; - const ADMIN_ONLY_INVITES = 0x40; - const ENFORCE_SINGLE_GROUP_JOIN = 0x80; - const JOIN_IN_PROGRESS_SUPPORTED = 0x100; - const ADMIN_INVITE_ONLY_IGNORE_ENTRY_CHECKS = 0x200; - const IGNORE_ENTRY_CRITERIA_WITH_INVITE = 0x400; - const ENABLE_PERSISTED_GAME_ID = 0x800; - const ALLOW_SAME_TEAM_ID = 0x1000; - const VIRTUALIZED = 0x2000; - const SEND_ORPHANDED_GAME_REPORT_EVENT = 0x4000; - const ALLOW_ANY_REPUTATION = 0x8000; - } -} - -impl From for u16 { - fn from(value: GameSettings) -> Self { - value.bits() - } -} - -impl TdfSerialize for GameSettings { - fn serialize(&self, w: &mut S) { - ::serialize_owned(self.bits(), w) - } -} - -impl TdfDeserializeOwned for GameSettings { - fn deserialize_owned(r: &mut tdf::TdfDeserializer<'_>) -> tdf::DecodeResult { - let value = u16::deserialize_owned(r)?; - Ok(GameSettings::from_bits_retain(value)) - } -} - -impl TdfTyped for GameSettings { - const TYPE: TdfType = TdfType::VarInt; -} - -#[derive( - Default, Debug, Serialize, Clone, Copy, PartialEq, Eq, TdfDeserialize, TdfSerialize, TdfTyped, -)] -#[repr(u8)] -pub enum PlayerState { - /// Link between the mesh points is not connected - #[default] - #[tdf(default)] - Reserved = 0x0, - Queued = 0x1, - /// Link is being formed between two mesh points - ActiveConnecting = 0x2, - ActiveMigrating = 0x3, - /// Link is connected between two mesh points - ActiveConnected = 0x4, - ActiveKickPending = 0x5, -} - -/// Message for a game state changing -#[derive(TdfSerialize)] -pub struct StateChange { - /// The ID of the game - #[tdf(tag = "GID")] - pub id: GameID, - /// The game state - #[tdf(tag = "GSTA")] - pub state: GameState, -} - -/// Message for a game setting changing -#[derive(TdfSerialize)] -pub struct SettingChange { - /// The game setting - #[tdf(tag = "ATTR", into = u16)] - pub setting: GameSettings, - /// The ID of the game - #[tdf(tag = "GID")] - pub id: GameID, -} - -/// Packet for game attribute changes -pub struct AttributesChange<'a> { - /// Borrowed game attributes map - pub attributes: &'a AttrMap, - /// The id of the game the attributes have changed for - pub id: GameID, -} - -impl TdfSerialize for AttributesChange<'_> { - fn serialize(&self, w: &mut S) { - w.tag_ref(b"ATTR", self.attributes); - w.tag_owned(b"GID", self.id) - } -} - -/// Message for a player joining notification -pub struct PlayerJoining<'a> { - /// The ID of the game - pub game_id: GameID, - /// The slot the player is joining into - pub slot: GameSlot, - /// The player that is joining - pub player: &'a GamePlayer, -} - -impl TdfSerialize for PlayerJoining<'_> { - fn serialize(&self, w: &mut S) { - w.tag_u32(b"GID", self.game_id); - - w.tag_group(b"PDAT"); - self.player.encode(self.game_id, self.slot, w); - } -} - -const VSTR: &str = "ME3-295976325-179181965240128"; - -pub enum GameSetupContext { - /// Context without additional data - Dataless(DatalessContext), - /// Context added from matchmaking - Matchmaking(u32), -} - -#[derive(Debug, Copy, Clone)] -#[repr(u8)] -pub enum DatalessContext { - /// Session created the game - CreateGameSetup = 0x0, - /// Session joined by ID - JoinGameSetup = 0x1, - // IndirectJoinGameFromQueueSetup = 0x2, - // IndirectJoinGameFromReservationContext = 0x3, - // HostInjectionSetupContext = 0x4, -} - -pub struct GameDetails<'a> { - pub game: &'a Game, - pub context: GameSetupContext, -} - -impl TdfSerialize for GameDetails<'_> { - fn serialize(&self, w: &mut S) { - let game = self.game; - let host_player = match game.players.first() { - Some(value) => value, - None => return, - }; - - // Game details - w.group(b"GAME", |w| { - w.tag_list_iter_owned(b"ADMN", game.players.iter().map(|player| player.player.id)); - w.tag_ref(b"ATTR", &game.attributes); - - w.tag_list_slice(b"CAP", &[4u8, 0u8]); - - w.tag_u32(b"GID", game.id); - w.tag_str(b"GNAM", &host_player.player.display_name); - - w.tag_u64(b"GPVH", 0x5a4f2b378b715c6); - w.tag_u16(b"GSET", game.setting.bits()); - w.tag_u64(b"GSID", 0x4000000a76b645); - w.tag_ref(b"GSTA", &game.state); - - w.tag_str_empty(b"GTYP"); - { - w.tag_list_start(b"HNET", TdfType::Group, 1); - w.write_byte(2); - if let NetworkAddress::AddressPair(pair) = &host_player.net.addr { - TdfSerialize::serialize(pair, w) - } - } - - w.tag_u32(b"HSES", host_player.player.id); - w.tag_zero(b"IGNO"); - w.tag_u8(b"MCAP", 4); - w.tag_ref(b"NQOS", &host_player.net.qos); - w.tag_zero(b"NRES"); - w.tag_zero(b"NTOP"); - w.tag_str_empty(b"PGID"); - w.tag_blob_empty(b"PGSR"); - - w.group(b"PHST", |w| { - w.tag_u32(b"HPID", host_player.player.id); - w.tag_zero(b"HSLT"); - }); - - w.tag_u8(b"PRES", 0x1); - w.tag_str_empty(b"PSAS"); - w.tag_u8(b"QCAP", 0x0); - w.tag_u32(b"SEED", 0x4cbc8585); - w.tag_u8(b"TCAP", 0x0); - - w.group(b"THST", |w| { - w.tag_u32(b"HPID", host_player.player.id); - w.tag_u8(b"HSLT", 0x0); - }); - - w.tag_str(b"UUID", "286a2373-3e6e-46b9-8294-3ef05e479503"); - w.tag_u8(b"VOIP", 0x2); - w.tag_str(b"VSTR", VSTR); - w.tag_blob_empty(b"XNNC"); - w.tag_blob_empty(b"XSES"); - }); - - // Player list - w.tag_list_start(b"PROS", TdfType::Group, game.players.len()); - for (slot, player) in game.players.iter().enumerate() { - player.encode(game.id, slot, w); - } - - match &self.context { - GameSetupContext::Dataless(context) => { - w.tag_union_start(b"REAS", 0x0); - w.group(b"VALU", |writer| { - writer.tag_u8(b"DCTX", (*context) as u8); - }); - } - GameSetupContext::Matchmaking(id) => { - w.tag_union_start(b"REAS", 0x3); - w.group(b"VALU", |writer| { - const FIT: u16 = 21600; - - writer.tag_u16(b"FIT", FIT); - writer.tag_u16(b"MAXF", FIT); - writer.tag_u32(b"MSID", *id); - // TODO: Matchmaking result - // SUCCESS_CREATED_GAME = 0 - // SUCCESS_JOINED_NEW_GAME = 1 - // SUCCESS_JOINED_EXISTING_GAME = 2 - // SESSION_TIMED_OUT = 3 - // SESSION_CANCELED = 4 - // SESSION_TERMINATED = 5 - // SESSION_ERROR_GAME_SETUP_FAILED = 6 - writer.tag_u8(b"RSLT", 0x2); - writer.tag_u32(b"USID", *id); - }); - } - } - } -} - -pub struct GetGameDetails<'a> { - pub game: &'a Game, -} - -impl TdfSerialize for GetGameDetails<'_> { - fn serialize(&self, w: &mut S) { - let game = self.game; - let host_player = match game.players.first() { - Some(value) => value, - None => return, - }; - - w.tag_list_start(b"GDAT", TdfType::Group, 1); - w.group_body(|w| { - w.tag_list_iter_owned(b"ADMN", game.players.iter().map(|player| player.player.id)); - w.tag_ref(b"ATTR", &game.attributes); - w.tag_list_slice(b"CAP", &[4u8, 0u8]); - - w.tag_u32(b"GID", game.id); - w.tag_str(b"GNAM", &host_player.player.display_name); - w.tag_u16(b"GSET", game.setting.bits()); - w.tag_ref(b"GSTA", &game.state); - { - w.tag_list_start(b"HNET", TdfType::Group, 1); - w.write_byte(2); - if let NetworkAddress::AddressPair(pair) = &host_player.net.addr { - TdfSerialize::serialize(pair, w) - } - } - w.tag_u32(b"HOST", host_player.player.id); - w.tag_zero(b"NTOP"); - - w.tag_list_slice(b"PCNT", &[1u8, 0u8]); - - w.tag_u8(b"PRES", 0x2); - w.tag_str(b"PSAS", "ea-sjc"); - w.tag_str_empty(b"PSID"); - w.tag_zero(b"QCAP"); - w.tag_zero(b"QCNT"); - w.tag_zero(b"SID"); - w.tag_zero(b"TCAP"); - w.tag_u8(b"VOIP", 0x2); - w.tag_str(b"VSTR", VSTR); - }); - } -} - -#[derive(TdfSerialize)] -pub struct PlayerStateChange { - #[tdf(tag = "GID")] - pub gid: GameID, - #[tdf(tag = "PID")] - pub pid: PlayerID, - #[tdf(tag = "STAT")] - pub state: PlayerState, -} - -#[derive(TdfSerialize)] -pub struct JoinComplete { - #[tdf(tag = "GID")] - pub game_id: GameID, - #[tdf(tag = "PID")] - pub player_id: PlayerID, -} - -#[derive(TdfSerialize)] -pub struct AdminListChange { - #[tdf(tag = "ALST")] - pub player_id: PlayerID, - #[tdf(tag = "GID")] - pub game_id: GameID, - #[tdf(tag = "OPER")] - pub operation: AdminListOperation, - #[tdf(tag = "UID")] - pub host_id: PlayerID, -} - -/// Different operations that can be performed on -/// the admin list -#[derive(Debug, Clone, Copy, TdfSerialize, TdfTyped)] -#[repr(u8)] -pub enum AdminListOperation { - Add = 0, - Remove = 1, -} - -#[derive(TdfSerialize)] -pub struct PlayerRemoved { - #[tdf(tag = "CNTX")] - pub cntx: u8, - #[tdf(tag = "GID")] - pub game_id: GameID, - #[tdf(tag = "PID")] - pub player_id: PlayerID, - #[tdf(tag = "REAS")] - pub reason: RemoveReason, -} - -#[derive(Default, Debug, Clone, Copy, TdfSerialize, TdfDeserialize, TdfTyped)] -#[repr(u8)] -pub enum RemoveReason { - /// Hit timeout while joining - JoinTimeout = 0x0, - /// Player lost PTP conneciton - PlayerConnectionLost = 0x1, - /// Player lost connection with the Pocket Relay server - ServerConnectionLost = 0x2, - /// Game migration failed - MigrationFailed = 0x3, - GameDestroyed = 0x4, - GameEnded = 0x5, - /// Generic player left the game reason - #[tdf(default)] - #[default] - PlayerLeft = 0x6, - GroupLeft = 0x7, - /// Player kicked - PlayerKicked = 0x8, - /// Player kicked and banned - PlayerKickedWithBan = 0x9, - /// Failed to join from the queue - PlayerJoinFromQueueFailed = 0xA, - PlayerReservationTimeout = 0xB, - HostEjected = 0xC, -} - -#[derive(TdfSerialize)] -pub struct FetchExtendedData { - #[tdf(tag = "BUID")] - pub player_id: PlayerID, -} - -#[derive(TdfSerialize)] -pub struct HostMigrateStart { - #[tdf(tag = "GID")] - pub game_id: GameID, - #[tdf(tag = "HOST")] - pub host_id: PlayerID, - #[tdf(tag = "PMIG")] - pub pmig: u32, - #[tdf(tag = "SLOT")] - pub slot: u8, -} - -#[derive(TdfSerialize)] -pub struct HostMigrateFinished { - #[tdf(tag = "GID")] - pub game_id: GameID, -} - -/// -/// # Example -/// ``` -/// Content: { -/// "ASIL": List [ -/// { -/// "CGS": { -/// "EVST": 6, -/// "MMSN": 1, -/// "NOMP": 0, -/// }, -/// "CUST": { -/// }, -/// "DNFS": { -/// "MDNF": 0, -/// "XDNF": 0, -/// }, -/// "FGS": { -/// "GNUM": 0, -/// }, -/// "GEOS": { -/// "DIST": 0, -/// }, -/// "GRDA": Map { -/// "ME3_gameDifficultyRule": { -/// "NAME": "ME3_gameDifficultyRule", -/// "VALU": List ["difficulty0"], -/// } -/// "ME3_gameEnemyTypeRule": { -/// "NAME": "ME3_gameEnemyTypeRule", -/// "VALU": List ["enemy0", "enemy1", "enemy2", "enemy3", "enemy4", "enemy5", "enemy6", "enemy7", "enemy8", "enemy9", "random", "abstain"], -/// } -/// "ME3_gameMapMatchRule": { -/// "NAME": "ME3_gameMapMatchRule", -/// "VALU": List ["map0", "map1", "map2", "map3", "map4", "map5", "map6", "map7", "map8", "map9", "map10", "map11", "map12", "map13", "map14", "map15", "map16", "map17", "map18", "map19", "map20", "map21", "map22", "map23", "map24", "map25", "map26", "map27", "map28", "map29", "random", "abstain"], -/// } -/// "ME3_gameStateMatchRule": { -/// "NAME": "ME3_gameStateMatchRule", -/// "VALU": List ["IN_LOBBY", "IN_LOBBY_LONGTIME", "IN_GAME_STARTING", "abstain"], -/// } -/// "ME3_rule_dlc2300": { -/// "NAME": "ME3_rule_dlc2300", -/// "VALU": List ["required", "preferred"], -/// } -/// "ME3_rule_dlc2500": { -/// "NAME": "ME3_rule_dlc2500", -/// "VALU": List ["required", "preferred"], -/// } -/// "ME3_rule_dlc2700": { -/// "NAME": "ME3_rule_dlc2700", -/// "VALU": List ["required", "preferred"], -/// } -/// "ME3_rule_dlc3050": { -/// "NAME": "ME3_rule_dlc3050", -/// "VALU": List ["required", "preferred"], -/// } -/// "ME3_rule_dlc3225": { -/// "NAME": "ME3_rule_dlc3225", -/// "VALU": List ["required", "preferred"], -/// } -/// }, -/// "GSRD": { -/// "PMAX": 4, -/// "PMIN": 2, -/// }, -/// "HBRD": { -/// "BVAL": 1, -/// }, -/// "HVRD": { -/// "VVAL": 0, -/// }, -/// "PSRS": { -/// }, -/// "RRDA": { -/// "RVAL": 0, -/// }, -/// "TSRS": { -/// "TMAX": 0, -/// "TMIN": 0, -/// }, -/// "UEDS": Map { -/// "ME3_characterSkill_Rule": { -/// "AMAX": 500, -/// "AMIN": 0, -/// "MUED": 500, -/// "NAME": "ME3_characterSkill_Rule", -/// } -/// }, -/// "VGRS": { -/// "VVAL": 0, -/// }, -/// } -/// ], -/// "MSID": 0x1, -/// "USID": 0x1, -///} -/// ``` -pub struct AsyncMatchmakingStatus { - pub player_id: PlayerID, -} - -impl TdfSerialize for AsyncMatchmakingStatus { - fn serialize(&self, w: &mut S) { - w.tag_list_start(b"ASIL", TdfType::Group, 1); - w.group_body(|w| { - // Create game status - w.group(b"CGS", |w| { - // Evaluate status - // PlayerCountSufficient = 1, - // AcceptableHostFound = 2, - // TeamSizesSufficient = 4 - w.tag_u8(b"EVST", 2 | 4); - // Number of matchmaking sessions - w.tag_u8(b"MMSN", 1); - // Number of matched players - w.tag_u8(b"NOMP", 0); - }); - - // Custom async status - w.tag_group_empty(b"CUST"); - - // DNF rule status - w.group(b"DNFS", |w| { - // My DNF value - w.tag_zero(b"MDNF"); - // Max DNF value - w.tag_zero(b"XDNF"); - }); - - // Find game status - w.group(b"FGS", |w| { - // Number of games - w.tag_zero(b"GNUM"); - }); - - // Geo location rule status - w.group(b"GEOS", |w| { - // Max distance - w.tag_zero(b"DIST"); - }); - - // Generic rule status dictionary (TODO: RULES HERE) - w.tag_map_start(b"GRDA", TdfType::String, TdfType::Group, 0); - - // Game size rule status - w.group(b"GSRD", |w| { - // Max player count accepted - w.tag_u8(b"PMAX", 4); - // Min player count accepted - w.tag_u8(b"PMIN", 2); - }); - - // Host balance rule status - w.group(b"HBRD", |w| { - // Host balance values - // HOSTS_STRICTLY_BALANCED = 0, - // HOSTS_BALANCED = 1, - // HOSTS_UNBALANCED = 2, - - w.tag_u8(b"BVAL", 1); - }); - - // Host viability rule status - w.group(b"HVRD", |w| { - // Host viability values - // CONNECTION_ASSURED = 0, - // CONNECTION_LIKELY = 1, - // CONNECTION_FEASIBLE = 2, - // CONNECTION_UNLIKELY = 3, - - w.tag_zero(b"VVAL"); - }); - - // Ping site rule status - w.group(b"PSRS", |_| {}); - - // Rank rule status - w.group(b"RRDA", |w| { - // Matched rank flags - w.tag_zero(b"RVAL"); - }); - - // Team size rule status - w.group(b"TSRS", |w| { - // Max team size accepted - w.tag_zero(b"TMAX"); - // Min team size accepted - w.tag_zero(b"TMIN"); - }); - - // UED rule status - w.tag_map_empty(b"GRDA", TdfType::String, TdfType::Group); - // Virtual game rule status - w.group(b"VGRS", |w| w.tag_zero(b"VVAL")); - }); - - w.tag_owned(b"MSID", self.player_id); - w.tag_owned(b"USID", self.player_id); - } -} diff --git a/src/services/leaderboard/mod.rs b/src/services/leaderboard/mod.rs index f6a97d21..44ee43a1 100644 --- a/src/services/leaderboard/mod.rs +++ b/src/services/leaderboard/mod.rs @@ -5,20 +5,22 @@ use crate::{ entities::{Player, PlayerData}, DatabaseConnection, DbResult, }, - utils::parsing::{KitNameDeployed, PlayerClass}, + utils::{ + parsing::{KitNameDeployed, PlayerClass}, + types::PlayerID, + }, }; -use interlink::prelude::*; +use futures_util::future::BoxFuture; use log::{debug, error}; use sea_orm::{EntityTrait, PaginatorTrait, QueryOrder}; use std::{collections::HashMap, sync::Arc, time::Instant}; -use tokio::task::JoinSet; +use tokio::{sync::RwLock, task::JoinSet}; pub mod models; -#[derive(Service)] pub struct Leaderboard { /// Map between the group types and the actual leaderboard group content - groups: HashMap, + groups: RwLock>, } /// Extra state wrapper around a leaderboard group which @@ -31,89 +33,60 @@ struct GroupState { group: Arc, } -/// Message for requesting access to a leaderborad -/// of the specific leaderboard type -#[derive(Message)] -#[msg(rtype = "Arc")] -pub struct QueryMessage(pub LeaderboardType, pub DatabaseConnection); - -impl Handler for Leaderboard { - type Response = Fr; - - fn handle(&mut self, msg: QueryMessage, ctx: &mut ServiceContext) -> Self::Response { - let ty = msg.0; +impl Leaderboard { + /// Starts a new leaderboard service + pub fn new() -> Leaderboard { + Leaderboard { + groups: Default::default(), + } + } - // If the group already exists and is not expired we can respond with it - if let Some(group) = self.groups.get_mut(&ty) { - let inner = &group.group; + pub async fn query( + &self, + ty: LeaderboardType, + db: &DatabaseConnection, + ) -> Arc { + { + let groups = &mut *self.groups.write().await; + // If the group already exists and is not expired we can respond with it + if let Some(group) = groups.get_mut(&ty) { + let inner = &group.group; + + // Response with current values if the group isn't expired or is computing + if group.computing || !inner.is_expired() { + // Value is not expired respond immediately + return inner.clone(); + } - // Response with current values if the group isn't expired or is computing - if group.computing || !inner.is_expired() { - // Value is not expired respond immediately - return Fr::ready(inner.clone()); + // Mark the group as currently being computed + group.computing = true; + } else { + // Create dummy empty group to hand out while computing + let dummy = GroupState { + computing: true, + group: Arc::new(LeaderboardGroup::dummy()), + }; + groups.insert(ty, dummy); } - - // Mark the group as currently being computed - group.computing = true; - } else { - // Create dummy empty group to hand out while computing - let dummy = GroupState { - computing: true, - group: Arc::new(LeaderboardGroup::dummy()), - }; - self.groups.insert(ty, dummy); } - let link = ctx.link(); + // Compute new leaderboard values + let values = Self::compute(&ty, db).await; + let group = Arc::new(LeaderboardGroup::new(values)); - Fr::new(Box::pin(async move { - // Compute new leaderboard values - let values = Self::compute(&ty, msg.1).await; - let group = Arc::new(LeaderboardGroup::new(values)); - - // Store the group and respond to the request - let _ = link.do_send(SetGroupMessage { - group: group.clone(), + // Store the updated group + { + let groups = &mut *self.groups.write().await; + groups.insert( ty, - }); - - group - })) - } -} - -/// Message used internally to update group state with -/// a new group value once a leaderboard has been -/// computed -#[derive(Message)] -struct SetGroupMessage { - /// The leaderboard type to set - ty: LeaderboardType, - /// The new leaderboard value - group: Arc, -} - -impl Handler for Leaderboard { - type Response = (); - - fn handle(&mut self, msg: SetGroupMessage, _ctx: &mut ServiceContext) -> Self::Response { - self.groups.insert( - msg.ty, - GroupState { - computing: false, - group: msg.group, - }, - ); - } -} + GroupState { + computing: false, + group: group.clone(), + }, + ); + } -impl Leaderboard { - /// Starts a new leaderboard service - pub fn start() -> Link { - let this = Leaderboard { - groups: Default::default(), - }; - this.start() + group } /// Computes the ranking values for the provided `ty` this consists of @@ -122,7 +95,7 @@ impl Leaderboard { /// on their value. /// /// `ty` The leaderboard type - async fn compute(ty: &LeaderboardType, db: DatabaseConnection) -> Box<[LeaderboardEntry]> { + async fn compute(ty: &LeaderboardType, db: &DatabaseConnection) -> Box<[LeaderboardEntry]> { let start_time = Instant::now(); // The amount of players to process in each database request @@ -134,7 +107,7 @@ impl Leaderboard { let mut paginator = players::Entity::find() .order_by_asc(players::Column::Id) - .paginate(&db, BATCH_COUNT); + .paginate(db, BATCH_COUNT); // Function pointer to the computing function for the desired type let fun: fn(DatabaseConnection, Player) -> Lf = match ty { @@ -240,9 +213,7 @@ fn compute_n7_player(db: DatabaseConnection, player: Player) -> Lf { /// `player` The player to rank fn compute_cp_player(db: DatabaseConnection, player: Player) -> Lf { Box::pin(async move { - let value = PlayerData::get_challenge_points(&db, player.id) - .await - .unwrap_or(0); + let value = get_challenge_points(&db, player.id).await.unwrap_or(0); Ok(LeaderboardEntry { player_id: player.id, player_name: player.display_name.into_boxed_str(), @@ -252,3 +223,13 @@ fn compute_cp_player(db: DatabaseConnection, player: Player) -> Lf { }) }) } + +async fn get_challenge_points(db: &DatabaseConnection, player_id: PlayerID) -> Option { + let list = PlayerData::get(db, player_id, "Completion") + .await + .ok()?? + .value; + let part = list.split(',').nth(1)?; + let value: u32 = part.parse().ok()?; + Some(value) +} diff --git a/src/services/leaderboard/models.rs b/src/services/leaderboard/models.rs index 89c66f91..07c8cd94 100644 --- a/src/services/leaderboard/models.rs +++ b/src/services/leaderboard/models.rs @@ -1,12 +1,9 @@ use crate::utils::types::PlayerID; -use serde::Serialize; -use std::{ - fmt::Display, - time::{Duration, SystemTime}, -}; +use serde::{Deserialize, Serialize}; +use std::time::{Duration, SystemTime}; /// Structure for an entry in a leaderboard group -#[derive(Serialize, Clone)] +#[derive(Serialize)] pub struct LeaderboardEntry { /// The ID of the player this entry is for pub player_id: PlayerID, @@ -55,30 +52,37 @@ impl LeaderboardGroup { now.ge(&self.expires) } + /// Checks whether there are more items after the provided offset and size + pub fn has_more(&self, start: usize, count: usize) -> bool { + let length = self.values.len(); + start + count < length + } + /// Gets a normal collection of leaderboard entries at the start offset of the /// provided count. Will return the slice of entires as well as whether there are /// more entries after the desired offset /// /// `start` The start offset index /// `count` The number of leaderboard entries - pub fn get_normal(&self, start: usize, count: usize) -> Option<(&[LeaderboardEntry], bool)> { - let values = &self.values; - let values_len = values.len(); - - // The index to stop at - let end_index = (start + count).min(values_len); - - values - .get(start..end_index) - .map(|value| (value, values_len > end_index)) + pub fn get_normal(&self, start: usize, count: usize) -> Option<&[LeaderboardEntry]> { + let end_index = (start + count).min(self.values.len()); + self.values.get(start..end_index) } /// Gets a leaderboard entry for the provided player ID if one is present /// /// `player_id` The ID of the player to find the entry for pub fn get_entry(&self, player_id: PlayerID) -> Option<&LeaderboardEntry> { - let values = &self.values; - values.iter().find(|value| value.player_id == player_id) + self.values + .iter() + .find(|value| value.player_id == player_id) + } + + pub fn get_filtered(&self, players: &[PlayerID]) -> Vec<&LeaderboardEntry> { + self.values + .iter() + .filter(move |entry| players.contains(&entry.player_id)) + .collect() } /// Gets a collection of leaderboard entries centered on the provided player with @@ -87,70 +91,48 @@ impl LeaderboardGroup { /// `player_id` The ID of the player to center on /// `count` The total number of players to center on pub fn get_centered(&self, player_id: PlayerID, count: usize) -> Option<&[LeaderboardEntry]> { - let values = &self.values; - let values_len = values.len(); + if count == 0 { + return None; + } + // The number of items before the center index let before = if count % 2 == 0 { - count / 2 + 1 + (count / 2).saturating_add(1) } else { count / 2 }; + // The number of items after the center index let after = count / 2; // The index of the centered player - let player_index = values + let player_index = self + .values .iter() .position(|value| value.player_id == player_id)?; // The index of the first item - let start_index = player_index - before.min(player_index); + let start_index = player_index.saturating_sub(before).min(player_index); // The index of the last item - let end_index = (player_index + after).min(values_len); + let end_index = player_index.saturating_add(after).min(self.values.len()); - values.get(start_index..end_index) + self.values.get(start_index..end_index) } } /// Type of leaderboard entity -#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq, Deserialize)] pub enum LeaderboardType { /// Leaderboard based on the player N7 ratings + #[serde(rename = "n7")] N7Rating, /// Leaderboard based on the player challenge point number + #[serde(rename = "cp")] ChallengePoints, } -impl Display for LeaderboardType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(match self { - Self::N7Rating => "N7 Rating", - Self::ChallengePoints => "Challenge Points", - }) - } -} - -impl LeaderboardType { - /// Attempts to parse the leaderboard type from the provided value - /// - /// `value` The value to attempt to parse from - pub fn try_parse(value: &str) -> Option { - if value.eq_ignore_ascii_case("n7") { - Some(LeaderboardType::N7Rating) - } else if value.eq_ignore_ascii_case("cp") { - Some(LeaderboardType::ChallengePoints) - } else { - None - } - } - - /// Gets the leaderboard type from the value provided - /// by a Mass Effect client this would be either N7Rating - /// or ChallangePoints along with the locale which in this - /// case is ignored - /// - /// `value` The value to parse from - pub fn from_value(value: &str) -> Self { +impl From<&str> for LeaderboardType { + fn from(value: &str) -> Self { if value.starts_with("N7Rating") { Self::N7Rating } else { diff --git a/src/services/retriever/mod.rs b/src/services/retriever/mod.rs index 26f5dfe4..33e5c672 100644 --- a/src/services/retriever/mod.rs +++ b/src/services/retriever/mod.rs @@ -1,54 +1,45 @@ //! Retriever system for connecting and retrieving data from the official //! Mass Effect 3 servers. -use std::{ - fmt::{Debug, Display}, - ops::Add, - time::{Duration, SystemTime}, -}; - use self::origin::OriginFlowService; use crate::{ config::RetrieverConfig, - session::packet::{Packet, PacketCodec, PacketDebug, PacketHeader, PacketType}, - utils::{ - components::redirector, + session::{ models::{InstanceDetails, InstanceNet, Port}, + packet::{FireFrame, FrameType, Packet, PacketCodec, PacketDebug}, }, + utils::components::redirector, }; - use blaze_ssl_async::{stream::BlazeStream, BlazeError}; use futures_util::{SinkExt, StreamExt}; -use interlink::prelude::*; use log::{debug, error, log_enabled}; use models::InstanceRequest; use origin::OriginFlow; use reqwest; use serde::Deserialize; +use std::{ + fmt::Display, + ops::Add, + time::{Duration, SystemTime}, +}; use tdf::{DecodeError, TdfDeserialize, TdfSerialize}; use thiserror::Error; -use tokio::io; +use tokio::{io, sync::RwLock}; use tokio_util::codec::Framed; mod models; - pub mod origin; /// Structure for the retrievier system which contains the host address /// for the official game server in order to make further connections -#[derive(Service)] pub struct Retriever { // Optional official instance if fetching is possible - instance: Option, + instance: RwLock>, /// Optional service for creating origin flows if enabled origin_flow: Option, } -#[derive(Message)] -#[msg(rtype = "Result")] -pub struct GetOriginFlow; - #[derive(Debug, Error)] pub enum GetFlowError { #[error("Retriever is disabled or unavailable")] @@ -61,45 +52,6 @@ pub enum GetFlowError { OriginDisabled, } -impl Handler for Retriever { - type Response = Sfr; - - fn handle( - &mut self, - _msg: GetOriginFlow, - _ctx: &mut interlink::service::ServiceContext, - ) -> Self::Response { - Sfr::new(move |act: &mut Retriever, _ctx| { - Box::pin(async move { - let mut instance = act.instance.as_ref().ok_or(GetFlowError::Unavailable)?; - - // Obtain a new instance if the current one is expired - if instance.expiry < SystemTime::now() { - debug!("Current official instance is outdated.. retrieving a new instance"); - - instance = match OfficialInstance::obtain().await { - Ok(value) => act.instance.insert(value), - Err(err) => { - act.instance = None; - error!("Official server instance expired but failed to obtain new instance: {}", err); - return Err(GetFlowError::Instance); - } - }; - } - - let session = instance.session().await.ok_or(GetFlowError::Session)?; - let flow = act - .origin_flow - .as_ref() - .ok_or(GetFlowError::OriginDisabled)? - .create(session); - - Ok(flow) - }) - }) - } -} - /// Connection details for an official server instance struct OfficialInstance { /// The host address of the official server @@ -247,7 +199,7 @@ impl Retriever { /// ip address of the gosredirector.ea.com host and then creates a /// connection to the redirector server and obtains the IP and Port /// of the Official server. - pub async fn start(config: RetrieverConfig) -> Link { + pub async fn start(config: RetrieverConfig) -> Retriever { let instance = if config.enabled { match OfficialInstance::obtain().await { Ok(value) => Some(value), @@ -268,12 +220,52 @@ impl Retriever { None }; - let this = Retriever { - instance, + Retriever { + instance: RwLock::new(instance), origin_flow, + } + } + + pub async fn origin_flow(&self) -> Result { + let flow = self + .origin_flow + .as_ref() + .ok_or(GetFlowError::OriginDisabled)?; + + let read_guard = self.instance.read().await; + let instance = read_guard.as_ref().ok_or(GetFlowError::Unavailable)?; + let is_expired = instance.expiry < SystemTime::now(); + + let guard = if is_expired { + // Drop the read instance and guard + let _ = instance; + drop(read_guard); + + debug!("Current official instance is outdated.. retrieving a new instance"); + let mut write_guard = self.instance.write().await; + + let official = match OfficialInstance::obtain().await { + Ok(value) => Some(value), + Err(err) => { + error!( + "Official server instance expired but failed to obtain new instance: {}", + err + ); + None + } + }; + + *write_guard = official; + + write_guard.downgrade() + } else { + read_guard }; - this.start() + let instance = guard.as_ref().ok_or(GetFlowError::Instance)?; + let session = instance.session().await.ok_or(GetFlowError::Session)?; + + Ok(flow.create(session)) } } @@ -342,13 +334,13 @@ impl OfficialSession { ) -> RetrieverResult { let request = Packet::request(self.id, component, command, contents); - debug_log_packet(&request, "Sending to Official"); - let header = request.header; + debug_log_packet(&request, "Send"); + let frame = request.frame.clone(); self.stream.send(request).await?; self.id += 1; - self.expect_response(&header).await + self.expect_response(&frame).await } /// Writes a request packet and waits until the response packet is @@ -371,8 +363,8 @@ impl OfficialSession { command: u16, ) -> RetrieverResult { let request = Packet::request_empty(self.id, component, command); - debug_log_packet(&request, "Sent to Official"); - let header = request.header; + debug_log_packet(&request, "Send"); + let header = request.frame.clone(); self.stream.send(request).await?; self.id += 1; self.expect_response(&header).await @@ -380,21 +372,23 @@ impl OfficialSession { /// Waits for a response packet to be recieved any notification packets /// that are recieved are handled in the handle_notify function. - async fn expect_response(&mut self, request: &PacketHeader) -> RetrieverResult { + async fn expect_response(&mut self, request: &FireFrame) -> RetrieverResult { loop { let response = match self.stream.next().await { Some(value) => value?, None => return Err(RetrieverError::EarlyEof), }; - debug_log_packet(&response, "Received from Official"); - let header = &response.header; - - if let PacketType::Response = header.ty { - if header.path_matches(request) { - return Ok(response); + debug_log_packet(&response, "Receive"); + let header = &response.frame; + + match &header.ty { + FrameType::Response => { + if header.path_matches(request) { + return Ok(response); + } } - } else if let PacketType::Error = header.ty { - return Err(RetrieverError::Packet(ErrorPacket(response))); + FrameType::Error => return Err(RetrieverError::Packet(ErrorPacket(response))), + _ => {} } } } @@ -411,11 +405,8 @@ fn debug_log_packet(packet: &Packet, action: &str) { if !log_enabled!(log::Level::Debug) { return; } - let debug = PacketDebug { - packet, - minified: false, - }; - debug!("\n{}\n{:?}", action, debug); + let debug = PacketDebug { packet }; + debug!("\nOfficial: {}\n{:?}", action, debug); } /// Wrapping structure for packets to allow them to be @@ -427,7 +418,7 @@ impl std::error::Error for ErrorPacket {} impl Display for ErrorPacket { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:#X}", self.0.header.error) + write!(f, "{:#X}", self.0.frame.error) } } diff --git a/src/services/sessions/mod.rs b/src/services/sessions/mod.rs index 056abc44..d679ef3f 100644 --- a/src/services/sessions/mod.rs +++ b/src/services/sessions/mod.rs @@ -1,56 +1,36 @@ //! Service for storing links to all the currenly active //! authenticated sessions on the server -use crate::{session::Session, utils::types::PlayerID}; -use argon2::password_hash::rand_core::{OsRng, RngCore}; +use crate::utils::hashing::IntHashMap; +use crate::utils::types::PlayerID; +use crate::{session::SessionLink, utils::signing::SigningKey}; use base64ct::{Base64UrlUnpadded, Encoding}; -use interlink::prelude::*; -use interlink::service::ServiceContext; -use log::error; -use ring::hmac::{self, Key, HMAC_SHA256}; -use std::collections::HashMap; -use std::{ - path::Path, - time::{Duration, SystemTime, UNIX_EPOCH}, -}; -use thiserror::Error; -use tokio::{ - fs::{write, File}, - io::{self, AsyncReadExt}, -}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::sync::RwLock; /// Service for storing links to authenticated sessions and /// functionality for authenticating sessions -#[derive(Service)] pub struct Sessions { /// Map of the authenticated players to their session links - values: HashMap>, + sessions: RwLock>, /// HMAC key used for computing signatures - key: Key, + key: SigningKey, } -/// Message for creating a new authentication token for the provided -/// [PlayerID] -#[derive(Message)] -#[msg(rtype = "String")] -pub struct CreateTokenMessage(pub PlayerID); - -/// Message for verifying the provided token -#[derive(Message)] -#[msg(rtype = "Result")] -pub struct VerifyTokenMessage(pub String); - -impl Handler for Sessions { - type Response = Mr; +impl Sessions { + /// Expiry time for tokens + const EXPIRY_TIME: Duration = Duration::from_secs(60 * 60 * 24 * 30 /* 30 Days */); - fn handle( - &mut self, - msg: CreateTokenMessage, - _ctx: &mut ServiceContext, - ) -> Self::Response { - let id = msg.0; + /// Starts a new service returning its link + pub fn new(key: SigningKey) -> Self { + Self { + sessions: Default::default(), + key, + } + } + pub fn create_token(&self, player_id: PlayerID) -> String { // Compute expiry timestamp let exp = SystemTime::now() .checked_add(Self::EXPIRY_TIME) @@ -61,7 +41,7 @@ impl Handler for Sessions { // Create encoded token value let mut data = [0u8; 12]; - data[..4].copy_from_slice(&id.to_be_bytes()); + data[..4].copy_from_slice(&player_id.to_be_bytes()); data[4..].copy_from_slice(&exp.to_be_bytes()); let data = &data; @@ -69,82 +49,14 @@ impl Handler for Sessions { let msg = Base64UrlUnpadded::encode_string(data); // Create a signature from the raw message bytes - let sig = hmac::sign(&self.key, data); + let sig = self.key.sign(data); let sig = Base64UrlUnpadded::encode_string(sig.as_ref()); // Join the message and signature to create the token - let token = [msg, sig].join("."); - - Mr(token) - } -} - -impl Handler for Sessions { - type Response = Mr; - - fn handle( - &mut self, - msg: VerifyTokenMessage, - _ctx: &mut ServiceContext, - ) -> Self::Response { - Mr(self.verify(&msg.0)) + [msg, sig].join(".") } -} - -impl Sessions { - /// Starts a new service returning its link - pub async fn start() -> Link { - let key = Self::create_key().await; - let this = Self { - values: Default::default(), - key, - }; - this.start() - } - - /// Expiry time for tokens - const EXPIRY_TIME: Duration = Duration::from_secs(60 * 60 * 24 * 30 /* 30 Days */); - - /// Creates a new instance of the tokens structure loading/creating - /// the secret bytes that are used for signing authentication tokens - pub async fn create_key() -> Key { - // Path to the file containing the server secret value - let secret_path = Path::new("data/secret.bin"); - - // The bytes of the secret - let mut secret = [0u8; 64]; - // Attempt to load existing secret - if secret_path.exists() { - if let Err(err) = Self::read_secret(&mut secret, secret_path).await { - error!("Failed to read secrets file: {:?}", err); - } else { - return Key::new(HMAC_SHA256, &secret); - } - } - - // Generate random secret bytes - OsRng.fill_bytes(&mut secret); - - // Save the created secret - if let Err(err) = write(secret_path, &secret).await { - error!("Failed to write secrets file: {:?}", err); - } - - Key::new(HMAC_SHA256, &secret) - } - - /// Reads the secret from the secrets file into the provided buffer - /// returning whether the entire secret could be read - /// - /// `out` The buffer to read the secret to - async fn read_secret(out: &mut [u8], path: &Path) -> io::Result<()> { - let mut file = File::open(path).await?; - file.read_exact(out).await?; - Ok(()) - } - - fn verify(&self, token: &str) -> Result { + pub fn verify_token(&self, token: &str) -> Result { // Split the token parts let (msg_raw, sig_raw) = match token.split_once('.') { Some(value) => value, @@ -153,14 +65,14 @@ impl Sessions { // Decode the 12 byte token message let mut msg = [0u8; 12]; - Base64UrlUnpadded::decode(msg_raw, &mut msg)?; + Base64UrlUnpadded::decode(msg_raw, &mut msg).map_err(|_| VerifyError::Invalid)?; // Decode 32byte signature (SHA256) let mut sig = [0u8; 32]; - Base64UrlUnpadded::decode(sig_raw, &mut sig)?; + Base64UrlUnpadded::decode(sig_raw, &mut sig).map_err(|_| VerifyError::Invalid)?; // Verify the signature - if hmac::verify(&self.key, &msg, &sig).is_err() { + if !self.key.verify(&msg, &sig) { return Err(VerifyError::Invalid); } @@ -185,76 +97,48 @@ impl Sessions { Ok(id) } -} -/// Message for removing players from the authenticated -/// sessions list -#[derive(Message)] -pub struct RemoveMessage { - /// The ID of the player to remove - pub player_id: PlayerID, -} - -/// Message for adding a player to the authenticated -/// sessions list -#[derive(Message)] -pub struct AddMessage { - /// The ID of the player the link belongs to - pub player_id: PlayerID, - /// The link to the player session - pub link: Link, -} - -/// Message for finding a session by a player ID returning -/// a link to the player if one is found -#[derive(Message)] -#[msg(rtype = "Option>")] -pub struct LookupMessage { - /// The ID of the player to lookup - pub player_id: PlayerID, -} - -/// Handle messages to add authenticated sessions -impl Handler for Sessions { - type Response = (); - - fn handle(&mut self, msg: AddMessage, _ctx: &mut ServiceContext) -> Self::Response { - self.values.insert(msg.player_id, msg.link); + pub async fn remove_session(&self, player_id: PlayerID) { + let sessions = &mut *self.sessions.write().await; + sessions.remove(&player_id); } -} - -/// Handle messages to remove authenticated sessions -impl Handler for Sessions { - type Response = (); - fn handle(&mut self, msg: RemoveMessage, _ctx: &mut ServiceContext) -> Self::Response { - self.values.remove(&msg.player_id); + pub async fn add_session(&self, player_id: PlayerID, link: SessionLink) { + let sessions = &mut *self.sessions.write().await; + sessions.insert(player_id, link); } -} - -/// Handle messages to lookup authenticated sessions -impl Handler for Sessions { - type Response = Mr; - fn handle(&mut self, msg: LookupMessage, _ctx: &mut ServiceContext) -> Self::Response { - let value = self.values.get(&msg.player_id).cloned(); - Mr(value) + pub async fn lookup_session(&self, player_id: PlayerID) -> Option { + let sessions = &*self.sessions.read().await; + sessions.get(&player_id).cloned() } } /// Errors that can occur while verifying a token -#[derive(Debug, Error)] +#[derive(Debug)] pub enum VerifyError { /// The token is expired - #[error("Expired token")] Expired, /// The token is invalid - #[error("Invalid token")] Invalid, } -impl From for VerifyError { - fn from(_: base64ct::Error) -> Self { - Self::Invalid +#[cfg(test)] +mod test { + use crate::utils::signing::SigningKey; + + use super::Sessions; + + /// Tests that tokens can be created and verified correctly + #[test] + fn test_token() { + let (key, _) = SigningKey::generate(); + let sessions = Sessions::new(key); + + let player_id = 32; + let token = sessions.create_token(player_id); + let claim = sessions.verify_token(&token).unwrap(); + + assert_eq!(player_id, claim) } } diff --git a/src/session/mod.rs b/src/session/mod.rs index 93732794..728e89e2 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -2,475 +2,424 @@ //! data such as player data for when they become authenticated and //! networking data. +use self::{ + models::{ + game_manager::RemoveReason, + user_sessions::{ + HardwareFlags, LookupResponse, NotifyUserAdded, NotifyUserRemoved, NotifyUserUpdated, + UserDataFlags, UserIdentification, UserSessionExtendedData, + UserSessionExtendedDataUpdate, + }, + }, + packet::{Packet, PacketCodec, PacketDebug}, + router::BlazeRouter, +}; use crate::{ database::entities::Player, - middleware::blaze_upgrade::BlazeScheme, services::{ - game::{ - manager::{GameManager, GetGameMessage, RemoveQueueMessage}, - models::RemoveReason, - GamePlayer, RemovePlayerMessage, - }, - sessions::{AddMessage, RemoveMessage, Sessions}, + game::{Game, GameRef}, + sessions::Sessions, }, + session::models::{NetworkAddress, QosNetworkData}, utils::{ - components::{self, game_manager::GAME_TYPE, user_sessions}, - models::{NetData, NetworkAddress, Port, QosNetworkData, UpdateExtDataAttr}, + components::{component_key, user_sessions, DEBUG_IGNORED_PACKETS}, types::{GameID, PlayerID, SessionID}, }, }; - -use interlink::prelude::*; -use log::{debug, log_enabled}; -use std::{fmt::Debug, io, net::Ipv4Addr, sync::Arc}; -use tdf::{ObjectId, TdfSerialize, TdfType, TdfTyped}; - -use self::{ - packet::{Packet, PacketDebug}, - router::BlazeRouter, +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, StreamExt, +}; +use hyper::upgrade::Upgraded; +use log::{debug, log_enabled, warn}; +use serde::Serialize; +use std::{ + fmt::Debug, + net::Ipv4Addr, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + time::Duration, }; +use tokio::{ + sync::{mpsc, RwLock}, + task::JoinSet, +}; +use tokio_util::codec::Framed; pub mod models; pub mod packet; pub mod router; pub mod routes; -/// Structure for storing a client session. This includes the -/// network stream for the client along with global state and -/// other session state. +pub type SessionLink = Arc; + pub struct Session { - /// Unique identifier for this session. id: SessionID, - - /// Connection socket addr addr: Ipv4Addr, - - /// Packet writer sink for the session - writer: SinkLink, - - /// The session scheme - host_target: SessionHostTarget, - - /// Data associated with this session - data: SessionData, - + writer: mpsc::UnboundedSender, + data: RwLock>, router: Arc, - - game_manager: Link, - sessions: Link, + sessions: Arc, } -#[derive(Default, Clone)] -pub struct SessionData { - /// If the session is authenticated it will have a linked - /// player model from the database - player: Option, - /// Networking information - net: NetData, - /// The id of the game if connected to one - game: Option, +pub struct SessionExtData { + player: Arc, + net: Arc, + game: Option, + subscribers: Vec<(PlayerID, SessionLink)>, } -impl Service for Session { - fn stopping(&mut self) { - self.clear_auth(); - debug!("Session stopped (SID: {})", self.id); - } +struct SessionGameData { + game_id: GameID, + game_ref: Arc>, } -pub type SessionLink = Link; - -#[derive(Message)] -#[msg(rtype = "Option")] -pub struct GetPlayerMessage; - -impl Handler for Session { - type Response = Mr; - - fn handle( - &mut self, - _msg: GetPlayerMessage, - _ctx: &mut ServiceContext, - ) -> Self::Response { - Mr(self.data.player.clone()) +impl SessionExtData { + pub fn new(player: Player) -> Self { + Self { + player: Arc::new(player), + net: Default::default(), + game: Default::default(), + subscribers: Default::default(), + } } -} - -#[derive(Message)] -#[msg(rtype = "SessionHostTarget")] -pub struct GetHostTarget; - -impl Handler for Session { - type Response = Mr; - fn handle(&mut self, _msg: GetHostTarget, _ctx: &mut ServiceContext) -> Self::Response { - Mr(self.host_target.clone()) + fn ext(&self) -> UserSessionExtendedData { + UserSessionExtendedData { + net: self.net.clone(), + game: self.game.as_ref().map(|game| game.game_id), + } } -} -#[derive(Clone)] -pub struct SessionHostTarget { - pub scheme: BlazeScheme, - pub host: Box, - pub port: Port, - pub local_http: bool, -} - -#[derive(Message)] -#[msg(rtype = "Option")] -pub struct GetPlayerIdMessage; + fn add_subscriber(&mut self, player_id: PlayerID, subscriber: SessionLink) { + // Create the details packets + let added_notify = Packet::notify( + user_sessions::COMPONENT, + user_sessions::USER_ADDED, + NotifyUserAdded { + session_data: self.ext(), + user: UserIdentification::from_player(&self.player), + }, + ); -impl Handler for Session { - type Response = Mr; + // Create update notifying the user of the subscription + let update_notify = Packet::notify( + user_sessions::COMPONENT, + user_sessions::USER_UPDATED, + NotifyUserUpdated { + flags: UserDataFlags::SUBSCRIBED | UserDataFlags::ONLINE, + player_id: self.player.id, + }, + ); - fn handle( - &mut self, - _msg: GetPlayerIdMessage, - _ctx: &mut ServiceContext, - ) -> Self::Response { - Mr(self.data.player.as_ref().map(|value| value.id)) + self.subscribers.push((player_id, subscriber.clone())); + subscriber.push(added_notify); + subscriber.push(update_notify); } -} -#[derive(Message)] -#[msg(rtype = "Option")] -pub struct GetGamePlayerMessage; - -impl Handler for Session { - type Response = Mr; - fn handle( - &mut self, - _msg: GetGamePlayerMessage, - ctx: &mut ServiceContext, - ) -> Self::Response { - let player = match self.data.player.clone() { + fn remove_subscriber(&mut self, player_id: PlayerID) { + let index = match self.subscribers.iter().position(|(id, _)| player_id.eq(id)) { Some(value) => value, - None => return Mr(None), + None => return, }; - Mr(Some(GamePlayer::new( - player, - self.data.net.clone(), - ctx.link(), - ))) + + let (_, subscriber) = self.subscribers.swap_remove(index); + + // Create the details packets + let removed_notify = Packet::notify( + user_sessions::COMPONENT, + user_sessions::USER_REMOVED, + NotifyUserRemoved { player_id }, + ); + + subscriber.push(removed_notify) } -} -#[derive(Message)] -pub struct SetPlayerMessage(pub Option); + /// Publishes changes of the session data to all the + /// subscribed session links + fn publish_update(&self) { + let packet = Packet::notify( + user_sessions::COMPONENT, + user_sessions::USER_SESSION_EXTENDED_DATA_UPDATE, + UserSessionExtendedDataUpdate { + user_id: self.player.id, + data: self.ext(), + }, + ); -impl Handler for Session { - type Response = (); - fn handle(&mut self, msg: SetPlayerMessage, ctx: &mut ServiceContext) -> Self::Response { - // Clear the current authentication - self.clear_auth(); - - // If we are setting a new player - if let Some(player) = msg.0 { - // Add the session to authenticated sessions - let _ = self.sessions.do_send(AddMessage { - player_id: player.id, - link: ctx.link(), - }); - self.data.player = Some(player); + for (_, subscriber) in &self.subscribers { + subscriber.push(packet.clone()); } } } -/// Extension for links to push packets for session links -pub trait PushExt { - fn push(&self, packet: Packet); +#[derive(Debug, Default, Clone, Serialize)] +pub struct NetData { + pub addr: NetworkAddress, + pub qos: QosNetworkData, + pub hardware_flags: HardwareFlags, } -impl PushExt for Link { - fn push(&self, packet: Packet) { - let _ = self.do_send(WriteMessage(packet)); +impl NetData { + // Re-creates the current net data using the provided address and QOS data + pub fn with_basic(&self, addr: NetworkAddress, qos: QosNetworkData) -> Self { + Self { + addr, + qos, + hardware_flags: self.hardware_flags, + } + } + + /// Re-creates the current net data using the provided hardware flags + pub fn with_hardware_flags(&self, flags: HardwareFlags) -> Self { + Self { + addr: self.addr.clone(), + qos: self.qos, + hardware_flags: flags, + } } } -#[derive(Message)] -#[msg(rtype = "Option")] -pub struct GetPlayerGameMessage; +static SESSION_IDS: AtomicU32 = AtomicU32::new(1); -impl Handler for Session { - type Response = Mr; +impl Session { + /// Max number of times to poll a session for shutdown before erroring + const MAX_RELEASE_ATTEMPTS: u8 = 5; - fn handle( - &mut self, - _msg: GetPlayerGameMessage, - _ctx: &mut ServiceContext, - ) -> Self::Response { - Mr(self.data.game) - } -} + pub fn start(io: Upgraded, addr: Ipv4Addr, router: Arc, sessions: Arc) { + // Obtain a session ID + let id = SESSION_IDS.fetch_add(1, Ordering::AcqRel); -#[derive(Message)] -#[msg(rtype = "Option")] -pub struct GetLookupMessage; + let framed = Framed::new(io, PacketCodec); + let (write, read) = framed.split(); + let (tx, rx) = mpsc::unbounded_channel(); -impl Handler for Session { - type Response = Mr; + let session = Arc::new(Self { + id, + writer: tx, + data: Default::default(), + addr, + router, + sessions, + }); - fn handle( - &mut self, - _msg: GetLookupMessage, - _ctx: &mut ServiceContext, - ) -> Self::Response { - let data = &self.data; - let player = match &data.player { - Some(value) => value, - None => return Mr(None), + let reader = SessionReader { + link: session.clone(), + inner: read, }; - let response = LookupResponse { - session_data: data.clone(), - player_id: player.id, - display_name: player.display_name.clone(), + let writer = SessionWriter { + link: session.clone(), + rx, + inner: write, }; - Mr(Some(response)) + tokio::spawn(reader.process()); + tokio::spawn(writer.process()); } -} -#[derive(Message)] -pub struct WriteMessage(pub Packet); + /// Handles routing a packet + async fn handle_packet(self: Arc, packet: Packet) { + let route_link = self.clone(); + let this = &*self; -impl Handler for Session { - type Response = (); + this.debug_log_packet("Receive", &packet).await; + let response = match this.router.handle(route_link, packet) { + // Await the handler response future + Ok(fut) => fut.await, - fn handle(&mut self, msg: WriteMessage, _ctx: &mut ServiceContext) -> Self::Response { - self.push(msg.0); + // Handle no handler for packet + Err(packet) => { + debug!("Missing packet handler"); + Packet::response_empty(&packet) + } + }; + // Push the response to the client + this.push(response); } -} -impl StreamHandler> for Session { - fn handle(&mut self, msg: io::Result, ctx: &mut ServiceContext) { - if let Ok(packet) = msg { - self.debug_log_packet("Read", &packet); - let addr = ctx.link(); - let router = self.router.clone(); - tokio::spawn(async move { - let response = match router.handle(addr.clone(), packet) { - // Await the handler response future - Ok(fut) => fut.await, - - // Handle no handler for packet - Err(packet) => { - debug!("Missing packet handler"); - Packet::response_empty(&packet) - } - }; - // Push the response to the client - addr.push(response); - }); - } else { - ctx.stop(); - } + /// Internal session stopped function called by the reader when + /// the connection is terminated, cleans up any references and + /// asserts only 1 strong reference exists + async fn stop(self: Arc) { + // Tell the write half to close and wait until its closed + _ = self.writer.send(WriteMessage::Close); + self.writer.closed().await; + + // Clear authentication + self.clear_player().await; + + let mut attempt: u8 = 1; + + let mut arc = self; + let session = loop { + if attempt > Self::MAX_RELEASE_ATTEMPTS { + let references = Arc::strong_count(&arc); + warn!( + "Failed to stop session {} there are still {} references to it", + arc.id, references + ); + return; + } + match Arc::try_unwrap(arc) { + Ok(value) => break value, + Err(value) => { + let wait = 5 * attempt as u64; + let references = Arc::strong_count(&value); + debug!( + "Session {} still has {} references to it, waiting {}s", + value.id, references, wait + ); + tokio::time::sleep(Duration::from_secs(wait)).await; + arc = value; + attempt += 1; + continue; + } + } + }; + + debug!("Session stopped (SID: {})", session.id); } -} -impl ErrorHandler for Session { - fn handle(&mut self, _err: io::Error, _ctx: &mut ServiceContext) -> ErrorAction { - ErrorAction::Continue + pub async fn add_subscriber(&self, player_id: PlayerID, subscriber: SessionLink) { + let data = &mut *self.data.write().await; + let data = match data { + Some(value) => value, + // TODO: Handle this as an error for unauthenticated + None => return, + }; + data.add_subscriber(player_id, subscriber); } -} -/// Message telling the session to inform the clients of -/// a change in session data -#[derive(Message)] -pub struct UpdateClientMessage; - -impl Handler for Session { - type Response = (); - - fn handle(&mut self, _msg: UpdateClientMessage, _ctx: &mut ServiceContext) { - if let Some(player) = &self.data.player { - let packet = Packet::notify( - user_sessions::COMPONENT, - user_sessions::SET_SESSION, - SetSession { - player_id: player.id, - session: &self.data, - }, - ); - self.push(packet); - } + pub async fn remove_subscriber(&self, player_id: PlayerID) { + let data = &mut *self.data.write().await; + let data = match data { + Some(value) => value, + // TODO: Handle this as an error for unauthenticated + None => return, + }; + data.remove_subscriber(player_id); } -} -#[derive(Message)] -#[msg(rtype = "Ipv4Addr")] -pub struct GetSocketAddrMessage; + pub async fn set_player(&self, player: Player) -> Arc { + // Clear the current authentication + self.clear_player().await; -impl Handler for Session { - type Response = Mr; + let data = &mut *self.data.write().await; + let data = data.insert(SessionExtData::new(player)); - fn handle( - &mut self, - _msg: GetSocketAddrMessage, - _ctx: &mut ServiceContext, - ) -> Self::Response { - Mr(self.addr) + data.player.clone() } -} -/// Creates a set session packet and sends it to all the -/// provided session links -#[derive(Message)] -pub struct InformSessions { - /// The link to send the set session to - pub links: Vec>, -} + /// Clears the current game returning the game data if + /// the player was in a game + /// + /// Called by the game itself when the player has been removed + pub async fn clear_game(&self) -> Option<(PlayerID, GameRef)> { + // Check that theres authentication + let data = &mut *self.data.write().await; + let data = data.as_mut()?; + let game = data.game.take(); + data.publish_update(); + let game = game?; -impl Handler for Session { - type Response = (); - - fn handle(&mut self, msg: InformSessions, _ctx: &mut ServiceContext) -> Self::Response { - if let Some(player) = &self.data.player { - let packet = Packet::notify( - user_sessions::COMPONENT, - user_sessions::SET_SESSION, - SetSession { - player_id: player.id, - session: &self.data, - }, - ); - - for link in msg.links { - link.push(packet.clone()); - } - } + Some((data.player.id, game.game_ref)) } -} - -/// Message to update the hardware flag of a session -#[derive(Message)] -pub struct HardwareFlagMessage { - /// The new value for the hardware flag - pub value: u16, -} -impl Handler for Session { - type Response = (); - - fn handle(&mut self, msg: HardwareFlagMessage, ctx: &mut ServiceContext) { - self.data.net.hardware_flags = msg.value; - - // Notify the client of the change via a message rather than - // directly so its sent after the response - let _ = ctx.shared_link().do_send(UpdateClientMessage); + /// Called to remove the player from its current game + pub async fn remove_from_game(&self) { + if let Some((player_id, game_ref)) = self.clear_game().await { + let game = &mut *game_ref.write().await; + game.remove_player(player_id, RemoveReason::PlayerLeft); + } } -} -#[derive(Message)] -pub struct NetworkInfoMessage { - pub address: NetworkAddress, - pub qos: QosNetworkData, -} + pub async fn clear_player(&self) { + self.remove_from_game().await; -impl Handler for Session { - type Response = (); + // Check that theres authentication + let data = &mut *self.data.write().await; + let data = match data { + Some(value) => value, + None => return, + }; - fn handle(&mut self, msg: NetworkInfoMessage, ctx: &mut ServiceContext) { - let net = &mut &mut self.data.net; - net.qos = msg.qos; - net.addr = msg.address; + // Existing sessions must be unsubscribed + data.subscribers.clear(); - // Notify the client of the change via a message rather than - // directly so its sent after the response - let _ = ctx.shared_link().do_send(UpdateClientMessage); + // Remove the session from the sessions service + self.sessions.remove_session(data.player.id).await; } -} - -#[derive(Message)] -pub struct SetGameMessage { - pub game: Option, -} - -impl Handler for Session { - type Response = (); - fn handle(&mut self, msg: SetGameMessage, ctx: &mut ServiceContext) { - self.data.game = msg.game; - - // Notify the client of the change via a message rather than - // directly so its sent after the response - let _ = ctx.shared_link().do_send(UpdateClientMessage); + pub async fn get_game(&self) -> Option<(GameID, GameRef)> { + let data = &*self.data.read().await; + data.as_ref() + .and_then(|value| value.game.as_ref()) + .map(|value| (value.game_id, value.game_ref.clone())) } -} - -/// Message to send the details of this session to -/// the provided session link -#[derive(Message)] -pub struct DetailsMessage { - pub link: Link, -} -impl Handler for Session { - type Response = (); + pub async fn get_lookup(&self) -> Option { + let data = &*self.data.read().await; + data.as_ref().map(|data| LookupResponse { + player: data.player.clone(), + extended_data: data.ext(), + }) + } - fn handle(&mut self, msg: DetailsMessage, _ctx: &mut ServiceContext) { - let player = match self.data.player.as_ref() { - Some(value) => value, - None => return, - }; + #[inline] + async fn update_data(&self, update: F) + where + F: FnOnce(&mut SessionExtData), + { + let data = &mut *self.data.write().await; + if let Some(data) = data { + update(data); + data.publish_update(); + } + } - // Create the details packets - let a = Packet::notify( - user_sessions::COMPONENT, - user_sessions::SESSION_DETAILS, - SessionUpdate { - session: self, - player_id: player.id, - display_name: &player.display_name, - }, - ); + pub async fn set_game(&self, game_id: GameID, game_ref: GameRef) { + // Set the current game + self.update_data(|data| { + // Remove the player from the game if they are already present in one + if let Some(game) = data.game.take() { + let player_id = data.player.id; + tokio::spawn(async move { + let game = &mut *game.game_ref.write().await; + game.remove_player(player_id, RemoveReason::PlayerLeft); + }); + } - let b = Packet::notify( - user_sessions::COMPONENT, - user_sessions::UPDATE_EXTENDED_DATA_ATTRIBUTE, - UpdateExtDataAttr { - flags: 0x3, - player_id: player.id, - }, - ); + data.game = Some(SessionGameData { game_id, game_ref }); + }) + .await; + } - // Push the message to the session link - msg.link.push(a); - msg.link.push(b); + #[inline] + pub async fn set_hardware_flags(&self, value: HardwareFlags) { + self.update_data(|data| { + data.net = Arc::new(data.net.with_hardware_flags(value)); + }) + .await; } -} -impl Session { - pub fn new( - id: SessionID, - host_target: SessionHostTarget, - writer: SinkLink, - addr: Ipv4Addr, - router: Arc, - game_manager: Link, - sessions: Link, - ) -> Self { - Self { - id, - writer, - data: SessionData::default(), - host_target, - addr, - router, - game_manager, - sessions, - } + #[inline] + pub async fn set_network_info(&self, address: NetworkAddress, qos: QosNetworkData) { + self.update_data(|data| { + data.net = Arc::new(data.net.with_basic(address, qos)); + }) + .await; } /// Pushes a new packet to the back of the packet buffer /// and sends a flush notification /// /// `packet` The packet to push to the buffer - pub fn push(&mut self, packet: Packet) { - self.debug_log_packet("Queued Write", &packet); - if self.writer.sink(packet).is_err() { - // TODO: Handle failing to send contents to writer - } + pub fn push(&self, packet: Packet) { + _ = self.writer.send(WriteMessage::Write(packet)) + // TODO: Handle failing to send contents to writer } /// Logs the contents of the provided packet to the debug output along with @@ -479,207 +428,95 @@ impl Session { /// `action` The name of the action this packet is undergoing. /// (e.g. Writing or Reading) /// `packet` The packet that is being logged - fn debug_log_packet(&self, action: &'static str, packet: &Packet) { + async fn debug_log_packet(&self, action: &'static str, packet: &Packet) { // Skip if debug logging is disabled if !log_enabled!(log::Level::Debug) { return; } - // Ping messages are ignored from debug logging as they are very frequent - let ignored = packet.header.component == components::util::COMPONENT - && (packet.header.command == components::util::PING - || packet.header.command == components::util::SUSPEND_USER_PING); - + let key = component_key(packet.frame.component, packet.frame.command); + let ignored = DEBUG_IGNORED_PACKETS.contains(&key); if ignored { return; } - let debug = SessionPacketDebug { + let data = &*self.data.read().await; + let debug_data = DebugSessionData { action, - packet, - session: self, - }; - - debug!("\n{:?}", debug); - } - - /// Removes the session from any connected games and the - /// matchmaking queue - pub fn remove_games(&mut self) { - // Don't attempt to remove if theres no active player - let player_id = match &self.data.player { - Some(value) => value.id, - None => return, + id: self.id, + data, }; + let debug_packet = PacketDebug { packet }; - if let Some(game_id) = self.data.game.take() { - let game_manager = self.game_manager.clone(); - // Remove the player from the game - tokio::spawn(async move { - // Obtain the current game - let game = match game_manager.send(GetGameMessage { game_id }).await { - Ok(Some(value)) => value, - _ => return, - }; - - // Send the remove message - let _ = game - .send(RemovePlayerMessage { - id: player_id, - reason: RemoveReason::PlayerLeft, - }) - .await; - }); - } else { - // Remove the player from matchmaking if present - let _ = self.game_manager.do_send(RemoveQueueMessage { player_id }); - } - } - - /// Removes the player from the authenticated sessions list - /// if the player is authenticated - pub fn clear_auth(&mut self) { - self.remove_games(); - - // Check that theres authentication - let player = match self.data.player.take() { - Some(value) => value, - None => return, - }; - - // Send the remove session message - let _ = self.sessions.do_send(RemoveMessage { - player_id: player.id, - }); + debug!("\n{:?}{:?}", debug_data, debug_packet); } } -/// Structure for wrapping session details around a debug -/// packet message for logging -struct SessionPacketDebug<'a> { +struct DebugSessionData<'a> { + id: SessionID, + data: &'a Option, action: &'static str, - packet: &'a Packet, - session: &'a Session, } -impl Debug for SessionPacketDebug<'_> { +impl Debug for DebugSessionData<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!(f, "Session {} Packet", self.action)?; + writeln!(f, "Session ({}): {}", self.id, self.action)?; - if let Some(player) = &self.session.data.player { + if let Some(data) = self.data.as_ref() { writeln!( f, - "Info: (Name: {}, ID: {}, SID: {})", - &player.display_name, &player.id, &self.session.id + "Auth ({}): (Name: {})", + data.player.id, &data.player.display_name, )?; - } else { - writeln!(f, "Info: ( SID: {})", &self.session.id)?; } - let header = &self.packet.header; - - let minified = (header.component == components::authentication::COMPONENT - && header.command == components::authentication::LIST_USER_ENTITLEMENTS_2) - || (header.component == components::util::COMPONENT - && (header.command == components::util::FETCH_CLIENT_CONFIG - || header.command == components::util::USER_SETTINGS_LOAD_ALL)); - - PacketDebug { - packet: self.packet, - minified, - } - .fmt(f) + Ok(()) } } -impl TdfSerialize for SessionData { - fn serialize(&self, w: &mut S) { - w.group_body(|w| { - w.tag_ref(b"ADDR", &self.net.addr); - w.tag_str(b"BPS", "ea-sjc"); - w.tag_str_empty(b"CTY"); - w.tag_var_int_list_empty(b"CVAR"); - - w.tag_map_tuples(b"DMAP", &[(0x70001, 0x409a)]); - - w.tag_u16(b"HWFG", self.net.hardware_flags); - - // Ping latency to the Quality of service servers - w.tag_list_slice(b"PSLM", &[0xfff0fff]); - - w.tag_ref(b"QDAT", &self.net.qos); - w.tag_u8(b"UATT", 0); - if let Some(game_id) = &self.game { - w.tag_list_slice(b"ULST", &[ObjectId::new(GAME_TYPE, *game_id as u64)]); - } - }); - } +// Writer for writing packets +struct SessionWriter { + inner: SplitSink, Packet>, + rx: mpsc::UnboundedReceiver, + link: SessionLink, } -impl TdfTyped for SessionData { - const TYPE: TdfType = TdfType::Group; +pub enum WriteMessage { + Write(Packet), + Close, } -/// Session update for a session other than ourselves -/// which contains the details for that session -struct SessionUpdate<'a> { - /// The session this update is for - session: &'a Session, - /// The player ID the update is for - player_id: PlayerID, - /// The display name of the player the update is - display_name: &'a str, -} +impl SessionWriter { + pub async fn process(mut self) { + while let Some(msg) = self.rx.recv().await { + let packet = match msg { + WriteMessage::Write(packet) => packet, + WriteMessage::Close => break, + }; -impl TdfSerialize for SessionUpdate<'_> { - fn serialize(&self, w: &mut S) { - w.tag_ref(b"DATA", &self.session.data); - - w.group(b"USER", |writer| { - writer.tag_owned(b"AID", self.player_id); - writer.tag_u32(b"ALOC", 0x64654445); - writer.tag_blob_empty(b"EXBB"); - writer.tag_u8(b"EXID", 0); - writer.tag_owned(b"ID", self.player_id); - writer.tag_str(b"NAME", self.display_name); - }); + self.link.debug_log_packet("Send", &packet).await; + if self.inner.send(packet).await.is_err() { + break; + } + } } } -pub struct LookupResponse { - session_data: SessionData, - player_id: PlayerID, - display_name: String, +struct SessionReader { + inner: SplitStream>, + link: SessionLink, } -impl TdfSerialize for LookupResponse { - fn serialize(&self, w: &mut S) { - w.tag_ref(b"EDAT", &self.session_data); - - w.tag_u8(b"FLGS", 2); +impl SessionReader { + pub async fn process(mut self) { + let mut tasks = JoinSet::new(); - w.group(b"USER", |w| { - w.tag_owned(b"AID", self.player_id); - w.tag_u32(b"ALOC", 0x64654445); - w.tag_blob_empty(b"EXBB"); - w.tag_u8(b"EXID", 0); - w.tag_owned(b"ID", self.player_id); - w.tag_str(b"NAME", &self.display_name); - }); - } -} - -/// Session update for ourselves -struct SetSession<'a> { - /// The session this update is for - session: &'a SessionData, - /// The player ID the update is for - player_id: PlayerID, -} + while let Some(Ok(packet)) = self.inner.next().await { + let link = self.link.clone(); + tasks.spawn(link.handle_packet(packet)); + } -impl TdfSerialize for SetSession<'_> { - fn serialize(&self, w: &mut S) { - w.tag_ref(b"DATA", self.session); - w.tag_owned(b"USID", self.player_id) + tasks.shutdown().await; + self.link.stop().await; } } diff --git a/src/session/models/auth.rs b/src/session/models/auth.rs index c93b4893..ec41bca3 100644 --- a/src/session/models/auth.rs +++ b/src/session/models/auth.rs @@ -1,7 +1,7 @@ use tdf::{TdfDeserialize, TdfSerialize, TdfSerializer, TdfType, TdfTyped}; use crate::{database::entities::Player, utils::types::PlayerID}; -use std::borrow::Cow; +use std::{borrow::Cow, sync::Arc}; #[derive(Debug, Clone)] #[repr(u16)] @@ -101,7 +101,7 @@ fn encode_persona(w: &mut S, id: PlayerID, display_name: &str) /// Structure for the response to an authentication request. pub struct AuthResponse { /// The authenticated player - pub player: Player, + pub player: Arc, /// The session token for the completed authentication pub session_token: String, /// Whether the authentication proccess was silent @@ -166,7 +166,7 @@ pub struct CreateAccountRequest { /// player details pub struct PersonaResponse { /// The player - pub player: Player, + pub player: Arc, } impl TdfSerialize for PersonaResponse { diff --git a/src/session/models/errors.rs b/src/session/models/errors.rs index 8973b6e5..ef9080d5 100644 --- a/src/session/models/errors.rs +++ b/src/session/models/errors.rs @@ -1,22 +1,24 @@ -use interlink::prelude::LinkError; use log::error; use sea_orm::DbErr; use crate::session::{packet::Packet, router::IntoPacketResponse}; -use super::{auth::AuthenticationError, game_manager::GameManagerError, util::UtilError}; +use super::{ + auth::AuthenticationError, game_manager::GameManagerError, user_sessions::UserSessionsError, + util::UtilError, +}; pub type ServerResult = Result; -// #[test] -// fn decode_error() { -// let value: i32 = 19791881; -// let bytes = value.to_le_bytes(); -// let mut out = [0u8; 2]; -// out.copy_from_slice(&bytes[2..]); -// let out = u16::from_le_bytes(out); -// println!("{:#00x}", out); -// } +#[test] +fn decode_error() { + let value: i32 = 96258; + let bytes = value.to_le_bytes(); + let mut out = [0u8; 2]; + out.copy_from_slice(&bytes[2..]); + let out = u16::from_le_bytes(out); + println!("{:#00x}", out); +} #[derive(Debug, Clone)] #[repr(u16)] @@ -49,12 +51,6 @@ pub enum DatabaseError { /// Response type for some blaze error code pub struct BlazeError(u16); -impl From for BlazeError { - fn from(_: LinkError) -> Self { - GlobalError::System.into() - } -} - impl From for BlazeError { fn from(value: DbErr) -> Self { error!("Database error: {}", value); @@ -73,6 +69,12 @@ impl From for BlazeError { } } +impl From for BlazeError { + fn from(value: UserSessionsError) -> Self { + BlazeError(value as u16) + } +} + impl From for BlazeError { fn from(value: GlobalError) -> Self { BlazeError(value as u16) diff --git a/src/session/models/game_manager.rs b/src/session/models/game_manager.rs index 8b494a23..7d1f13bc 100644 --- a/src/session/models/game_manager.rs +++ b/src/session/models/game_manager.rs @@ -1,14 +1,14 @@ +use bitflags::bitflags; +use serde::Serialize; use tdf::{Blob, GroupSlice, TdfDeserialize, TdfDeserializeOwned, TdfSerialize, TdfType, TdfTyped}; use crate::{ - services::game::{ - models::{GameSettings, GameState, PlayerState, RemoveReason}, - rules::RuleSet, - AttrMap, - }, + services::game::{rules::RuleSet, AttrMap, Game, GamePlayer}, utils::types::{GameID, PlayerID, SessionID}, }; +use super::NetworkAddress; + #[derive(Debug, Clone)] #[repr(u16)] #[allow(unused)] @@ -31,7 +31,7 @@ pub struct CreateGameRequest { #[tdf(tag = "ATTR")] pub attributes: AttrMap, /// The games initial setting - #[tdf(tag = "GSET")] + #[tdf(tag = "GSET", into = u16)] pub setting: GameSettings, } @@ -80,7 +80,7 @@ pub struct SetStateRequest { pub struct SetSettingRequest { #[tdf(tag = "GID")] pub game_id: GameID, - #[tdf(tag = "GSET")] + #[tdf(tag = "GSET", into = u16)] pub setting: GameSettings, } @@ -181,3 +181,597 @@ pub enum JoinGameState { // InQueue = 1, // GroupPartiallyJoined = 2, } + +#[derive(TdfSerialize)] +pub struct PlayerRemoved { + #[tdf(tag = "CNTX")] + pub cntx: u8, + #[tdf(tag = "GID")] + pub game_id: GameID, + #[tdf(tag = "PID")] + pub player_id: PlayerID, + #[tdf(tag = "REAS")] + pub reason: RemoveReason, +} + +#[derive(Default, Debug, Clone, Copy, TdfSerialize, TdfDeserialize, TdfTyped)] +#[repr(u8)] +pub enum RemoveReason { + /// Hit timeout while joining + JoinTimeout = 0x0, + /// Player lost PTP conneciton + PlayerConnectionLost = 0x1, + /// Player lost connection with the Pocket Relay server + ServerConnectionLost = 0x2, + /// Game migration failed + MigrationFailed = 0x3, + GameDestroyed = 0x4, + GameEnded = 0x5, + /// Generic player left the game reason + #[tdf(default)] + #[default] + PlayerLeft = 0x6, + GroupLeft = 0x7, + /// Player kicked + PlayerKicked = 0x8, + /// Player kicked and banned + PlayerKickedWithBan = 0x9, + /// Failed to join from the queue + PlayerJoinFromQueueFailed = 0xA, + PlayerReservationTimeout = 0xB, + HostEjected = 0xC, +} + +#[derive(TdfSerialize)] +pub struct AdminListChange { + #[tdf(tag = "ALST")] + pub player_id: PlayerID, + #[tdf(tag = "GID")] + pub game_id: GameID, + #[tdf(tag = "OPER")] + pub operation: AdminListOperation, + #[tdf(tag = "UID")] + pub host_id: PlayerID, +} + +/// Different operations that can be performed on +/// the admin list +#[derive(Debug, Clone, Copy, TdfSerialize, TdfTyped)] +#[repr(u8)] +pub enum AdminListOperation { + Add = 0, + Remove = 1, +} + +#[derive(TdfSerialize)] +pub struct PlayerStateChange { + #[tdf(tag = "GID")] + pub gid: GameID, + #[tdf(tag = "PID")] + pub pid: PlayerID, + #[tdf(tag = "STAT")] + pub state: PlayerState, +} + +#[derive(TdfSerialize)] +pub struct JoinComplete { + #[tdf(tag = "GID")] + pub game_id: GameID, + #[tdf(tag = "PID")] + pub player_id: PlayerID, +} + +#[derive(TdfSerialize)] +pub struct HostMigrateStart { + #[tdf(tag = "GID")] + pub game_id: GameID, + #[tdf(tag = "HOST")] + pub host_id: PlayerID, + #[tdf(tag = "PMIG")] + pub pmig: u32, + #[tdf(tag = "SLOT")] + pub slot: u8, +} + +#[derive(TdfSerialize)] +pub struct HostMigrateFinished { + #[tdf(tag = "GID")] + pub game_id: GameID, +} + +/// +/// # Example +/// ``` +/// Content: { +/// "ASIL": List [ +/// { +/// "CGS": { +/// "EVST": 6, +/// "MMSN": 1, +/// "NOMP": 0, +/// }, +/// "CUST": { +/// }, +/// "DNFS": { +/// "MDNF": 0, +/// "XDNF": 0, +/// }, +/// "FGS": { +/// "GNUM": 0, +/// }, +/// "GEOS": { +/// "DIST": 0, +/// }, +/// "GRDA": Map { +/// "ME3_gameDifficultyRule": { +/// "NAME": "ME3_gameDifficultyRule", +/// "VALU": List ["difficulty0"], +/// } +/// "ME3_gameEnemyTypeRule": { +/// "NAME": "ME3_gameEnemyTypeRule", +/// "VALU": List ["enemy0", "enemy1", "enemy2", "enemy3", "enemy4", "enemy5", "enemy6", "enemy7", "enemy8", "enemy9", "random", "abstain"], +/// } +/// "ME3_gameMapMatchRule": { +/// "NAME": "ME3_gameMapMatchRule", +/// "VALU": List ["map0", "map1", "map2", "map3", "map4", "map5", "map6", "map7", "map8", "map9", "map10", "map11", "map12", "map13", "map14", "map15", "map16", "map17", "map18", "map19", "map20", "map21", "map22", "map23", "map24", "map25", "map26", "map27", "map28", "map29", "random", "abstain"], +/// } +/// "ME3_gameStateMatchRule": { +/// "NAME": "ME3_gameStateMatchRule", +/// "VALU": List ["IN_LOBBY", "IN_LOBBY_LONGTIME", "IN_GAME_STARTING", "abstain"], +/// } +/// "ME3_rule_dlc2300": { +/// "NAME": "ME3_rule_dlc2300", +/// "VALU": List ["required", "preferred"], +/// } +/// "ME3_rule_dlc2500": { +/// "NAME": "ME3_rule_dlc2500", +/// "VALU": List ["required", "preferred"], +/// } +/// "ME3_rule_dlc2700": { +/// "NAME": "ME3_rule_dlc2700", +/// "VALU": List ["required", "preferred"], +/// } +/// "ME3_rule_dlc3050": { +/// "NAME": "ME3_rule_dlc3050", +/// "VALU": List ["required", "preferred"], +/// } +/// "ME3_rule_dlc3225": { +/// "NAME": "ME3_rule_dlc3225", +/// "VALU": List ["required", "preferred"], +/// } +/// }, +/// "GSRD": { +/// "PMAX": 4, +/// "PMIN": 2, +/// }, +/// "HBRD": { +/// "BVAL": 1, +/// }, +/// "HVRD": { +/// "VVAL": 0, +/// }, +/// "PSRS": { +/// }, +/// "RRDA": { +/// "RVAL": 0, +/// }, +/// "TSRS": { +/// "TMAX": 0, +/// "TMIN": 0, +/// }, +/// "UEDS": Map { +/// "ME3_characterSkill_Rule": { +/// "AMAX": 500, +/// "AMIN": 0, +/// "MUED": 500, +/// "NAME": "ME3_characterSkill_Rule", +/// } +/// }, +/// "VGRS": { +/// "VVAL": 0, +/// }, +/// } +/// ], +/// "MSID": 0x1, +/// "USID": 0x1, +///} +/// ``` +pub struct AsyncMatchmakingStatus { + pub player_id: PlayerID, +} + +impl TdfSerialize for AsyncMatchmakingStatus { + fn serialize(&self, w: &mut S) { + w.tag_list_start(b"ASIL", TdfType::Group, 1); + w.group_body(|w| { + // Create game status + w.group(b"CGS", |w| { + // Evaluate status + // PlayerCountSufficient = 1, + // AcceptableHostFound = 2, + // TeamSizesSufficient = 4 + w.tag_u8(b"EVST", 2 | 4); + // Number of matchmaking sessions + w.tag_u8(b"MMSN", 1); + // Number of matched players + w.tag_u8(b"NOMP", 0); + }); + + // Custom async status + w.tag_group_empty(b"CUST"); + + // DNF rule status + w.group(b"DNFS", |w| { + // My DNF value + w.tag_zero(b"MDNF"); + // Max DNF value + w.tag_zero(b"XDNF"); + }); + + // Find game status + w.group(b"FGS", |w| { + // Number of games + w.tag_zero(b"GNUM"); + }); + + // Geo location rule status + w.group(b"GEOS", |w| { + // Max distance + w.tag_zero(b"DIST"); + }); + + // Generic rule status dictionary (TODO: RULES HERE) + w.tag_map_start(b"GRDA", TdfType::String, TdfType::Group, 0); + + // Game size rule status + w.group(b"GSRD", |w| { + // Max player count accepted + w.tag_u8(b"PMAX", 4); + // Min player count accepted + w.tag_u8(b"PMIN", 2); + }); + + // Host balance rule status + w.group(b"HBRD", |w| { + // Host balance values + // HOSTS_STRICTLY_BALANCED = 0, + // HOSTS_BALANCED = 1, + // HOSTS_UNBALANCED = 2, + + w.tag_u8(b"BVAL", 1); + }); + + // Host viability rule status + w.group(b"HVRD", |w| { + // Host viability values + // CONNECTION_ASSURED = 0, + // CONNECTION_LIKELY = 1, + // CONNECTION_FEASIBLE = 2, + // CONNECTION_UNLIKELY = 3, + + w.tag_zero(b"VVAL"); + }); + + // Ping site rule status + w.group(b"PSRS", |_| {}); + + // Rank rule status + w.group(b"RRDA", |w| { + // Matched rank flags + w.tag_zero(b"RVAL"); + }); + + // Team size rule status + w.group(b"TSRS", |w| { + // Max team size accepted + w.tag_zero(b"TMAX"); + // Min team size accepted + w.tag_zero(b"TMIN"); + }); + + // UED rule status + w.tag_map_empty(b"GRDA", TdfType::String, TdfType::Group); + // Virtual game rule status + w.group(b"VGRS", |w| w.tag_zero(b"VVAL")); + }); + + w.tag_owned(b"MSID", self.player_id); + w.tag_owned(b"USID", self.player_id); + } +} + +#[derive( + Default, Debug, Serialize, Clone, Copy, PartialEq, Eq, TdfDeserialize, TdfSerialize, TdfTyped, +)] +#[repr(u8)] +pub enum PlayerState { + /// Link between the mesh points is not connected + #[default] + #[tdf(default)] + Reserved = 0x0, + Queued = 0x1, + /// Link is being formed between two mesh points + ActiveConnecting = 0x2, + ActiveMigrating = 0x3, + /// Link is connected between two mesh points + ActiveConnected = 0x4, + ActiveKickPending = 0x5, +} + +/// Message for a game state changing +#[derive(TdfSerialize)] +pub struct StateChange { + /// The ID of the game + #[tdf(tag = "GID")] + pub id: GameID, + /// The game state + #[tdf(tag = "GSTA")] + pub state: GameState, +} + +/// Message for a game setting changing +#[derive(TdfSerialize)] +pub struct SettingChange { + /// The game setting + #[tdf(tag = "ATTR", into = u16)] + pub settings: GameSettings, + /// The ID of the game + #[tdf(tag = "GID")] + pub id: GameID, +} + +/// Packet for game attribute changes +#[derive(TdfSerialize)] +pub struct AttributesChange<'a> { + /// Borrowed game attributes map + #[tdf(tag = "ATTR")] + pub attributes: &'a AttrMap, + /// The id of the game the attributes have changed for + #[tdf(tag = "GID")] + pub id: GameID, +} + +/// Message for a player joining notification +pub struct PlayerJoining<'a> { + /// The ID of the game + pub game_id: GameID, + /// The slot the player is joining into + pub slot: usize, + /// The player that is joining + pub player: &'a GamePlayer, +} + +impl TdfSerialize for PlayerJoining<'_> { + fn serialize(&self, w: &mut S) { + w.tag_u32(b"GID", self.game_id); + + w.tag_group(b"PDAT"); + self.player.encode(self.game_id, self.slot, w); + } +} + +/// Different states the game can be in +#[derive( + Default, Debug, Serialize, Clone, Copy, PartialEq, Eq, TdfSerialize, TdfDeserialize, TdfTyped, +)] +#[repr(u8)] +pub enum GameState { + NewState = 0x0, + #[tdf(default)] + #[default] + Initializing = 0x1, + Virtual = 0x2, + PreGame = 0x82, + InGame = 0x83, + PostGame = 0x4, + Migrating = 0x5, + Destructing = 0x6, + Resetable = 0x7, + ReplaySetup = 0x8, +} + +bitflags! { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + pub struct GameSettings: u16 { + const NONE = 0; + const OPEN_TO_BROWSING = 1; + const OPEN_TO_MATCHMAKING = 2; + const OPEN_TO_INVITES = 4; + const OPEN_TO_JOIN_BY_PLAYER = 8; + const HOST_MIGRATABLE = 0x10; + const RANKED = 0x20; + const ADMIN_ONLY_INVITES = 0x40; + const ENFORCE_SINGLE_GROUP_JOIN = 0x80; + const JOIN_IN_PROGRESS_SUPPORTED = 0x100; + const ADMIN_INVITE_ONLY_IGNORE_ENTRY_CHECKS = 0x200; + const IGNORE_ENTRY_CRITERIA_WITH_INVITE = 0x400; + const ENABLE_PERSISTED_GAME_ID = 0x800; + const ALLOW_SAME_TEAM_ID = 0x1000; + const VIRTUALIZED = 0x2000; + const SEND_ORPHANDED_GAME_REPORT_EVENT = 0x4000; + const ALLOW_ANY_REPUTATION = 0x8000; + } +} + +impl From for u16 { + fn from(value: GameSettings) -> Self { + value.bits() + } +} + +impl From for GameSettings { + fn from(value: u16) -> Self { + GameSettings::from_bits_retain(value) + } +} + +const VSTR: &str = "ME3-295976325-179181965240128"; + +#[derive(TdfSerialize, TdfTyped)] +pub enum GameSetupContext { + /// Context without additional data + #[tdf(key = 0x0, tag = "VALU")] + Dataless { + #[tdf(tag = "DCTX")] + context: DatalessContext, + }, + /// Context added from matchmaking + #[tdf(key = 0x3, tag = "VALU")] + Matchmaking { + #[tdf(tag = "FIT")] + fit_score: u16, + #[tdf(tag = "MAXF")] + max_fit_score: u16, + #[tdf(tag = "MSID")] + session_id: PlayerID, + #[tdf(tag = "RSLT")] + result: MatchmakingResult, + #[tdf(tag = "USID")] + player_id: PlayerID, + }, +} + +#[derive(Debug, Copy, Clone, TdfSerialize, TdfTyped)] +#[repr(u8)] +pub enum MatchmakingResult { + // CreatedGame = 0x0, + // JoinedNewGame = 0x1, + JoinedExistingGame = 0x2, + // TimedOut = 0x3, + // Canceled = 0x4, + // Terminated = 0x5, + // GameSetupFailed = 0x6, +} + +#[derive(Debug, Copy, Clone, TdfSerialize, TdfTyped)] +#[repr(u8)] +pub enum DatalessContext { + /// Session created the game + CreateGameSetup = 0x0, + /// Session joined by ID + JoinGameSetup = 0x1, + // IndirectJoinGameFromQueueSetup = 0x2, + // IndirectJoinGameFromReservationContext = 0x3, + // HostInjectionSetupContext = 0x4, +} + +pub struct GameSetupResponse<'a> { + pub game: &'a Game, + pub context: GameSetupContext, +} + +impl TdfSerialize for GameSetupResponse<'_> { + fn serialize(&self, w: &mut S) { + let game = self.game; + let host = game.players.first().expect("Missing game host for setup"); + + w.group(b"GAME", |w| { + w.tag_list_iter_owned(b"ADMN", game.players.iter().map(|player| player.player.id)); + w.tag_ref(b"ATTR", &game.attributes); + w.tag_list_slice::(b"CAP", &[4, 0]); + w.tag_u32(b"GID", game.id); + w.tag_str(b"GNAM", &host.player.display_name); + w.tag_u64(b"GPVH", 0x5a4f2b378b715c6); + w.tag_owned(b"GSET", game.settings.bits()); + w.tag_u64(b"GSID", 0x4000000a76b645); + w.tag_ref(b"GSTA", &game.state); + + w.tag_str_empty(b"GTYP"); + { + w.tag_list_start(b"HNET", TdfType::Group, 1); + w.write_byte(2); + if let NetworkAddress::AddressPair(pair) = &host.net.addr { + TdfSerialize::serialize(pair, w) + } + } + + w.tag_u32(b"HSES", host.player.id); + w.tag_zero(b"IGNO"); + w.tag_u8(b"MCAP", 4); + w.tag_ref(b"NQOS", &host.net.qos); + w.tag_zero(b"NRES"); + w.tag_zero(b"NTOP"); + w.tag_str_empty(b"PGID"); + w.tag_blob_empty(b"PGSR"); + + // Platform host info + w.group(b"PHST", |w| { + w.tag_u32(b"HPID", host.player.id); + w.tag_zero(b"HSLT"); + }); + + w.tag_u8(b"PRES", 0x1); + w.tag_str_empty(b"PSAS"); + // Queue capacity + w.tag_zero(b"QCAP"); + // Shared game randomness seed? + w.tag_u32(b"SEED", 0x4cbc8585); + // tEAM capacity + w.tag_zero(b"TCAP"); + + // Topology host info + w.group(b"THST", |w| { + w.tag_u32(b"HPID", host.player.id); + w.tag_zero(b"HSLT"); + }); + + w.tag_str(b"UUID", "286a2373-3e6e-46b9-8294-3ef05e479503"); + w.tag_u8(b"VOIP", 0x2); + w.tag_str(b"VSTR", VSTR); + w.tag_blob_empty(b"XNNC"); + w.tag_blob_empty(b"XSES"); + }); + + // Player list + w.tag_list_start(b"PROS", TdfType::Group, game.players.len()); + for (slot, player) in game.players.iter().enumerate() { + player.encode(game.id, slot, w); + } + + w.tag_ref(b"REAS", &self.context); + } +} + +pub struct GetGameDetails<'a> { + pub game: &'a Game, +} + +impl TdfSerialize for GetGameDetails<'_> { + fn serialize(&self, w: &mut S) { + let game = self.game; + let host = game.players.first().expect("Missing game host for details"); + + w.tag_list_start(b"GDAT", TdfType::Group, 1); + w.group_body(|w| { + w.tag_list_iter_owned(b"ADMN", game.players.iter().map(|player| player.player.id)); + w.tag_ref(b"ATTR", &game.attributes); + w.tag_list_slice(b"CAP", &[4u8, 0u8]); + + w.tag_u32(b"GID", game.id); + w.tag_str(b"GNAM", &host.player.display_name); + w.tag_u16(b"GSET", game.settings.bits()); + w.tag_ref(b"GSTA", &game.state); + { + w.tag_list_start(b"HNET", TdfType::Group, 1); + w.write_byte(2); + if let NetworkAddress::AddressPair(pair) = &host.net.addr { + TdfSerialize::serialize(pair, w) + } + } + w.tag_u32(b"HOST", host.player.id); + w.tag_zero(b"NTOP"); + + w.tag_list_slice(b"PCNT", &[1u8, 0u8]); + + w.tag_u8(b"PRES", 0x2); + w.tag_str(b"PSAS", "ea-sjc"); + w.tag_str_empty(b"PSID"); + w.tag_zero(b"QCAP"); + w.tag_zero(b"QCNT"); + w.tag_zero(b"SID"); + w.tag_zero(b"TCAP"); + w.tag_u8(b"VOIP", 0x2); + w.tag_str(b"VSTR", VSTR); + }); + } +} diff --git a/src/session/models/mod.rs b/src/session/models/mod.rs index b93e7b65..b4d261a0 100644 --- a/src/session/models/mod.rs +++ b/src/session/models/mod.rs @@ -9,3 +9,180 @@ pub mod other; pub mod stats; pub mod user_sessions; pub mod util; + +use serde::Serialize; +use std::{fmt::Debug, net::Ipv4Addr}; +use tdf::{GroupSlice, TdfDeserialize, TdfDeserializeOwned, TdfSerialize, TdfTyped}; + +/// Networking information for an instance. Contains the +/// host address and the port +#[derive(TdfTyped)] +#[tdf(group)] +pub struct InstanceAddress { + pub host: InstanceHost, + pub port: Port, +} + +impl TdfSerialize for InstanceAddress { + fn serialize(&self, w: &mut S) { + w.group_body(|w| { + self.host.serialize(w); + w.tag_u16(b"PORT", self.port); + }); + } +} + +impl TdfDeserializeOwned for InstanceAddress { + fn deserialize_owned(r: &mut tdf::TdfDeserializer<'_>) -> tdf::DecodeResult { + let host: InstanceHost = InstanceHost::deserialize_owned(r)?; + let port: u16 = r.tag(b"PORT")?; + GroupSlice::deserialize_content_skip(r)?; + Ok(Self { host, port }) + } +} + +/// Type of instance details provided either hostname +/// encoded as string or IP address encoded as NetAddress +pub enum InstanceHost { + Host(String), + Address(Ipv4Addr), +} + +/// Attempts to convert the provided value into a instance type. If +/// the provided value is an IPv4 value then Address is used otherwise +/// Host is used. +impl From for InstanceHost { + fn from(value: String) -> Self { + if let Ok(value) = value.parse::() { + Self::Address(value) + } else { + Self::Host(value) + } + } +} + +/// Function for converting an instance type into its address +/// string value for use in connections +impl From for String { + fn from(value: InstanceHost) -> Self { + match value { + InstanceHost::Address(value) => value.to_string(), + InstanceHost::Host(value) => value, + } + } +} + +impl TdfSerialize for InstanceHost { + fn serialize(&self, w: &mut S) { + match self { + InstanceHost::Host(value) => w.tag_str(b"HOST", value), + InstanceHost::Address(value) => w.tag_u32(b"IP", (*value).into()), + } + } +} + +impl TdfDeserializeOwned for InstanceHost { + fn deserialize_owned(r: &mut tdf::TdfDeserializer<'_>) -> tdf::DecodeResult { + let host: Option = r.try_tag(b"HOST")?; + if let Some(host) = host { + return Ok(Self::Host(host)); + } + let ip: u32 = r.tag(b"IP")?; + Ok(Self::Address(Ipv4Addr::from(ip))) + } +} + +/// Details about an instance. This is used for the redirector system +/// to both encode for redirections and decode for the retriever system +#[derive(TdfDeserialize)] +pub struct InstanceDetails { + /// The networking information for the instance + #[tdf(tag = "ADDR")] + pub net: InstanceNet, + /// Whether the host requires a secure connection (SSLv3) + #[tdf(tag = "SECU")] + pub secure: bool, + #[tdf(tag = "XDNS")] + pub xdns: bool, +} + +#[derive(Default, TdfSerialize, TdfDeserialize, TdfTyped)] +pub enum InstanceNet { + #[tdf(key = 0x0, tag = "VALU")] + InstanceAddress(InstanceAddress), + #[tdf(unset)] + Unset, + #[default] + #[tdf(default)] + Default, + // IpAddress = 0x0, + // XboxServer = 0x1, +} + +/// Structure for storing extended network data +#[derive(Debug, Copy, Clone, Default, Serialize, TdfSerialize, TdfDeserialize, TdfTyped)] +#[tdf(group)] +pub struct QosNetworkData { + /// Downstream bits per second + #[tdf(tag = "DBPS")] + pub dbps: u16, + /// Natt type + #[tdf(tag = "NATT")] + pub natt: NatType, + /// Upstream bits per second + #[tdf(tag = "UBPS")] + pub ubps: u16, +} + +// +#[derive(Debug, Default, Copy, Clone, Serialize, TdfDeserialize, TdfSerialize, TdfTyped)] +#[repr(u8)] +pub enum NatType { + #[default] + Open = 0x0, + Moderate = 0x1, + Sequential = 0x2, + Strict = 0x3, + #[tdf(default)] + Unknown = 0x4, +} + +#[derive(Default, Debug, Clone, TdfSerialize, TdfDeserialize, TdfTyped, Serialize)] +#[serde(untagged)] +pub enum NetworkAddress { + #[tdf(key = 0x2, tag = "VALU")] + AddressPair(IpPairAddress), + #[tdf(unset)] + Unset, + #[default] + #[tdf(default)] + Default, + // XboxClient = 0x0, + // XboxServer = 0x1, + // Pair = 0x2, + // IpAddress = 0x3, + // HostnameAddress = 0x4, +} + +/// Type alias for ports which are always u16 +pub type Port = u16; + +/// Pair of socket addresses +#[derive(Debug, Clone, TdfDeserialize, TdfSerialize, TdfTyped, Serialize)] +#[tdf(group)] +pub struct IpPairAddress { + #[tdf(tag = "EXIP")] + pub external: PairAddress, + #[tdf(tag = "INIP")] + pub internal: PairAddress, +} + +#[derive(Debug, Clone, TdfDeserialize, TdfSerialize, TdfTyped, Serialize)] +#[tdf(group)] +pub struct PairAddress { + #[tdf(tag = "IP", into = u32)] + #[serde(rename = "address")] + pub addr: Ipv4Addr, + #[tdf(tag = "PORT")] + pub port: u16, +} diff --git a/src/session/models/stats.rs b/src/session/models/stats.rs index d01ec2c6..671409aa 100644 --- a/src/session/models/stats.rs +++ b/src/session/models/stats.rs @@ -1,10 +1,7 @@ -use tdf::{ - types::var_int::skip_var_int, DecodeError, TdfDeserialize, TdfDeserializeOwned, TdfSerialize, - TdfType, TdfTyped, -}; +use tdf::{TdfDeserialize, TdfSerialize, TdfType, TdfTyped}; use crate::{ - services::leaderboard::models::LeaderboardEntry, + services::leaderboard::models::{LeaderboardEntry, LeaderboardType}, utils::{components::user_sessions::PLAYER_TYPE, types::PlayerID}, }; @@ -13,8 +10,8 @@ use crate::{ #[derive(TdfDeserialize)] pub struct EntityCountRequest { /// The leaderboard name - #[tdf(tag = "NAME")] - pub name: String, + #[tdf(tag = "NAME", into = &str)] + pub name: LeaderboardType, } /// Structure for the entity count response for finding the @@ -56,17 +53,13 @@ pub struct CenteredLeaderboardRequest { #[tdf(tag = "COUN")] pub count: usize, /// The leaderboard name - #[tdf(tag = "NAME")] - pub name: String, + #[tdf(tag = "NAME", into = &str)] + pub name: LeaderboardType, } pub enum LeaderboardResponse<'a> { - /// Empty response where there is no content - Empty, - /// Response with one entry - One(&'a LeaderboardEntry), - /// Response with many leaderboard entires - Many(&'a [LeaderboardEntry]), + Owned(Vec<&'a LeaderboardEntry>), + Borrowed(&'a [LeaderboardEntry]), } impl TdfSerialize for LeaderboardEntry { @@ -95,15 +88,11 @@ impl TdfTyped for LeaderboardEntry { impl TdfSerialize for LeaderboardResponse<'_> { fn serialize(&self, w: &mut S) { match self { - Self::Empty => { - w.tag_list_empty(b"LDLS", TdfType::Group); - } - Self::One(value) => { - w.tag_list_start(b"LDLS", TdfType::Group, 1); - value.serialize(w); + LeaderboardResponse::Owned(value) => { + w.tag_list_slice_ref(b"LDLS", value); } - Self::Many(values) => { - w.tag_list_slice(b"LDLS", values); + LeaderboardResponse::Borrowed(value) => { + w.tag_list_slice(b"LDLS", value); } } } @@ -135,8 +124,8 @@ pub struct LeaderboardRequest { #[tdf(tag = "COUN")] pub count: usize, /// The leaderboard name - #[tdf(tag = "NAME")] - pub name: String, + #[tdf(tag = "NAME", into = &str)] + pub name: LeaderboardType, /// The rank offset to start at #[tdf(tag = "STRT")] pub start: usize, @@ -162,26 +151,14 @@ pub struct LeaderboardRequest { /// "USET": (0, 0, 0) /// } /// ``` +#[derive(TdfDeserialize)] pub struct FilteredLeaderboardRequest { /// The player ID - pub id: PlayerID, + #[tdf(tag = "IDLS")] + pub ids: Vec, /// The leaderboard name - pub name: String, -} - -impl TdfDeserializeOwned for FilteredLeaderboardRequest { - fn deserialize_owned(r: &mut tdf::TdfDeserializer<'_>) -> tdf::DecodeResult { - let count: usize = r.until_list_typed(b"IDLS", TdfType::VarInt)?; - if count < 1 { - return Err(DecodeError::Other("Missing player ID for filter")); - } - let id: PlayerID = PlayerID::deserialize_owned(r)?; - for _ in 1..count { - skip_var_int(r)?; - } - let name: String = r.tag(b"NAME")?; - Ok(Self { id, name }) - } + #[tdf(tag = "NAME", into = &str)] + pub name: LeaderboardType, } /// Structure for a request for a leaderboard group diff --git a/src/session/models/user_sessions.rs b/src/session/models/user_sessions.rs index 958eebee..e33b0f49 100644 --- a/src/session/models/user_sessions.rs +++ b/src/session/models/user_sessions.rs @@ -1,8 +1,22 @@ -use crate::utils::{ - models::{NetworkAddress, QosNetworkData}, - types::PlayerID, +use std::sync::Arc; + +use crate::{ + database::entities::Player, + session::NetData, + utils::{components::game_manager::GAME_TYPE, types::PlayerID}, }; -use tdf::TdfDeserialize; +use bitflags::bitflags; +use serde::Serialize; +use tdf::{ObjectId, TdfDeserialize, TdfSerialize, TdfTyped}; + +use super::{util::PING_SITE_ALIAS, NetworkAddress, QosNetworkData}; + +#[derive(Debug, Clone)] +#[repr(u16)] +#[allow(unused)] +pub enum UserSessionsError { + UserNotFound = 0xb, +} /// Structure for a request to resume a session using a session token #[derive(TdfDeserialize)] @@ -12,8 +26,7 @@ pub struct ResumeSessionRequest { pub session_token: String, } -/// Structure for a request to update the network info of the -/// current session +/// Request to update the stored networking information for a session #[derive(TdfDeserialize)] pub struct UpdateNetworkRequest { /// The client address net groups @@ -24,17 +37,186 @@ pub struct UpdateNetworkRequest { pub qos: QosNetworkData, } -/// Structure for request to update the hardware flags of the -/// current session +/// Request to update the stored hardware flags for a session #[derive(TdfDeserialize)] -pub struct HardwareFlagRequest { +pub struct UpdateHardwareFlagsRequest { /// The hardware flag value - #[tdf(tag = "HWFG")] - pub hardware_flag: u16, + #[tdf(tag = "HWFG", into = u8)] + pub hardware_flags: HardwareFlags, +} + +bitflags! { + #[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize)] + pub struct HardwareFlags: u8 { + const NONE = 0; + const VOIP_HEADSET_STATUS = 1; + } +} + +impl From for u8 { + #[inline] + fn from(value: HardwareFlags) -> Self { + value.bits() + } +} + +impl From for HardwareFlags { + #[inline] + fn from(value: u8) -> Self { + HardwareFlags::from_bits_retain(value) + } +} + +#[derive(TdfSerialize)] +pub struct UserSessionExtendedDataUpdate { + #[tdf(tag = "DATA")] + pub data: UserSessionExtendedData, + #[tdf(tag = "USID")] + pub user_id: PlayerID, +} + +#[derive(TdfTyped)] +#[tdf(group)] +pub struct UserSessionExtendedData { + /// Networking data for the session + pub net: Arc, + /// ID of the game the player is in (if present) + pub game: Option, +} + +impl TdfSerialize for UserSessionExtendedData { + fn serialize(&self, w: &mut S) { + w.group_body(|w| { + // Network address + w.tag_ref(b"ADDR", &self.net.addr); + // Best ping site alias + w.tag_str(b"BPS", PING_SITE_ALIAS); + // Country + w.tag_str_empty(b"CTY"); + // Client data + w.tag_var_int_list_empty(b"CVAR"); + // Data map + w.tag_map_tuples(b"DMAP", &[(0x70001, 0x409a)]); + // Hardware flags + w.tag_owned(b"HWFG", self.net.hardware_flags.bits()); + // Ping server latency list + w.tag_list_slice(b"PSLM", &[0xfff0fff]); + // Quality of service data + w.tag_ref(b"QDAT", &self.net.qos); + // User info attributes + w.tag_owned(b"UATT", 0u8); + + if let Some(game) = self.game { + // Blaze object ID list + w.tag_list_slice(b"ULST", &[ObjectId::new(GAME_TYPE, game as u64)]); + } + }); + } } +#[derive(TdfTyped)] +#[tdf(group)] +pub struct UserIdentification<'a> { + pub id: PlayerID, + pub name: &'a str, +} + +impl<'a> UserIdentification<'a> { + pub fn from_player(player: &'a Player) -> Self { + Self { + id: player.id, + name: &player.display_name, + } + } +} + +impl TdfSerialize for UserIdentification<'_> { + fn serialize(&self, w: &mut S) { + w.group_body(|w| { + // Account ID + w.tag_owned(b"AID", self.id); + // Account locale + w.tag_owned(b"ALOC", 0x64654445u32); + // External blob + w.tag_blob_empty(b"EXBB"); + // External ID + w.tag_zero(b"EXID"); + // Blaze ID + w.tag_owned(b"ID", self.id); + // Account name + w.tag_str(b"NAME", self.name); + }); + } +} + +#[derive(TdfSerialize)] +pub struct NotifyUserAdded<'a> { + /// The user session data + #[tdf(tag = "DATA")] + pub session_data: UserSessionExtendedData, + /// The added user identification + #[tdf(tag = "USER")] + pub user: UserIdentification<'a>, +} + +#[derive(TdfSerialize)] +pub struct NotifyUserRemoved { + /// The ID of the removed user + #[tdf(tag = "BUID")] + pub player_id: PlayerID, +} + +#[derive(TdfSerialize)] +pub struct NotifyUserUpdated { + #[tdf(tag = "FLGS", into = u8)] + pub flags: UserDataFlags, + /// The ID of the updated user + #[tdf(tag = "ID")] + pub player_id: PlayerID, +} + +bitflags! { + #[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize)] + pub struct UserDataFlags: u8 { + const NONE = 0; + const SUBSCRIBED = 1; + const ONLINE = 2; + } +} + +impl From for u8 { + fn from(value: UserDataFlags) -> Self { + value.bits() + } +} + +impl From for UserDataFlags { + fn from(value: u8) -> Self { + UserDataFlags::from_bits_retain(value) + } +} + +/// Request to lookup the session details of a user, see [UserIdentification] +/// for the full structure that this uses #[derive(TdfDeserialize)] pub struct LookupRequest { #[tdf(tag = "ID")] pub player_id: PlayerID, } + +/// User lookup response +pub struct LookupResponse { + pub player: Arc, + pub extended_data: UserSessionExtendedData, +} + +impl TdfSerialize for LookupResponse { + fn serialize(&self, w: &mut S) { + // The user session extended data + w.tag_ref(b"EDAT", &self.extended_data); + w.tag_owned(b"FLGS", UserDataFlags::ONLINE.bits()); + + // The lookup user identification + w.tag_alt(b"USER", UserIdentification::from_player(&self.player)); + } +} diff --git a/src/session/models/util.rs b/src/session/models/util.rs index c15a0f9d..24d0c87d 100644 --- a/src/session/models/util.rs +++ b/src/session/models/util.rs @@ -1,9 +1,9 @@ +use super::Port; use crate::{ - session::SessionHostTarget, - utils::{models::Port, types::PlayerID}, + config::{QosServerConfig, RuntimeConfig}, + utils::types::PlayerID, }; - -use std::borrow::Cow; +use std::{borrow::Cow, sync::Arc}; use tdf::{TdfDeserialize, TdfMap, TdfSerialize, TdfType}; #[derive(Debug, Clone)] @@ -47,6 +47,7 @@ impl TdfSerialize for TelemetryServer { w.tag_zero(b"ANON"); w.tag_str(b"DISA", TELEMTRY_DISA); w.tag_str(b"FILT", "-UION/****"); + // Encoded locale actually BE encoded string bytes (enNZ) w.tag_u32(b"LOC", 1701727834); w.tag_str(b"NOOK", "US,CA,MX"); // Last known telemetry port: 9988 @@ -86,9 +87,12 @@ pub const SRC_VERSION: &str = "303107"; pub const BLAZE_VERSION: &str = "Blaze 3.15.08.0 (CL# 1629389)"; pub const PING_PERIOD: &str = "15s"; +/// Alias used for ping sites +pub const PING_SITE_ALIAS: &str = "ea-sjc"; + /// Structure for the response to a pre authentication request pub struct PreAuthResponse { - pub host_target: SessionHostTarget, + pub config: Arc, } impl TdfSerialize for PreAuthResponse { @@ -129,12 +133,17 @@ impl TdfSerialize for PreAuthResponse { // Quality Of Service Server details w.group(b"QOSS", |w| { - let (http_host, http_port) = if self.host_target.local_http { - ("127.0.0.1", LOCAL_HTTP_PORT) - } else { - (&self.host_target.host as &str, self.host_target.port) + let qos = &self.config.qos; + + let (http_host, http_port) = match qos { + QosServerConfig::Official => ("gossjcprod-qos01.ea.com", 17502), + QosServerConfig::Local => ("127.0.0.1", LOCAL_HTTP_PORT), + QosServerConfig::Custom { host, port } => (host.as_str(), *port), }; + // let http_host = "127.0.0.1"; + // let http_port = 17499; + // Bioware Primary Server w.group(b"BWPS", |w| { w.tag_str(b"PSA", http_host); @@ -150,7 +159,7 @@ impl TdfSerialize for PreAuthResponse { w.tag_map_start(b"LTPS", TdfType::String, TdfType::Group, 1); // Key for the server - "ea-sjc".serialize(w); + PING_SITE_ALIAS.serialize(w); w.group_body(|w| { // Same as the Bioware primary server @@ -232,9 +241,9 @@ pub struct FetchConfigResponse { /// Structure for the suspend user ping request #[derive(TdfDeserialize)] pub struct SuspendPingRequest { - /// The suspend ping value + /// The suspend ping value (Suspend time in microseconds) #[tdf(tag = "TVAL")] - pub value: u32, + pub time_value: u32, } /// Structure for the request to update the settings for diff --git a/src/session/packet.rs b/src/session/packet.rs index 0ee8b0f6..a7f8b8bf 100644 --- a/src/session/packet.rs +++ b/src/session/packet.rs @@ -1,154 +1,145 @@ -#![allow(unused)] - -use crate::utils::components::{get_command_name, get_component_name}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use std::{fmt::Debug, sync::Arc}; -use std::{io, ops::Deref}; -use tdf::{ - serialize_vec, DecodeResult, TdfDeserialize, TdfDeserializer, TdfSerialize, TdfStringifier, +use crate::utils::components::{ + component_key, get_command_name, get_component_name, OMIT_PACKET_CONTENTS, }; +use bitflags::bitflags; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use std::fmt::Debug; +use std::io; +use tdf::{prelude::*, serialize_vec}; use tokio_util::codec::{Decoder, Encoder}; -/// The different types of packets #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[repr(u8)] -pub enum PacketType { - /// ID counted request packets (0x00) - Request = 0x00, - /// Packets responding to requests (0x10) - Response = 0x10, - /// Unique packets coming from the server (0x20) - Notify = 0x20, - /// Error packets (0x30) - Error = 0x30, +pub enum FrameType { + /// Request to a server + Request = 0x0, + /// Response to a request + Response = 0x1, + /// Async notification from the server + Notify = 0x2, + /// Error response from the server + Error = 0x3, +} + +bitflags! { + #[derive(Debug, Copy, Clone, PartialEq, Eq)] + pub struct PacketOptions: u8 { + const NONE = 0x0; + /// Frame length is extended from 16bits to 32bits + const JUMBO_FRAME = 0x1; + const HAS_CONTEXT = 0x2; + const IMMEDIATE = 0x4; + const JUMBO_CONTEXT = 0x8; + } } -/// From u8 implementation to convert bytes back into -/// PacketTypes -impl From for PacketType { +impl From for FrameType { fn from(value: u8) -> Self { match value { - 0x00 => PacketType::Request, - 0x10 => PacketType::Response, - 0x20 => PacketType::Notify, - 0x30 => PacketType::Error, - // Default type fallback to request - _ => PacketType::Request, + 0x0 => FrameType::Request, + 0x1 => FrameType::Response, + 0x2 => FrameType::Notify, + 0x3 => FrameType::Error, + _ => FrameType::Request, } } } -/// Structure of packet header which comes before the -/// packet content and describes it. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub struct PacketHeader { - /// The component of this packet +/// Framing structure +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FireFrame { + /// The component that should handle this frame pub component: u16, - /// The command of this packet + /// The command this frame is for pub command: u16, - /// A possible error this packet contains (zero is none) + /// Error code if present, otherwise zero pub error: u16, - /// The type of this packet - pub ty: PacketType, - /// The unique ID of this packet (Notify packets this is just zero) - pub id: u16, + /// The type of frame + pub ty: FrameType, + /// Additional options assocaited with this frame + pub options: PacketOptions, + /// Sequence number for tracking request and response mappings + pub seq: u16, } -impl PacketHeader { - /// Creates a notify header for the provided component and command - /// - /// `component` The component to use - /// `command` The command to use +impl FireFrame { + const MIN_HEADER_SIZE: usize = 12; + const JUMBO_SIZE: usize = std::mem::size_of::(); + pub const fn notify(component: u16, command: u16) -> Self { Self { component, command, error: 0, - ty: PacketType::Notify, - id: 0, + ty: FrameType::Notify, + options: PacketOptions::NONE, + seq: 0, } } - /// Creates a request header for the provided id, component - /// and command - /// - /// `id` The packet ID - /// `component` The component to use - /// `command` The command to use - pub const fn request(id: u16, component: u16, command: u16) -> Self { + pub const fn request(seq: u16, component: u16, command: u16) -> Self { Self { component, command, error: 0, - ty: PacketType::Request, - id, + ty: FrameType::Request, + options: PacketOptions::NONE, + seq, } } - /// Creates a response to the provided packet header by - /// changing the type of the header pub const fn response(&self) -> Self { - self.with_type(PacketType::Response) + self.with_type(FrameType::Response) } - /// Copies the header contents changing its Packet Type - /// - /// `ty` The new packet type - pub const fn with_type(&self, ty: PacketType) -> Self { + pub const fn with_type(&self, ty: FrameType) -> Self { Self { component: self.component, command: self.command, error: self.error, ty, - id: self.id, + options: PacketOptions::NONE, + seq: self.seq, } } - /// Copies the header contents changing its Packet Type pub const fn with_error(&self, error: u16) -> Self { Self { component: self.component, command: self.command, error, - ty: PacketType::Error, - id: self.id, + ty: FrameType::Error, + options: PacketOptions::NONE, + seq: self.seq, } } - /// Checks if the component and command of this packet header matches - /// that of the other packet header - /// - /// `other` The packet header to compare to - pub fn path_matches(&self, other: &PacketHeader) -> bool { + pub fn path_matches(&self, other: &FireFrame) -> bool { self.component.eq(&other.component) && self.command.eq(&other.command) } - /// Encodes the contents of this header appending to the - /// output source - /// - /// `dst` The dst to append the bytes to - /// `length` The length of the content after the header pub fn write(&self, dst: &mut BytesMut, length: usize) { - let is_extended = length > 0xFFFF; + let mut options = self.options; + if length > 0xFFFF { + options |= PacketOptions::JUMBO_FRAME; + } + dst.put_u16(length as u16); dst.put_u16(self.component); dst.put_u16(self.command); dst.put_u16(self.error); - dst.put_u8(self.ty as u8); - dst.put_u8(if is_extended { 0x10 } else { 0x00 }); - dst.put_u16(self.id); - if is_extended { - dst.put_u8(((length & 0xFF000000) >> 24) as u8); - dst.put_u8(((length & 0x00FF0000) >> 16) as u8); + dst.put_u8((self.ty as u8) << 4); + dst.put_u8(options.bits() << 4); + dst.put_u16(self.seq); + + if options.contains(PacketOptions::JUMBO_FRAME) { + // Put the extended length (The next 16 bits of the value to make the 32bit length) + dst.put_u16((length >> 16) as u16); } } - /// Attempts to read the packet header from the provided - /// source bytes returning None if there aren't enough bytes - /// - /// `src` The bytes to read from - pub fn read(src: &mut BytesMut) -> Option<(PacketHeader, usize)> { - if src.len() < 12 { + pub fn read(src: &mut BytesMut) -> Option<(FireFrame, usize)> { + if src.len() < Self::MIN_HEADER_SIZE { return None; } @@ -156,42 +147,38 @@ impl PacketHeader { let component = src.get_u16(); let command = src.get_u16(); let error = src.get_u16(); - let ty = src.get_u8(); - // If we encounter 0x10 here then the packet contains extended length - // bytes so its longer than a u16::MAX length - let is_extended = src.get_u8() == 0x10; - let id = src.get_u16(); + let ty = src.get_u8() >> 4; + let options = src.get_u8() >> 4; + let options = PacketOptions::from_bits_retain(options); + let seq = src.get_u16(); - if is_extended { + if options.contains(PacketOptions::JUMBO_FRAME) { // We need another two bytes for the extended length - if src.len() < 2 { + if src.len() < Self::JUMBO_SIZE { return None; } - length += src.get_u16() as usize; + let ext_length = (src.get_u16() as usize) << 16; + length |= ext_length; } - let ty = PacketType::from(ty); - let header = PacketHeader { + let ty = FrameType::from(ty); + let header = FireFrame { component, command, error, ty, - id, + options, + seq, }; Some((header, length)) } } -/// Structure for Blaze packets contains the contents of the packet -/// and the header for identification. -/// -/// Packets can be cloned with little memory usage increase because -/// the content is stored as Bytes. #[derive(Debug, Clone)] pub struct Packet { - /// The packet header - pub header: PacketHeader, - /// The packet encoded byte contents + /// The frame preceeding this packet + pub frame: FireFrame, + /// The encoded contents of the packet pub contents: Bytes, } @@ -202,56 +189,60 @@ where Bytes::from(serialize_vec(value)) } +#[allow(unused)] impl Packet { /// Creates a new packet from the provided header and contents - pub const fn new(header: PacketHeader, contents: Bytes) -> Self { - Self { header, contents } + pub const fn new(header: FireFrame, contents: Bytes) -> Self { + Self { + frame: header, + contents, + } } /// Creates a new packet from the provided header with empty content #[inline] - pub const fn new_empty(header: PacketHeader) -> Self { + pub const fn new_empty(header: FireFrame) -> Self { Self::new(header, Bytes::new()) } #[inline] pub const fn new_request(id: u16, component: u16, command: u16, contents: Bytes) -> Packet { - Self::new(PacketHeader::request(id, component, command), contents) + Self::new(FireFrame::request(id, component, command), contents) } #[inline] pub const fn new_response(packet: &Packet, contents: Bytes) -> Self { - Self::new(packet.header.response(), contents) + Self::new(packet.frame.response(), contents) } #[inline] pub const fn new_error(packet: &Packet, error: u16, contents: Bytes) -> Self { - Self::new(packet.header.with_error(error), contents) + Self::new(packet.frame.with_error(error), contents) } #[inline] pub const fn new_notify(component: u16, command: u16, contents: Bytes) -> Packet { - Self::new(PacketHeader::notify(component, command), contents) + Self::new(FireFrame::notify(component, command), contents) } #[inline] pub const fn request_empty(id: u16, component: u16, command: u16) -> Packet { - Self::new_empty(PacketHeader::request(id, component, command)) + Self::new_empty(FireFrame::request(id, component, command)) } #[inline] pub const fn response_empty(packet: &Packet) -> Self { - Self::new_empty(packet.header.response()) + Self::new_empty(packet.frame.response()) } #[inline] pub const fn error_empty(packet: &Packet, error: u16) -> Packet { - Self::new_empty(packet.header.with_error(error)) + Self::new_empty(packet.frame.with_error(error)) } #[inline] pub const fn notify_empty(component: u16, command: u16) -> Packet { - Self::new_empty(PacketHeader::notify(component, command)) + Self::new_empty(FireFrame::notify(component, command)) } #[inline] @@ -296,7 +287,7 @@ impl Packet { } pub fn read(src: &mut BytesMut) -> Option { - let (header, length) = PacketHeader::read(src)?; + let (frame, length) = FireFrame::read(src)?; if src.len() < length { return None; @@ -304,14 +295,14 @@ impl Packet { let contents = src.split_to(length); Some(Self { - header, + frame, contents: contents.freeze(), }) } pub fn write(&self, dst: &mut BytesMut) { let contents = &self.contents; - self.header.write(dst, contents.len()); + self.frame.write(dst, contents.len()); dst.extend_from_slice(contents); } } @@ -349,74 +340,56 @@ impl Encoder for PacketCodec { pub struct PacketDebug<'a> { /// Reference to the packet itself pub packet: &'a Packet, - - /// Decide whether to display the contents of the packet - pub minified: bool, } impl<'a> Debug for PacketDebug<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // Append basic header information - let header = &self.packet.header; - - let component_name = get_component_name(header.component); - let command_name = get_command_name( - header.component, - header.command, - matches!(&header.ty, PacketType::Notify), - ); - - match (component_name, command_name) { - (Some(component), Some(command)) => { - writeln!(f, "Component: {}({})", component, command)?; - } - (Some(component), None) => { - writeln!(f, "Component: {}({:#06x})", component, header.command)?; - } - _ => { - writeln!( - f, - "Component: {:#06x}({:#06x})", - header.component, header.command - )?; - } - } + let header = &self.packet.frame; - writeln!(f, "Type: {:?}", header.ty)?; + let key = component_key(header.component, header.command); - if !matches!(&header.ty, PacketType::Notify) { - writeln!(f, "ID: {}", &header.id)?; - } + let is_notify = matches!(&header.ty, FrameType::Notify); + let is_error = matches!(&header.ty, FrameType::Error); - if let PacketType::Error = &header.ty { - writeln!(f, "Error: {:#06x}", &header.error)?; + let component_name = get_component_name(header.component).unwrap_or("Unknown"); + let command_name = get_command_name(key, is_notify).unwrap_or("Unkown"); + + write!(f, "{:?}", header.ty)?; + + if is_error { + // Write sequence number and error for errors + write!(f, " ({}, E?{:#06x})", header.seq, header.error)?; + } else if !is_notify { + // Write sequence number of sequenced types + write!(f, " ({})", header.seq)?; } + writeln!( + f, + ": {}->{} ({:#06x}->{:#06x})", + component_name, command_name, header.component, header.command + )?; + + let omit_content = OMIT_PACKET_CONTENTS.contains(&key); + + writeln!(f, "Options: {:?}", header.options)?; + // Skip remaining if the message shouldn't contain its content - if self.minified { + if omit_content { return Ok(()); } - let mut r = TdfDeserializer::new(&self.packet.contents); - let mut out = String::new(); - out.push_str("{\n"); - let mut str = TdfStringifier::new(r, &mut out); + write!(f, "Content: ")?; - // Stringify the content or append error instead - if !str.stringify() { - writeln!(f, "Content Error: Content was malformed or not parsible")?; - writeln!(f, "Partial Content: {}", out)?; - writeln!(f, "Raw: {:?}", &self.packet.contents)?; - return Ok(()); - } + let r = TdfDeserializer::new(&self.packet.contents); + let mut str = TdfStringifier::new(r, f); - if out.len() == 2 { - // Remove new line if nothing else was appended - out.pop(); + if !str.stringify() { + // Write the raw content if stringify doesn't complete + writeln!(&mut str.w, "Raw: {:?}", &self.packet.contents)?; } - out.push('}'); - - write!(f, "Content: {}", out) + Ok(()) } } diff --git a/src/session/router.rs b/src/session/router.rs index 840b55b4..bbfeac1a 100644 --- a/src/session/router.rs +++ b/src/session/router.rs @@ -3,24 +3,26 @@ use super::{ models::errors::BlazeError, - packet::{Packet, PacketHeader}, + packet::{FireFrame, Packet}, SessionLink, }; use crate::{ + database::entities::Player, + services::game::GamePlayer, session::models::errors::GlobalError, utils::{ components::{component_key, ComponentKey}, - types::BoxFuture, + hashing::IntHashMap, }, }; use bytes::Bytes; +use futures_util::future::BoxFuture; use log::error; use std::{ any::{Any, TypeId}, - collections::HashMap, convert::Infallible, - future::{ready, Future}, - hash::{BuildHasherDefault, Hasher}, + future::ready, + future::Future, marker::PhantomData, sync::Arc, }; @@ -73,11 +75,12 @@ impl PacketRequest { } } -type AnyMap = HashMap, BuildHasherDefault>; +type AnyMap = IntHashMap>; +type RouteMap = IntHashMap>; pub struct BlazeRouterBuilder { /// Map for looking up a route based on the component key - routes: HashMap, BuildHasherDefault>, + routes: RouteMap, extensions: AnyMap, } @@ -124,7 +127,7 @@ impl BlazeRouterBuilder { pub struct BlazeRouter { /// Map for looking up a route based on the component key - routes: HashMap, BuildHasherDefault>, + routes: RouteMap, extensions: Arc, } @@ -134,10 +137,10 @@ impl BlazeRouter { state: SessionLink, packet: Packet, ) -> Result, Packet> { - let route = match self.routes.get(&component_key( - packet.header.component, - packet.header.command, - )) { + let route = match self + .routes + .get(&component_key(packet.frame.component, packet.frame.command)) + { Some(value) => value, None => return Err(packet), }; @@ -150,52 +153,11 @@ impl BlazeRouter { } } -/// "Hasher" used by the router map that just directly stores the integer value -/// from the component key as no hashing is required -#[derive(Default)] -pub struct ComponentKeyHasher(u32); - -impl Hasher for ComponentKeyHasher { - fn finish(&self) -> u64 { - self.0 as u64 - } - - fn write(&mut self, _bytes: &[u8]) { - panic!("Attempted to use component key hasher to hash bytes") - } - - fn write_u32(&mut self, i: u32) { - self.0 = i; - } -} - -// With TypeIds as keys, there's no need to hash them. They are already hashes -// themselves, coming from the compiler. The IdHasher just holds the u64 of -// the TypeId, and then returns it, instead of doing any bit fiddling. -#[derive(Default)] -pub struct IdHasher(u64); - -impl Hasher for IdHasher { - fn write(&mut self, _: &[u8]) { - panic!("Attempted to use id hasher to hash bytes") - } - - #[inline] - fn write_u64(&mut self, id: u64) { - self.0 = id; - } - - #[inline] - fn finish(&self) -> u64 { - self.0 - } -} - pub trait FromPacketRequest: Sized { type Rejection: IntoPacketResponse; fn from_packet_request<'a>( - req: &'a PacketRequest, + req: &'a mut PacketRequest, ) -> BoxFuture<'a, Result> where Self: 'a; @@ -212,13 +174,17 @@ pub struct Blaze(pub V); /// responses pub struct BlazeWithHeader { pub req: V, - pub header: PacketHeader, + pub frame: FireFrame, } /// [Blaze] tdf type for contents that have already been /// serialized ahead of time pub struct RawBlaze(Bytes); +/// Extracts the session authenticated player if one is present, +/// responds with [GlobalError::AuthenticationRequired] if there is none +pub struct SessionAuth(pub Arc); + pub struct Extension(pub T); impl FromPacketRequest for Extension @@ -228,7 +194,7 @@ where type Rejection = BlazeError; fn from_packet_request<'a>( - req: &'a PacketRequest, + req: &'a mut PacketRequest, ) -> BoxFuture<'a, Result> where Self: 'a, @@ -248,6 +214,45 @@ where } } +impl FromPacketRequest for GamePlayer { + type Rejection = BlazeError; + + fn from_packet_request<'a>( + req: &'a mut PacketRequest, + ) -> BoxFuture<'a, Result> + where + Self: 'a, + { + Box::pin(async move { + let data = &*req.state.data.read().await; + let data = data.as_ref().ok_or(GlobalError::AuthenticationRequired)?; + Ok(GamePlayer::new( + data.player.clone(), + data.net.clone(), + req.state.clone(), + )) + }) + } +} + +impl FromPacketRequest for SessionAuth { + type Rejection = BlazeError; + + fn from_packet_request<'a>( + req: &'a mut PacketRequest, + ) -> BoxFuture<'a, Result> + where + Self: 'a, + { + Box::pin(async move { + let data = &*req.state.data.read().await; + let data = data.as_ref().ok_or(GlobalError::AuthenticationRequired)?; + let player = data.player.clone(); + Ok(SessionAuth(player)) + }) + } +} + impl From for RawBlaze where T: TdfSerialize, @@ -266,7 +271,7 @@ where type Rejection = BlazeError; fn from_packet_request<'a>( - req: &'a PacketRequest, + req: &'a mut PacketRequest, ) -> BoxFuture<'a, Result> where Self: 'a, @@ -289,7 +294,7 @@ impl BlazeWithHeader { E: TdfSerialize, { Packet { - header: self.header.response(), + frame: self.frame.response(), contents: Bytes::from(serialize_vec(&res)), } } @@ -302,7 +307,7 @@ where type Rejection = BlazeError; fn from_packet_request<'a>( - req: &'a PacketRequest, + req: &'a mut PacketRequest, ) -> BoxFuture<'a, Result> where Self: 'a, @@ -313,7 +318,7 @@ where V::deserialize(&mut r) .map(|value| BlazeWithHeader { req: value, - header: req.packet.header, + frame: req.packet.frame.clone(), }) .map_err(|err| { error!("Error while decoding packet: {:?}", err); @@ -327,7 +332,7 @@ impl FromPacketRequest for SessionLink { type Rejection = Infallible; fn from_packet_request<'a>( - req: &'a PacketRequest, + req: &'a mut PacketRequest, ) -> BoxFuture<'a, Result> where Self: 'a, @@ -440,10 +445,10 @@ macro_rules! impl_handler { fn handle(&self, req: PacketRequest) -> BoxFuture<'_, Packet> { Box::pin(async move { - let req = req; + let mut req = req; $( - let $ty = match $ty::from_packet_request(&req).await { + let $ty = match $ty::from_packet_request(&mut req).await { Ok(value) => value, Err(rejection) => return rejection.into_response(&req.packet), }; diff --git a/src/session/routes/auth.rs b/src/session/routes/auth.rs index 70a3dc35..6ffaf690 100644 --- a/src/session/routes/auth.rs +++ b/src/session/routes/auth.rs @@ -2,21 +2,20 @@ use crate::{ config::RuntimeConfig, database::{entities::Player, DatabaseConnection}, services::{ - retriever::{GetOriginFlow, Retriever}, - sessions::{CreateTokenMessage, Sessions, VerifyError, VerifyTokenMessage}, + retriever::Retriever, + sessions::{Sessions, VerifyError}, }, session::{ models::{ auth::*, errors::{GlobalError, ServerResult}, }, - router::{Blaze, Extension}, - GetPlayerIdMessage, GetPlayerMessage, SessionLink, SetPlayerMessage, + router::{Blaze, Extension, SessionAuth}, + SessionLink, }, utils::hashing::{hash_password, verify_password}, }; use email_address::EmailAddress; -use interlink::prelude::Link; use log::{debug, error}; use std::{borrow::Cow, sync::Arc}; use tokio::fs::read_to_string; @@ -24,18 +23,16 @@ use tokio::fs::read_to_string; pub async fn handle_login( session: SessionLink, Extension(db): Extension, - Extension(sessions): Extension>, - Blaze(req): Blaze, + Extension(sessions): Extension>, + Blaze(LoginRequest { email, password }): Blaze, ) -> ServerResult> { - let LoginRequest { email, password } = &req; - // Ensure the email is actually valid - if !EmailAddress::is_valid(email) { + if !EmailAddress::is_valid(&email) { return Err(AuthenticationError::InvalidEmail.into()); } // Find a non origin player with that email - let player: Player = Player::by_email(&db, email) + let player: Player = Player::by_email(&db, &email) .await? .ok_or(AuthenticationError::InvalidUser)?; @@ -46,14 +43,16 @@ pub async fn handle_login( .ok_or(AuthenticationError::InvalidUser)?; // Ensure passwords match - if !verify_password(password, player_password) { + if !verify_password(&password, player_password) { return Err(AuthenticationError::InvalidPassword.into()); } // Update the session stored player - session.send(SetPlayerMessage(Some(player.clone()))).await?; - let session_token: String = sessions.send(CreateTokenMessage(player.id)).await?; + let player = session.set_player(player).await; + sessions.add_session(player.id, session).await; + + let session_token: String = sessions.create_token(player.id); Ok(Blaze(AuthResponse { player, @@ -65,24 +64,22 @@ pub async fn handle_login( pub async fn handle_silent_login( session: SessionLink, Extension(db): Extension, - Extension(sessions): Extension>, + Extension(sessions): Extension>, Blaze(req): Blaze, ) -> ServerResult> { // Verify the authentication token - let player_id = sessions - .send(VerifyTokenMessage(req.token.clone())) - .await? - .map_err(|err| match err { - VerifyError::Expired => AuthenticationError::ExpiredToken, - VerifyError::Invalid => AuthenticationError::InvalidToken, - })?; + let player_id = sessions.verify_token(&req.token).map_err(|err| match err { + VerifyError::Expired => AuthenticationError::ExpiredToken, + VerifyError::Invalid => AuthenticationError::InvalidToken, + })?; let player = Player::by_id(&db, player_id) .await? .ok_or(AuthenticationError::InvalidToken)?; // Update the session stored player - session.send(SetPlayerMessage(Some(player.clone()))).await?; + let player = session.set_player(player).await; + sessions.add_session(player.id, session).await; Ok(Blaze(AuthResponse { player, @@ -95,35 +92,26 @@ pub async fn handle_origin_login( session: SessionLink, Extension(db): Extension, Extension(config): Extension>, - Extension(sessions): Extension>, - Extension(retriever): Extension>, + Extension(sessions): Extension>, + Extension(retriever): Extension>, Blaze(req): Blaze, ) -> ServerResult> { // Obtain an origin flow - let mut flow = match retriever.send(GetOriginFlow).await { - Ok(Ok(value)) => value, - Ok(Err(err)) => { - error!("Failed to obtain origin flow: {}", err); - return Err(GlobalError::System.into()); - } - Err(err) => { - error!("Unable to access retriever service: {}", err); - return Err(GlobalError::System.into()); - } - }; - - let player: Player = match flow.login(&db, req.token, &config).await { - Ok(value) => value, - Err(err) => { - error!("Failed to login with origin: {}", err); - return Err(GlobalError::System.into()); - } - }; + let mut flow = retriever.origin_flow().await.map_err(|err| { + error!("Failed to obtain origin flow: {}", err); + GlobalError::System + })?; + + let player: Player = flow.login(&db, req.token, &config).await.map_err(|err| { + error!("Failed to login with origin: {}", err); + GlobalError::System + })?; // Update the session stored player - session.send(SetPlayerMessage(Some(player.clone()))).await?; + let player = session.set_player(player).await; + sessions.add_session(player.id, session).await; - let session_token: String = sessions.send(CreateTokenMessage(player.id)).await?; + let session_token: String = sessions.create_token(player.id); Ok(Blaze(AuthResponse { player, @@ -140,8 +128,13 @@ pub async fn handle_origin_login( /// ID: 8 /// Content: {} /// ``` -pub async fn handle_logout(session: SessionLink) { - let _ = session.send(SetPlayerMessage(None)).await; +pub async fn handle_logout( + session: SessionLink, + SessionAuth(player): SessionAuth, + Extension(sessions): Extension>, +) { + session.clear_player().await; + sessions.remove_session(player.id).await; } // Skip formatting these entitlement creations @@ -236,12 +229,8 @@ pub async fn handle_list_entitlements( /// "PMAM": "Jacobtread" /// } /// ``` -pub async fn handle_login_persona(session: SessionLink) -> ServerResult> { - let player: Player = session - .send(GetPlayerMessage) - .await? - .ok_or(GlobalError::AuthenticationRequired)?; - Ok(Blaze(PersonaResponse { player })) +pub async fn handle_login_persona(SessionAuth(player): SessionAuth) -> Blaze { + Blaze(PersonaResponse { player }) } /// Handles forgot password requests. This normally would send a forgot password @@ -255,9 +244,8 @@ pub async fn handle_login_persona(session: SessionLink) -> ServerResult) -> ServerResult<()> { +pub async fn handle_forgot_password(Blaze(req): Blaze) { debug!("Password reset request (Email: {})", req.email); - Ok(()) } /// Handles creating accounts @@ -274,8 +262,8 @@ pub async fn handle_forgot_password(Blaze(req): Blaze) -> /// "GEST": 0, /// "LANG": "en", // Language /// "MAIL": "ACCOUNT_EMAIL", -/// "OPT": 0, -/// "OPT": 0, +/// "OPT1": 0, +/// "OPT3": 0, /// "PASS": "ACCOUNT_PASSWORD", /// "PNAM": "", /// "PRIV": "webprivacy/au/en/pc/default/08202020/02042022", // Privacy policy path @@ -297,7 +285,7 @@ pub async fn handle_create_account( session: SessionLink, Extension(db): Extension, Extension(config): Extension>, - Extension(sessions): Extension>, + Extension(sessions): Extension>, Blaze(req): Blaze, ) -> ServerResult> { let email = req.email; @@ -305,21 +293,16 @@ pub async fn handle_create_account( return Err(AuthenticationError::InvalidEmail.into()); } - match Player::by_email(&db, &email).await? { - // Continue normally for non taken emails - None => {} + if Player::by_email(&db, &email).await?.is_some() { // Handle email address is already in use - Some(_) => return Err(AuthenticationError::Exists.into()), + return Err(AuthenticationError::Exists.into()); } // Hash the proivded plain text password using Argon2 - let hashed_password: String = match hash_password(&req.password) { - Ok(value) => value, - Err(err) => { - error!("Failed to hash password for creating account: {err:?}"); - return Err(GlobalError::System.into()); - } - }; + let hashed_password: String = hash_password(&req.password).map_err(|err| { + error!("Failed to hash password for creating account: {}", err); + GlobalError::System + })?; // Create a default display name from the first 99 chars of the email let display_name: String = email.chars().take(99).collect::(); @@ -328,11 +311,10 @@ pub async fn handle_create_account( let player: Player = Player::create(&db, email, display_name, Some(hashed_password), &config).await?; - // Failing to set the player likely the player disconnected or - // the server is shutting down - session.send(SetPlayerMessage(Some(player.clone()))).await?; + let player = session.set_player(player).await; + sessions.add_session(player.id, session).await; - let session_token = sessions.send(CreateTokenMessage(player.id)).await?; + let session_token = sessions.create_token(player.id); Ok(Blaze(AuthResponse { player, @@ -367,10 +349,12 @@ pub async fn handle_get_legal_docs_info() -> Blaze { /// } /// ``` pub async fn handle_tos() -> Blaze { - let content = match read_to_string("data/terms_of_service.html").await { - Ok(value) => Cow::Owned(value), - Err(_) => Cow::Borrowed("

This is a terms of service placeholder

"), - }; + let content = read_to_string("data/terms_of_service.html") + .await + .map(Cow::Owned) + .unwrap_or(Cow::Borrowed( + "

This is a terms of service placeholder

", + )); Blaze(LegalContent { col: 0xdaed, @@ -390,10 +374,12 @@ pub async fn handle_tos() -> Blaze { /// } /// ``` pub async fn handle_privacy_policy() -> Blaze { - let content = match read_to_string("data/privacy_policy.html").await { - Ok(value) => Cow::Owned(value), - Err(_) => Cow::Borrowed("

This is a privacy policy placeholder

"), - }; + let content = read_to_string("data/privacy_policy.html") + .await + .map(Cow::Owned) + .unwrap_or(Cow::Borrowed( + "

This is a privacy policy placeholder

", + )); Blaze(LegalContent { col: 0xc99c, @@ -411,14 +397,10 @@ pub async fn handle_privacy_policy() -> Blaze { /// Content: {} /// ``` pub async fn handle_get_auth_token( - session: SessionLink, - Extension(sessions): Extension>, -) -> ServerResult> { - let player_id = session - .send(GetPlayerIdMessage) - .await? - .ok_or(GlobalError::AuthenticationRequired)?; + SessionAuth(player): SessionAuth, + Extension(sessions): Extension>, +) -> Blaze { // Create a new token claim for the player to use with the API - let token = sessions.send(CreateTokenMessage(player_id)).await?; - Ok(Blaze(GetTokenResponse { token })) + let token = sessions.create_token(player.id); + Blaze(GetTokenResponse { token }) } diff --git a/src/session/routes/game_manager.rs b/src/session/routes/game_manager.rs index 5a632acd..f6ce6785 100644 --- a/src/session/routes/game_manager.rs +++ b/src/session/routes/game_manager.rs @@ -1,79 +1,58 @@ use crate::{ services::{ - game::{ - manager::{ - CreateMessage, GameManager, GetGameMessage, ProcessQueueMessage, - QueuePlayerMessage, TryAddMessage, TryAddResult, - }, - models::{DatalessContext, GameSetupContext}, - AddPlayerMessage, CheckJoinableMessage, GameJoinableState, GamePlayer, - GetGameDataMessage, RemovePlayerMessage, SetAttributesMessage, SetSettingMessage, - SetStateMessage, UpdateMeshMessage, - }, - sessions::{LookupMessage, Sessions}, + game::{manager::GameManager, GameJoinableState, GamePlayer}, + sessions::Sessions, }, session::{ models::{ errors::{GlobalError, ServerResult}, game_manager::*, }, - router::{Blaze, Extension, RawBlaze}, - GetGamePlayerMessage, GetPlayerGameMessage, GetPlayerIdMessage, SessionLink, + router::{Blaze, Extension, RawBlaze, SessionAuth}, + SessionLink, }, }; -use interlink::prelude::Link; use log::{debug, info}; use std::sync::Arc; pub async fn handle_join_game( - session: SessionLink, - Extension(sessions): Extension>, - Extension(game_manager): Extension>, + player: GamePlayer, + Extension(sessions): Extension>, + Extension(game_manager): Extension>, Blaze(req): Blaze, ) -> ServerResult> { - // Load the session - let player: GamePlayer = session - .send(GetGamePlayerMessage) - .await? - .ok_or(GlobalError::AuthenticationRequired)?; - // Lookup the session join target let session = sessions - .send(LookupMessage { - player_id: req.user.id, - }) - .await; - - // Ensure there wasn't an error - let session = match session { - Ok(Some(value)) => value, - _ => return Err(GlobalError::System.into()), - }; + .lookup_session(req.user.id) + .await + .ok_or(GameManagerError::JoinPlayerFailed)?; // Find the game ID for the target session - let game_id = session.send(GetPlayerGameMessage {}).await; - let game_id = match game_id { - Ok(Some(value)) => value, - _ => return Err(GlobalError::System.into()), - }; - - let game = game_manager.send(GetGameMessage { game_id }).await?; - - let game = match game { - Some(value) => value, - None => return Err(GameManagerError::InvalidGameId.into()), - }; + let (game_id, game_ref) = session + .get_game() + .await + .ok_or(GameManagerError::InvalidGameId)?; // Check the game is joinable - let join_state = game.send(CheckJoinableMessage { rule_set: None }).await?; + let join_state = { + let game = &*game_ref.read().await; + game.joinable_state(None) + }; // Join the game if let GameJoinableState::Joinable = join_state { debug!("Joining game from invite (GID: {})", game_id); - let _ = game.do_send(AddPlayerMessage { - player, - context: GameSetupContext::Dataless(DatalessContext::JoinGameSetup), - }); + + game_manager + .add_to_game( + game_ref, + player, + GameSetupContext::Dataless { + context: DatalessContext::JoinGameSetup, + }, + ) + .await; + Ok(Blaze(JoinGameResponse { game_id, state: JoinGameState::JoinedGame, @@ -85,7 +64,7 @@ pub async fn handle_join_game( pub async fn handle_get_game_data( Blaze(mut req): Blaze, - Extension(game_manager): Extension>, + Extension(game_manager): Extension>, ) -> ServerResult { if req.game_list.is_empty() { return Err(GlobalError::System.into()); @@ -94,11 +73,13 @@ pub async fn handle_get_game_data( let game_id = req.game_list.remove(0); let game = game_manager - .send(GetGameMessage { game_id }) - .await? + .get_game(game_id) + .await .ok_or(GameManagerError::InvalidGameId)?; - let body = game.send(GetGameDataMessage).await?; + let game = &*game.read().await; + + let body = game.game_data().await; Ok(body) } @@ -155,25 +136,27 @@ pub async fn handle_get_game_data( /// } /// ``` pub async fn handle_create_game( - session: SessionLink, - Extension(game_manager): Extension>, + player: GamePlayer, + Extension(game_manager): Extension>, Blaze(req): Blaze, ) -> ServerResult> { - let player: GamePlayer = session - .send(GetGamePlayerMessage) - .await? - .ok_or(GlobalError::AuthenticationRequired)?; - - let (link, game_id) = game_manager - .send(CreateMessage { - attributes: req.attributes, - setting: req.setting, - host: player, - }) - .await?; + let (link, game_id) = game_manager.create_game(req.attributes, req.setting).await; // Notify matchmaking of the new game - let _ = game_manager.do_send(ProcessQueueMessage { link, game_id }); + tokio::spawn(async move { + game_manager + .add_to_game( + link.clone(), + player, + GameSetupContext::Dataless { + context: DatalessContext::CreateGameSetup, + }, + ) + .await; + + // Update matchmaking with the new game + game_manager.process_queue(link, game_id).await; + }); Ok(Blaze(CreateGameResponse { game_id })) } @@ -199,22 +182,30 @@ pub async fn handle_create_game( /// } /// ``` pub async fn handle_set_attributes( - Extension(game_manager): Extension>, + Extension(game_manager): Extension>, Blaze(req): Blaze, ) -> ServerResult<()> { let link = game_manager - .send(GetGameMessage { - game_id: req.game_id, - }) - .await?; - - if let Some(link) = link { - link.send(SetAttributesMessage { - attributes: req.attributes, - }) - .await?; + .get_game(req.game_id) + .await + .ok_or(GameManagerError::InvalidGameId)?; + + { + let game = &mut *link.write().await; + game.set_attributes(req.attributes); } + // Update matchmaking for the changed game + tokio::spawn(async move { + let join_state = { + let game = &*link.read().await; + game.joinable_state(None) + }; + if let GameJoinableState::Joinable = join_state { + game_manager.process_queue(link, req.game_id).await; + } + }); + Ok(()) } @@ -229,18 +220,16 @@ pub async fn handle_set_attributes( /// } /// ``` pub async fn handle_set_state( - Extension(game_manager): Extension>, + Extension(game_manager): Extension>, Blaze(req): Blaze, ) -> ServerResult<()> { let link = game_manager - .send(GetGameMessage { - game_id: req.game_id, - }) - .await?; + .get_game(req.game_id) + .await + .ok_or(GameManagerError::InvalidGameId)?; - if let Some(link) = link { - link.send(SetStateMessage { state: req.state }).await?; - } + let game = &mut *link.write().await; + game.set_state(req.state); Ok(()) } @@ -256,21 +245,16 @@ pub async fn handle_set_state( /// } /// ``` pub async fn handle_set_setting( - Extension(game_manager): Extension>, + Extension(game_manager): Extension>, Blaze(req): Blaze, ) -> ServerResult<()> { let link = game_manager - .send(GetGameMessage { - game_id: req.game_id, - }) - .await?; - - if let Some(link) = link { - link.send(SetSettingMessage { - setting: req.setting, - }) - .await?; - } + .get_game(req.game_id) + .await + .ok_or(GameManagerError::InvalidGameId)?; + + let game = &mut *link.write().await; + game.set_settings(req.setting); Ok(()) } @@ -289,25 +273,18 @@ pub async fn handle_set_setting( /// } /// ``` pub async fn handle_remove_player( - Extension(game_manager): Extension>, + Extension(game_manager): Extension>, Blaze(req): Blaze, -) { - let game = match game_manager - .send(GetGameMessage { - game_id: req.game_id, - }) +) -> ServerResult<()> { + let link = game_manager + .get_game(req.game_id) .await - { - Ok(Some(value)) => value, - _ => return, - }; + .ok_or(GameManagerError::InvalidGameId)?; - let _ = game - .send(RemovePlayerMessage { - reason: req.reason, - id: req.player_id, - }) - .await; + let game = &mut *link.write().await; + game.remove_player(req.player_id, req.reason); + + Ok(()) } /// Handles updating mesh connections @@ -327,38 +304,22 @@ pub async fn handle_remove_player( /// } /// ``` pub async fn handle_update_mesh_connection( - session: SessionLink, - Extension(game_manager): Extension>, + SessionAuth(player): SessionAuth, + Extension(game_manager): Extension>, Blaze(mut req): Blaze, ) -> ServerResult<()> { - let id = match session.send(GetPlayerIdMessage).await? { - Some(value) => value, - None => return Err(GlobalError::AuthenticationRequired.into()), - }; - let target = match req.targets.pop() { Some(value) => value, None => return Ok(()), }; let link = game_manager - .send(GetGameMessage { - game_id: req.game_id, - }) - .await?; - - let link = match link { - Some(value) => value, - None => return Ok(()), - }; + .get_game(req.game_id) + .await + .ok_or(GameManagerError::InvalidGameId)?; - let _ = link - .send(UpdateMeshMessage { - id, - target: target.player_id, - state: target.state, - }) - .await; + let game = &mut *link.write().await; + game.update_mesh(player.id, target.player_id, target.state); Ok(()) } @@ -485,34 +446,21 @@ pub async fn handle_update_mesh_connection( /// } /// ``` pub async fn handle_start_matchmaking( - session: SessionLink, - Extension(game_manager): Extension>, + player: GamePlayer, + Extension(game_manager): Extension>, Blaze(req): Blaze, ) -> ServerResult> { - let player: GamePlayer = session - .send(GetGamePlayerMessage) - .await? - .ok_or(GlobalError::AuthenticationRequired)?; - let session_id = player.player.id; info!("Player {} started matchmaking", player.player.display_name); - let rule_set = Arc::new(req.rules); - - let result = game_manager - .send(TryAddMessage { - player, - rule_set: rule_set.clone(), - }) - .await?; - - // If adding failed attempt to queue instead - if let TryAddResult::Failure(player) = result { - game_manager - .send(QueuePlayerMessage { player, rule_set }) - .await?; - } + tokio::spawn(async move { + let rule_set = Arc::new(req.rules); + // If adding failed attempt to queue instead + if let Err(player) = game_manager.try_add(player, &rule_set).await { + game_manager.queue(player, rule_set).await; + } + }); Ok(Blaze(MatchmakingResponse { id: session_id })) } @@ -527,11 +475,11 @@ pub async fn handle_start_matchmaking( /// "MSID": 1 /// } /// ``` -pub async fn handle_cancel_matchmaking(session: SessionLink) { - session - .exec(|session, _| { - session.remove_games(); - }) - .await - .ok(); +pub async fn handle_cancel_matchmaking( + session: SessionLink, + SessionAuth(player): SessionAuth, + Extension(game_manager): Extension>, +) { + session.remove_from_game().await; + game_manager.remove_queue(player.id).await; } diff --git a/src/session/routes/messaging.rs b/src/session/routes/messaging.rs index 4b7e7795..0297d4f3 100644 --- a/src/session/routes/messaging.rs +++ b/src/session/routes/messaging.rs @@ -3,8 +3,8 @@ use crate::{ session::{ models::messaging::*, packet::Packet, - router::{Blaze, Extension}, - GetPlayerMessage, PushExt, SessionLink, + router::{Blaze, Extension, SessionAuth}, + SessionLink, }, utils::components::messaging, }; @@ -23,23 +23,17 @@ use std::sync::Arc; /// "PSIZ": 0, /// "SMSK": 0, /// "SORT": 0, -/// " (0, 0, 0), +/// "SRCE": (0, 0, 0), /// "STAT": 0, /// "TARG": (0, 0, 0), /// "TYPE": 0 /// } /// ``` -/// pub async fn handle_fetch_messages( session: SessionLink, + SessionAuth(player): SessionAuth, Extension(config): Extension>, ) -> Blaze { - // Request a copy of the player data - let Ok(Some(player)) = session.send(GetPlayerMessage).await else { - // Not authenticated return empty count - return Blaze(FetchMessageResponse { count: 0 }); - }; - // Message with player name replaced let mut message: String = config .menu_message diff --git a/src/session/routes/other.rs b/src/session/routes/other.rs index 21b2df46..474b9385 100644 --- a/src/session/routes/other.rs +++ b/src/session/routes/other.rs @@ -1,5 +1,5 @@ use crate::{ - session::{models::other::*, packet::Packet, router::Blaze, PushExt, SessionLink}, + session::{models::other::*, packet::Packet, router::Blaze, SessionLink}, utils::components::game_reporting, }; diff --git a/src/session/routes/stats.rs b/src/session/routes/stats.rs index 4e3b5998..19c63686 100644 --- a/src/session/routes/stats.rs +++ b/src/session/routes/stats.rs @@ -1,55 +1,49 @@ use crate::{ - services::leaderboard::{models::*, Leaderboard, QueryMessage}, + services::leaderboard::Leaderboard, session::{ - models::{errors::ServerResult, stats::*}, + models::stats::*, packet::Packet, router::{Blaze, BlazeWithHeader, Extension}, }, }; -use interlink::prelude::Link; use sea_orm::DatabaseConnection; use std::sync::Arc; pub async fn handle_normal_leaderboard( - Extension(leaderboard): Extension>, + Extension(leaderboard): Extension>, Extension(db): Extension, req: BlazeWithHeader, -) -> ServerResult { +) -> Packet { let query = &req.req; - let group = get_group(db, leaderboard, &query.name).await?; - let response = match group.get_normal(query.start, query.count) { - Some((values, _)) => LeaderboardResponse::Many(values), - None => LeaderboardResponse::Empty, - }; - Ok(req.response(response)) + let group = leaderboard.query(query.name, &db).await; + let slice = group + .get_normal(query.start, query.count) + .unwrap_or_default(); + req.response(LeaderboardResponse::Borrowed(slice)) } pub async fn handle_centered_leaderboard( - Extension(leaderboard): Extension>, + Extension(leaderboard): Extension>, Extension(db): Extension, req: BlazeWithHeader, -) -> ServerResult { +) -> Packet { let query = &req.req; - let group = get_group(db, leaderboard, &query.name).await?; - let response = match group.get_centered(query.center, query.count) { - Some(values) => LeaderboardResponse::Many(values), - None => LeaderboardResponse::Empty, - }; - Ok(req.response(response)) + let group = leaderboard.query(query.name, &db).await; + let slice = group + .get_centered(query.center, query.count) + .unwrap_or_default(); + req.response(LeaderboardResponse::Borrowed(slice)) } pub async fn handle_filtered_leaderboard( - Extension(leaderboard): Extension>, + Extension(leaderboard): Extension>, Extension(db): Extension, req: BlazeWithHeader, -) -> ServerResult { +) -> Packet { let query = &req.req; - let group = get_group(db, leaderboard, &query.name).await?; - let response = match group.get_entry(query.id) { - Some(value) => LeaderboardResponse::One(value), - None => LeaderboardResponse::Empty, - }; - Ok(req.response(response)) + let group = leaderboard.query(query.name, &db).await; + let response = group.get_filtered(&query.ids); + req.response(LeaderboardResponse::Owned(response)) } /// Handles returning the number of leaderboard objects present. @@ -68,23 +62,13 @@ pub async fn handle_filtered_leaderboard( /// } /// ``` pub async fn handle_leaderboard_entity_count( - Extension(leaderboard): Extension>, + Extension(leaderboard): Extension>, Extension(db): Extension, Blaze(req): Blaze, -) -> ServerResult> { - let group = get_group(db, leaderboard, &req.name).await?; +) -> Blaze { + let group = leaderboard.query(req.name, &db).await; let count = group.values.len(); - Ok(Blaze(EntityCountResponse { count })) -} - -async fn get_group( - db: DatabaseConnection, - leaderboard: Link, - name: &str, -) -> ServerResult> { - let ty = LeaderboardType::from_value(name); - let result = leaderboard.send(QueryMessage(ty, db)).await?; - Ok(result) + Blaze(EntityCountResponse { count }) } fn get_locale_name(code: &str) -> &str { diff --git a/src/session/routes/user_sessions.rs b/src/session/routes/user_sessions.rs index 66b21564..7305d27b 100644 --- a/src/session/routes/user_sessions.rs +++ b/src/session/routes/user_sessions.rs @@ -1,21 +1,19 @@ -use interlink::prelude::Link; -use sea_orm::DatabaseConnection; - use crate::{ database::entities::Player, - services::sessions::{LookupMessage, Sessions, VerifyError, VerifyTokenMessage}, + services::sessions::{Sessions, VerifyError}, session::{ models::{ auth::{AuthResponse, AuthenticationError}, - errors::{GlobalError, ServerResult}, + errors::ServerResult, user_sessions::*, + NetworkAddress, }, router::{Blaze, Extension}, - GetLookupMessage, GetSocketAddrMessage, HardwareFlagMessage, LookupResponse, - NetworkInfoMessage, SessionLink, SetPlayerMessage, + LookupResponse, SessionLink, }, - utils::models::NetworkAddress, }; +use sea_orm::DatabaseConnection; +use std::sync::Arc; /// Attempts to lookup another authenticated session details /// @@ -34,27 +32,20 @@ use crate::{ /// ``` pub async fn handle_lookup_user( Blaze(req): Blaze, - Extension(sessions): Extension>, + Extension(sessions): Extension>, ) -> ServerResult> { // Lookup the session - let session = sessions - .send(LookupMessage { - player_id: req.player_id, - }) - .await; - // Ensure there wasn't an error - let session = match session { - Ok(Some(value)) => value, - _ => return Err(GlobalError::System.into()), - }; + let session = sessions + .lookup_session(req.player_id) + .await + .ok_or(UserSessionsError::UserNotFound)?; // Get the lookup response from the session - let response = session.send(GetLookupMessage {}).await; - let response = match response { - Ok(Some(value)) => value, - _ => return Err(GlobalError::System.into()), - }; + let response = session + .get_lookup() + .await + .ok_or(UserSessionsError::UserNotFound)?; Ok(Blaze(response)) } @@ -72,15 +63,14 @@ pub async fn handle_lookup_user( pub async fn handle_resume_session( session: SessionLink, Extension(db): Extension, - Extension(sessions): Extension>, + Extension(sessions): Extension>, Blaze(req): Blaze, ) -> ServerResult> { let session_token = req.session_token; // Verify the authentication token let player_id = sessions - .send(VerifyTokenMessage(session_token.clone())) - .await? + .verify_token(&session_token) .map_err(|err| match err { VerifyError::Expired => AuthenticationError::ExpiredToken, VerifyError::Invalid => AuthenticationError::InvalidToken, @@ -90,9 +80,8 @@ pub async fn handle_resume_session( .await? .ok_or(AuthenticationError::InvalidToken)?; - // Failing to set the player likely the player disconnected or - // the server is shutting down - session.send(SetPlayerMessage(Some(player.clone()))).await?; + let player = session.set_player(player).await; + sessions.add_session(player.id, session).await; Ok(Blaze(AuthResponse { player, @@ -139,21 +128,15 @@ pub async fn handle_update_network( // If address is missing if ext.addr.is_unspecified() { - // Obtain socket address from session - if let Ok(addr) = session.send(GetSocketAddrMessage).await { - // Replace address with new address and port with same as local port - ext.addr = addr; - ext.port = pair.internal.port; - } + // Replace address with new address and port with same as local port + ext.addr = session.addr; + ext.port = pair.internal.port; } } - let _ = session - .send(NetworkInfoMessage { - address: req.address, - qos: req.qos, - }) - .await; + tokio::spawn(async move { + session.set_network_info(req.address, req.qos).await; + }); } /// Handles updating the stored hardware flag with the client provided hardware flag @@ -167,11 +150,9 @@ pub async fn handle_update_network( /// ``` pub async fn handle_update_hardware_flag( session: SessionLink, - Blaze(req): Blaze, + Blaze(req): Blaze, ) { - let _ = session - .send(HardwareFlagMessage { - value: req.hardware_flag, - }) - .await; + tokio::spawn(async move { + session.set_hardware_flags(req.hardware_flags).await; + }); } diff --git a/src/session/routes/util.rs b/src/session/routes/util.rs index f4e535ee..e80e5c0b 100644 --- a/src/session/routes/util.rs +++ b/src/session/routes/util.rs @@ -1,24 +1,25 @@ use crate::{ - config::VERSION, + config::{RuntimeConfig, VERSION}, database::entities::PlayerData, session::{ models::{ - errors::{GlobalError, ServerResult}, + errors::{BlazeError, GlobalError, ServerResult}, util::*, }, - router::{Blaze, Extension}, - DetailsMessage, GetHostTarget, GetPlayerIdMessage, SessionLink, + router::{Blaze, Extension, SessionAuth}, + SessionLink, }, }; use base64ct::{Base64, Encoding}; use embeddy::Embedded; use flate2::{write::ZlibEncoder, Compression}; -use interlink::prelude::Link; use log::error; use sea_orm::DatabaseConnection; use std::{ + cmp::Ordering, io::Write, path::Path, + sync::Arc, time::{Duration, SystemTime, UNIX_EPOCH}, }; use tdf::TdfMap; @@ -79,9 +80,10 @@ pub async fn handle_get_ticker_server() -> Blaze { /// } /// } /// ``` -pub async fn handle_pre_auth(session: SessionLink) -> ServerResult> { - let host_target = session.send(GetHostTarget {}).await?; - Ok(Blaze(PreAuthResponse { host_target })) +pub async fn handle_pre_auth( + Extension(config): Extension>, +) -> ServerResult> { + Ok(Blaze(PreAuthResponse { config })) } /// Handles post authentication requests. This provides information about other @@ -92,21 +94,17 @@ pub async fn handle_pre_auth(session: SessionLink) -> ServerResult ServerResult> { - let player_id = session - .send(GetPlayerIdMessage) - .await? - .ok_or(GlobalError::AuthenticationRequired)?; - - // Queue the session details to be sent to this client - let _ = session.do_send(DetailsMessage { - link: Link::clone(&session), - }); +pub async fn handle_post_auth( + session: SessionLink, + SessionAuth(player): SessionAuth, +) -> ServerResult> { + // Subscribe to the session with itself + session.add_subscriber(player.id, session.clone()).await; Ok(Blaze(PostAuthResponse { telemetry: TelemetryServer, ticker: TickerServer, - player_id, + player_id: player.id, })) } @@ -121,8 +119,7 @@ pub async fn handle_post_auth(session: SessionLink) -> ServerResult Blaze { - let now = SystemTime::now(); - let server_time = now + let server_time = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap_or(Duration::ZERO) .as_secs(); @@ -151,11 +148,10 @@ const ME3_DIME: &str = include_str!("../../resources/data/dime.xml"); /// } /// ``` pub async fn handle_fetch_client_config( - session: SessionLink, Blaze(req): Blaze, ) -> ServerResult> { - let config = match req.id.as_ref() { - "ME3_DATA" => data_config(&session).await, + let config = match req.id.as_str() { + "ME3_DATA" => data_config(), "ME3_MSG" => messages(), "ME3_ENT" => load_entitlements(), "ME3_DIME" => { @@ -169,10 +165,16 @@ pub async fn handle_fetch_client_config( map.insert("VERSION".to_string(), "40128".to_string()); map } - "ME3_BINI_PC_COMPRESSED" => load_coalesced().await?, + "ME3_BINI_PC_COMPRESSED" => match load_coalesced().await { + Ok(map) => map, + Err(err) => { + error!("Failed to load server coalesced: {}", err); + return Err(GlobalError::System.into()); + } + }, id => { if let Some(lang) = id.strip_prefix("ME3_LIVE_TLK_PC_") { - talk_file(lang).await? + talk_file(lang).await } else { TdfMap::default() } @@ -185,22 +187,24 @@ pub async fn handle_fetch_client_config( /// Loads the entitlements from the entitlements file and parses /// it as a fn load_entitlements() -> TdfMap { - let mut map = TdfMap::::new(); - for (key, value) in ME3_ENT.lines().filter_map(|line| line.split_once('=')) { - map.insert(key.to_string(), value.to_string()); - } - map + ME3_ENT + .lines() + .filter_map(|line| line.split_once('=')) + .map(|(key, value)| (key.to_string(), value.to_string())) + .collect() } /// Loads the local coalesced if one is present falling back /// to the default one on error or if its missing -async fn load_coalesced() -> ServerResult { +async fn load_coalesced() -> std::io::Result { let local_path = Path::new("data/coalesced.bin"); + if local_path.is_file() { - if let Ok(bytes) = read(local_path).await { - if let Ok(map) = generate_coalesced(&bytes) { - return Ok(map); - } + if let Ok(map) = read(local_path) + .await + .and_then(|bytes| generate_coalesced(&bytes)) + { + return Ok(map); } error!( @@ -216,17 +220,11 @@ async fn load_coalesced() -> ServerResult { /// Generates a compressed caolesced from the provided bytes /// /// `bytes` The coalesced bytes -fn generate_coalesced(bytes: &[u8]) -> ServerResult { +fn generate_coalesced(bytes: &[u8]) -> std::io::Result { let compressed: Vec = { let mut encoder = ZlibEncoder::new(Vec::new(), Compression::new(6)); - encoder.write_all(bytes).map_err(|_| { - error!("Failed to encode coalesced with ZLib (write stage)"); - GlobalError::System - })?; - encoder.finish().map_err(|_| { - error!("Failed to encode coalesced with ZLib (finish stage)"); - GlobalError::System - })? + encoder.write_all(bytes)?; + encoder.finish()? }; let mut encoded = Vec::with_capacity(16 + compressed.len()); @@ -278,35 +276,35 @@ fn create_base64_map(bytes: &[u8]) -> ChunkMap { output } +/// Default talk file values +#[derive(Embedded)] +#[folder = "src/resources/data/tlk"] +struct DefaultTlkFiles; + /// Retrieves a talk file for the specified language code falling back /// to the `ME3_TLK_DEFAULT` default talk file if it could not be found /// /// `lang` The talk file language -async fn talk_file(lang: &str) -> ServerResult { - let file_name = format!("data/{}.tlk", lang); - let local_path = Path::new(&file_name); +async fn talk_file(lang: &str) -> ChunkMap { + let file_name = format!("{}.tlk", lang); + let local_path = format!("data/{}", file_name); + let local_path = Path::new(&local_path); if local_path.is_file() { - if let Ok(bytes) = read(local_path).await { - return Ok(create_base64_map(&bytes)); + if let Ok(map) = read(local_path) + .await + .map(|bytes| create_base64_map(&bytes)) + { + return map; } error!("Unable to load local talk file falling back to default."); } - // Load default talk file - let file_name = format!("{}.tlk", lang); - Ok(if let Some(file) = DefaultTlkFiles::get(&file_name) { - create_base64_map(file) - } else { - let bytes: &[u8] = include_bytes!("../../resources/data/tlk/default.tlk"); - create_base64_map(bytes) - }) -} + let bytes = DefaultTlkFiles::get(&file_name) + .unwrap_or(include_bytes!("../../resources/data/tlk/default.tlk")); -/// Default talk file values -#[derive(Embedded)] -#[folder = "src/resources/data/tlk"] -struct DefaultTlkFiles; + create_base64_map(bytes) +} /// Loads the messages that should be displayed to the client and /// returns them in a list. @@ -426,22 +424,8 @@ impl Message { /// Image Server: http://eaassets-a.akamaihd.net/gameplayservices/prod/MassEffect/3/ /// Telemetry Server: 159.153.235.32:9988 /// -async fn data_config(session: &SessionLink) -> TdfMap { - let host_target = match session.send(GetHostTarget).await { - Ok(value) => value, - Err(_) => return TdfMap::with_capacity(0), - }; - - let prefix = if host_target.local_http { - format!("http://127.0.0.1:{}", LOCAL_HTTP_PORT) - } else { - format!( - "{}{}:{}", - host_target.scheme.value(), - host_target.host, - host_target.port - ) - }; +fn data_config() -> TdfMap { + let prefix = format!("http://127.0.0.1:{}", LOCAL_HTTP_PORT); let tele_port = TELEMETRY_PORT; @@ -474,6 +458,13 @@ async fn data_config(session: &SessionLink) -> TdfMap { /// Handles suspend user ping packets. The usage of this is unknown and needs /// further research /// +/// Handles suspending user ping timeout for a specific period of time. The client +/// provides a time in microseconds and the server responds with whether it will +/// allow the time +/// +/// [UtilError::] +/// +/// /// ``` /// Route: Util(SuspendUserPing) /// ID: 31 @@ -481,12 +472,13 @@ async fn data_config(session: &SessionLink) -> TdfMap { /// "TVAL": 90000000 /// } /// ``` -pub async fn handle_suspend_user_ping(Blaze(req): Blaze) -> ServerResult<()> { - match req.value { - 20000000 => Err(UtilError::SuspendPingTimeTooSmall.into()), - 90000000 => Err(UtilError::PingSuspended.into()), - _ => Ok(()), - } +pub async fn handle_suspend_user_ping(Blaze(req): Blaze) -> BlazeError { + let res = match req.time_value.cmp(&90000000) { + Ordering::Less => UtilError::SuspendPingTimeTooSmall, + Ordering::Greater => UtilError::SuspendPingTimeTooLarge, + Ordering::Equal => UtilError::PingSuspended, + }; + res.into() } /// Handles updating the stored data for this account @@ -501,16 +493,11 @@ pub async fn handle_suspend_user_ping(Blaze(req): Blaze) -> /// } /// ``` pub async fn handle_user_settings_save( - session: SessionLink, + SessionAuth(player): SessionAuth, Extension(db): Extension, Blaze(req): Blaze, ) -> ServerResult<()> { - let player = session - .send(GetPlayerIdMessage) - .await? - .ok_or(GlobalError::AuthenticationRequired)?; - - PlayerData::set(&db, player, req.key, req.value).await?; + PlayerData::set(&db, player.id, req.key, req.value).await?; Ok(()) } @@ -523,21 +510,15 @@ pub async fn handle_user_settings_save( /// Content: {} /// ``` pub async fn handle_load_settings( - session: SessionLink, + SessionAuth(player): SessionAuth, Extension(db): Extension, ) -> ServerResult> { - let player = session - .send(GetPlayerIdMessage) - .await? - .ok_or(GlobalError::AuthenticationRequired)?; - // Load the player data from the database - let data: Vec = PlayerData::all(&db, player).await?; + let settings: TdfMap = PlayerData::all(&db, player.id) + .await? + .into_iter() + .map(|entry| (entry.key, entry.value)) + .collect(); - // Encode the player data into a settings map and order it - let mut settings = TdfMap::::with_capacity(data.len()); - for value in data { - settings.insert(value.key, value.value); - } Ok(Blaze(SettingsResponse { settings })) } diff --git a/src/utils/components.rs b/src/utils/components.rs index 4e2e36a4..c17139b6 100644 --- a/src/utils/components.rs +++ b/src/utils/components.rs @@ -18,6 +18,27 @@ static COMPONENT_NAMES: &[(u16, &str)] = &[ static mut COMMANDS: Option> = None; static mut NOTIFICATIONS: Option> = None; +// Packets that will have their content omitted for debug logging +#[rustfmt::skip] +pub static OMIT_PACKET_CONTENTS: &[ComponentKey] = &[ + // Hide authentication packets for user privacy + component_key(authentication::COMPONENT, authentication::ORIGIN_LOGIN), + component_key(authentication::COMPONENT, authentication::LOGIN), + component_key(authentication::COMPONENT, authentication::SILENT_LOGIN), + // Hide large data packets that will clog up logs + component_key(authentication::COMPONENT, authentication::LIST_USER_ENTITLEMENTS_2), + component_key(util::COMPONENT, util::FETCH_CLIENT_CONFIG), + component_key(util::COMPONENT, util::USER_SETTINGS_LOAD_ALL), +]; + +// Packets that wont show up in debug logging +#[rustfmt::skip] +pub static DEBUG_IGNORED_PACKETS: &[ComponentKey] = &[ + // Ping messages occur very frequently and contain nothing important + component_key(util::COMPONENT, util::PING), + component_key(util::COMPONENT, util::SUSPEND_USER_PING), +]; + /// Initializes the stored component state. Should only be /// called on initial startup pub fn initialize() { @@ -34,8 +55,7 @@ pub fn get_component_name(component: u16) -> Option<&'static str> { .copied() } -pub fn get_command_name(component: u16, command: u16, notify: bool) -> Option<&'static str> { - let key = component_key(component, command); +pub fn get_command_name(key: ComponentKey, notify: bool) -> Option<&'static str> { let map = if notify { unsafe { NOTIFICATIONS.as_ref() } } else { @@ -370,10 +390,10 @@ pub mod user_sessions { pub const RESUME_SESSION: u16 = 0x23; // Notifications - pub const SET_SESSION: u16 = 0x1; - pub const SESSION_DETAILS: u16 = 0x2; - pub const FETCH_EXTENDED_DATA: u16 = 0x3; - pub const UPDATE_EXTENDED_DATA_ATTRIBUTE: u16 = 0x5; + pub const USER_SESSION_EXTENDED_DATA_UPDATE: u16 = 0x1; + pub const USER_ADDED: u16 = 0x2; + pub const USER_REMOVED: u16 = 0x3; + pub const USER_UPDATED: u16 = 0x5; // Object Types pub const PLAYER_TYPE: ObjectType = ObjectType::new(COMPONENT, 1); @@ -647,10 +667,10 @@ fn notifications() -> HashMap { (component_key(gr::COMPONENT, gr::GAME_REPORT_SUBMITTED), "GameReportSubmitted"), // User Sessions - (component_key(us::COMPONENT, us::SET_SESSION), "SetSession"), - (component_key(us::COMPONENT, us::SESSION_DETAILS), "SessionDetails"), - (component_key(us::COMPONENT, us::UPDATE_EXTENDED_DATA_ATTRIBUTE), "UpdateExtendedDataAttribute"), - (component_key(us::COMPONENT, us::FETCH_EXTENDED_DATA), "FetchExtendedData"), + (component_key(us::COMPONENT, us::USER_SESSION_EXTENDED_DATA_UPDATE), "UserSessionExtendedDataUpdate"), + (component_key(us::COMPONENT, us::USER_ADDED), "UserAdded"), + (component_key(us::COMPONENT, us::USER_UPDATED), "UserUpdated"), + (component_key(us::COMPONENT, us::USER_REMOVED), "UserRemoved"), ] .into_iter() .collect() diff --git a/src/utils/hashing.rs b/src/utils/hashing.rs index fcd5f294..35a53072 100644 --- a/src/utils/hashing.rs +++ b/src/utils/hashing.rs @@ -1,8 +1,13 @@ //! Hashing utility for hashing and verifying passwords +use std::{ + collections::HashMap, + hash::{BuildHasherDefault, Hasher}, +}; + use argon2::{ - Argon2, - password_hash::{self, PasswordVerifier, rand_core::OsRng, SaltString}, PasswordHash, PasswordHasher, + password_hash::{self, rand_core::OsRng, PasswordVerifier, SaltString}, + Argon2, PasswordHash, PasswordHasher, }; /// Hashes the provided password using the Argon2 algorithm returning @@ -30,3 +35,35 @@ pub fn verify_password(password: &str, hash: &str) -> bool { let argon2 = Argon2::default(); argon2.verify_password(password.as_bytes(), &hash).is_ok() } + +pub type IntHashMap = HashMap>; + +/// Hasher implementation that directly uses an integer value +/// instead of any specific hashing algorithm +/// +/// Only implements hashing for [u32] and [u64] +/// +/// Used for hashing packet component paths and type Ids +#[derive(Default)] +pub struct IntHasher(u64); + +impl Hasher for IntHasher { + fn write(&mut self, _: &[u8]) { + panic!("Attempted to use int hasher to hash bytes") + } + + #[inline] + fn write_u64(&mut self, id: u64) { + self.0 = id; + } + + #[inline] + fn write_u32(&mut self, i: u32) { + self.0 = i as u64; + } + + #[inline] + fn finish(&self) -> u64 { + self.0 + } +} diff --git a/src/utils/logging.rs b/src/utils/logging.rs index dd193fc2..8b5d44a7 100644 --- a/src/utils/logging.rs +++ b/src/utils/logging.rs @@ -1,4 +1,3 @@ -use super::models::Port; use log::{info, LevelFilter}; use log4rs::{ append::{console::ConsoleAppender, file::FileAppender}, @@ -58,7 +57,7 @@ pub fn setup(logging_level: LevelFilter) { /// Prints a list of possible urls that can be used to connect to /// this Pocket relay server -pub async fn log_connection_urls(http_port: Port) { +pub async fn log_connection_urls(http_port: u16) { let mut output = String::new(); if let Ok(local_address) = local_ip_address::local_ip() { output.push_str("LAN: "); diff --git a/src/utils/mod.rs b/src/utils/mod.rs index c7c9e849..aabd5855 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,6 +1,6 @@ pub mod components; pub mod hashing; pub mod logging; -pub mod models; pub mod parsing; +pub mod signing; pub mod types; diff --git a/src/utils/models.rs b/src/utils/models.rs deleted file mode 100644 index d5071b62..00000000 --- a/src/utils/models.rs +++ /dev/null @@ -1,213 +0,0 @@ -use crate::utils::types::PlayerID; -use serde::{ser::SerializeStruct, Serialize}; -use std::{fmt::Debug, net::Ipv4Addr}; -use tdf::{GroupSlice, TdfDeserialize, TdfDeserializeOwned, TdfSerialize, TdfType, TdfTyped}; - -/// Networking information for an instance. Contains the -/// host address and the port -pub struct InstanceAddress { - pub host: InstanceHost, - pub port: Port, -} - -impl From<(String, Port)> for InstanceAddress { - fn from((host, port): (String, Port)) -> Self { - let host = InstanceHost::from(host); - Self { host, port } - } -} - -impl TdfSerialize for InstanceAddress { - fn serialize(&self, w: &mut S) { - w.group_body(|w| { - self.host.serialize(w); - w.tag_u16(b"PORT", self.port); - }); - } -} - -impl TdfDeserializeOwned for InstanceAddress { - fn deserialize_owned(r: &mut tdf::TdfDeserializer<'_>) -> tdf::DecodeResult { - let host: InstanceHost = InstanceHost::deserialize_owned(r)?; - let port: u16 = r.tag(b"PORT")?; - GroupSlice::deserialize_content_skip(r)?; - Ok(Self { host, port }) - } -} - -impl TdfTyped for InstanceAddress { - const TYPE: TdfType = TdfType::Group; -} - -/// Type of instance details provided either hostname -/// encoded as string or IP address encoded as NetAddress -pub enum InstanceHost { - Host(String), - Address(Ipv4Addr), -} - -/// Attempts to convert the provided value into a instance type. If -/// the provided value is an IPv4 value then Address is used otherwise -/// Host is used. -impl From for InstanceHost { - fn from(value: String) -> Self { - if let Ok(value) = value.parse::() { - Self::Address(value) - } else { - Self::Host(value) - } - } -} - -/// Function for converting an instance type into its address -/// string value for use in connections -impl From for String { - fn from(value: InstanceHost) -> Self { - match value { - InstanceHost::Address(value) => value.to_string(), - InstanceHost::Host(value) => value, - } - } -} - -impl TdfSerialize for InstanceHost { - fn serialize(&self, w: &mut S) { - match self { - InstanceHost::Host(value) => w.tag_str(b"HOST", value), - InstanceHost::Address(value) => w.tag_u32(b"IP", (*value).into()), - } - } -} - -impl TdfDeserializeOwned for InstanceHost { - fn deserialize_owned(r: &mut tdf::TdfDeserializer<'_>) -> tdf::DecodeResult { - let host: Option = r.try_tag(b"HOST")?; - if let Some(host) = host { - return Ok(Self::Host(host)); - } - let ip: u32 = r.tag(b"IP")?; - Ok(Self::Address(Ipv4Addr::from(ip))) - } -} - -/// Details about an instance. This is used for the redirector system -/// to both encode for redirections and decode for the retriever system -#[derive(TdfDeserialize)] -pub struct InstanceDetails { - /// The networking information for the instance - #[tdf(tag = "ADDR")] - pub net: InstanceNet, - /// Whether the host requires a secure connection (SSLv3) - #[tdf(tag = "SECU")] - pub secure: bool, - #[tdf(tag = "XDNS")] - pub xdns: bool, -} - -#[derive(Default, TdfSerialize, TdfDeserialize, TdfTyped)] -pub enum InstanceNet { - #[tdf(key = 0x0, tag = "VALU")] - InstanceAddress(InstanceAddress), - #[tdf(unset)] - Unset, - #[default] - #[tdf(default)] - Default, - // IpAddress = 0x0, - // XboxServer = 0x1, -} - -#[derive(TdfSerialize)] -pub struct UpdateExtDataAttr { - #[tdf(tag = "FLGS")] - pub flags: u8, - #[tdf(tag = "ID")] - pub player_id: PlayerID, -} - -/// Structure for storing extended network data -#[derive(Debug, Copy, Clone, Default, Serialize, TdfSerialize, TdfDeserialize, TdfTyped)] -#[tdf(group)] -pub struct QosNetworkData { - /// Downstream bits per second - #[tdf(tag = "DBPS")] - pub dbps: u16, - /// Natt type - #[tdf(tag = "NATT")] - pub natt: NatType, - /// Upstream bits per second - #[tdf(tag = "UBPS")] - pub ubps: u16, -} - -// -#[derive(Debug, Default, Copy, Clone, Serialize, TdfDeserialize, TdfSerialize, TdfTyped)] -#[repr(u8)] -pub enum NatType { - Open = 0x0, - Moderate = 0x1, - Sequential = 0x2, - #[default] - Strict = 0x3, - #[tdf(default)] - Unknown = 0x4, -} - -#[derive(Default, Debug, Clone, TdfSerialize, TdfDeserialize, TdfTyped, Serialize)] -#[serde(untagged)] -pub enum NetworkAddress { - #[tdf(key = 0x2, tag = "VALU")] - AddressPair(IpPairAddress), - #[tdf(unset)] - Unset, - #[default] - #[tdf(default)] - Default, - // XboxClient = 0x0, - // XboxServer = 0x1, - // Pair = 0x2, - // IpAddress = 0x3, - // HostnameAddress = 0x4, -} - -/// Type alias for ports which are always u16 -pub type Port = u16; - -#[derive(Debug, Default, Clone, Serialize)] -pub struct NetData { - pub addr: NetworkAddress, - pub qos: QosNetworkData, - pub hardware_flags: u16, -} - -/// Pair of socket addresses -#[derive(Debug, Clone, TdfDeserialize, TdfSerialize, TdfTyped, Serialize)] -#[tdf(group)] -pub struct IpPairAddress { - #[tdf(tag = "EXIP")] - pub external: PairAddress, - #[tdf(tag = "INIP")] - pub internal: PairAddress, -} - -#[derive(Debug, Clone, TdfDeserialize, TdfSerialize, TdfTyped)] -#[tdf(group)] -pub struct PairAddress { - #[tdf(tag = "IP", into = u32)] - pub addr: Ipv4Addr, - #[tdf(tag = "PORT")] - pub port: u16, -} - -impl Serialize for PairAddress { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - // TODO: Dashboard likely incompatible now due to serialize change - let mut s = serializer.serialize_struct("PairAddress", 2)?; - s.serialize_field("address", &self.addr)?; - s.serialize_field("port", &self.port)?; - s.end() - } -} diff --git a/src/utils/signing.rs b/src/utils/signing.rs new file mode 100644 index 00000000..b66a80f1 --- /dev/null +++ b/src/utils/signing.rs @@ -0,0 +1,77 @@ +use argon2::password_hash::rand_core::{OsRng, RngCore}; +use log::{debug, error}; +use ring::hmac::{self, Key, Tag, HMAC_SHA256}; +use std::{io, path::Path}; +use tokio::{ + fs::{write, File}, + io::AsyncReadExt, +}; + +pub struct SigningKey(Key); + +impl AsRef for SigningKey { + fn as_ref(&self) -> &Key { + &self.0 + } +} + +impl SigningKey { + const KEY_LENGTH: usize = 64; + + /// Obtains the global signing key by reading it from a file + /// or generating a new one and saving that to a file + /// + /// Should only be used by the actual app, tests should + /// generate a new signing key + pub async fn global() -> Self { + // Path to the file containing the server secret value + let secret_path = Path::new("data/secret.bin"); + + if secret_path.exists() { + match Self::from_file(secret_path).await { + Ok(value) => return value, + Err(err) => { + error!("Failed to load existing secrets file: {}", err); + } + } + } + + debug!("Generating server secret key..."); + let (key, secret) = Self::generate(); + if let Err(err) = write(secret_path, &secret).await { + error!("Failed to save secrets file: {}", err); + } + + key + } + + #[inline] + fn new(secret: &[u8; Self::KEY_LENGTH]) -> Self { + Self(Key::new(HMAC_SHA256, secret)) + } + + #[inline] + pub fn sign(&self, data: &[u8]) -> Tag { + hmac::sign(&self.0, data) + } + + #[inline] + pub fn verify(&self, data: &[u8], tag: &[u8]) -> bool { + hmac::verify(&self.0, data, tag).is_ok() + } + + /// Generates a new signing key + pub fn generate() -> (Self, [u8; Self::KEY_LENGTH]) { + let mut secret = [0; Self::KEY_LENGTH]; + OsRng.fill_bytes(&mut secret); + (Self::new(&secret), secret) + } + + // Attempts to read a signing key from the provided file + async fn from_file(file: &Path) -> io::Result { + let mut secret = [0; Self::KEY_LENGTH]; + let mut file = File::open(file).await?; + file.read_exact(&mut secret).await?; + Ok(Self::new(&secret)) + } +} diff --git a/src/utils/types.rs b/src/utils/types.rs index d837d234..96551ff8 100644 --- a/src/utils/types.rs +++ b/src/utils/types.rs @@ -1,10 +1,4 @@ -use std::{future::Future, pin::Pin}; - /// Types for differentiating between fields pub type PlayerID = u32; pub type SessionID = u32; pub type GameID = u32; -pub type GameSlot = usize; - -/// Type for boxed futures -pub type BoxFuture<'a, T> = Pin + Send + 'a>>;