diff --git a/Cargo.lock b/Cargo.lock index a2f2a3ab4a..80e4d666b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1941,6 +1941,7 @@ dependencies = [ "si-frontend-types", "si-hash", "si-id", + "si-jwt-public-key", "si-layer-cache", "si-pkg", "si-runtime", @@ -2005,6 +2006,7 @@ dependencies = [ "si-data-nats", "si-data-pg", "si-events", + "si-jwt-public-key", "si-layer-cache", "si-pkg", "si-runtime", @@ -4305,6 +4307,7 @@ dependencies = [ "serde_json", "si-data-pg", "si-hash", + "si-jwt-public-key", "si-pkg", "si-posthog", "si-settings", @@ -6232,6 +6235,7 @@ dependencies = [ "si-data-spicedb", "si-events", "si-frontend-types", + "si-jwt-public-key", "si-layer-cache", "si-pkg", "si-posthog", @@ -6787,6 +6791,21 @@ dependencies = [ "ulid", ] +[[package]] +name = "si-jwt-public-key" +version = "0.1.0" +dependencies = [ + "base64 0.22.1", + "jwt-simple", + "remain", + "serde", + "si-events", + "si-std", + "telemetry", + "thiserror 1.0.69", + "tokio", +] + [[package]] name = "si-layer-cache" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index e2fdbda080..f88b841f3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ members = [ "lib/si-frontend-types-rs", "lib/si-hash", "lib/si-id", + "lib/si-jwt-public-key", "lib/si-layer-cache", "lib/si-pkg", "lib/si-pool-noodle", @@ -93,7 +94,11 @@ async-recursion = "1.0.5" async-trait = "0.1.79" aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } aws-sdk-firehose = "1.46.0" -axum = { version = "0.6.20", features = ["macros", "multipart", "ws"] } # todo: upgrade this alongside hyper/http/tokio-tungstenite +axum = { version = "0.6.20", features = [ + "macros", + "multipart", + "ws", +] } # todo: upgrade this alongside hyper/http/tokio-tungstenite base64 = "0.22.0" blake3 = "1.5.1" bollard = "0.16.1" @@ -103,13 +108,17 @@ ciborium = "0.2.2" clap = { version = "4.5.4", features = ["derive", "color", "env", "wrap_help"] } color-eyre = "0.6.3" colored = "2.1.0" -comfy-table = { version = "7.1.0", features = ["crossterm", "tty", "custom_styling"] } +comfy-table = { version = "7.1.0", features = [ + "crossterm", + "tty", + "custom_styling", +] } config = { version = "0.14.0", default-features = false, features = ["toml"] } console = "0.15.8" convert_case = "0.6.0" criterion = { version = "0.5.1", features = ["async_tokio"] } crossbeam-channel = "0.5.12" -crossbeam-queue = {version = "0.3.10"} +crossbeam-queue = { version = "0.3.10" } darling = "0.20.10" deadpool = { version = "0.10.0", features = ["rt_tokio_1"] } deadpool-postgres = "0.12.1" @@ -128,8 +137,15 @@ fs4 = "0.11.0" glob = "0.3.1" hex = "0.4.3" http = "0.2.12" # todo: upgrade this alongside hyper/axum/tokio-tungstenite/tower-http -hyper = { version = "0.14.28", features = ["client", "http1", "runtime", "server"] } # todo: upgrade this alongside http/axum/tokio-tungstenite/tower-http -hyperlocal = { version = "0.8.0", default-features = false, features = ["client"] } # todo: using the very latest of hyper client 1.x, we _may_ be able to phase this crate +hyper = { version = "0.14.28", features = [ + "client", + "http1", + "runtime", + "server", +] } # todo: upgrade this alongside http/axum/tokio-tungstenite/tower-http +hyperlocal = { version = "0.8.0", default-features = false, features = [ + "client", +] } # todo: using the very latest of hyper client 1.x, we _may_ be able to phase this crate iftree = "1.0.5" include_dir = "0.7.4" indexmap = { version = "2.2.6", features = ["serde", "std"] } @@ -137,14 +153,22 @@ indicatif = "0.17.8" indoc = "2.0.5" inquire = "0.7.4" itertools = "0.12.1" -jwt-simple = { version = "0.12.9", default-features = false, features = ["pure-rust"] } +jwt-simple = { version = "0.12.9", default-features = false, features = [ + "pure-rust", +] } krata-loopdev = "0.0.12" lazy_static = "1.4.0" manyhow = { version = "0.11.4", features = ["darling"] } mime_guess = { version = "=2.0.4" } # TODO(fnichol): 2.0.5 sets an env var in build.rs which needs to be tracked, required by reqwest miniz_oxide = { version = "0.7.2", features = ["simd"] } names = { version = "0.14.0", default-features = false } -nix = { version = "0.26.0", features = ["fs", "mount", "process", "signal", "user"] } +nix = { version = "0.26.0", features = [ + "fs", + "mount", + "process", + "signal", + "user", +] } nkeys = "0.4.0" num_cpus = "1.16.0" once_cell = "1.19.0" @@ -171,13 +195,25 @@ rand = "0.8.5" refinery = { version = "= 0.8.12", features = ["tokio-postgres"] } regex = "1.10.4" remain = "0.2.13" -reqwest = { version = "0.12.2", default-features = false, features = ["rustls-tls", "json", "multipart"] } +reqwest = { version = "0.12.2", default-features = false, features = [ + "rustls-tls", + "json", + "multipart", +] } ring = "=0.17.5" # Upgrading this is possible, but a pain, so we don't want to pick up every new minor version (see: https://github.com/facebook/buck2/commit/91af40b66960d003067c3d241595fb53d1e636c8) -rust-s3 = { version = "0.34.0-rc4", default-features = false, features = ["tokio-rustls-tls"] } +rust-s3 = { version = "0.34.0-rc4", default-features = false, features = [ + "tokio-rustls-tls", +] } rustls = { version = "0.23.18", default-features = false } # NOTE(nick,fletcher): rustls switched to "aws-lc-rs" as its default crypto provider, but we want ring (i.e. we disable the default feature for "aws-lc-rs") rustls-native-certs = "0.7.0" rustls-pemfile = { version = "2.1.1" } -sea-orm = { version = "1.1.0", features = ["sqlx-postgres", "runtime-tokio-rustls", "macros", "with-chrono", "debug-print"] } +sea-orm = { version = "1.1.0", features = [ + "sqlx-postgres", + "runtime-tokio-rustls", + "macros", + "with-chrono", + "debug-print", +] } self-replace = "1.3.7" serde = { version = "1.0.197", features = ["derive", "rc"] } serde-aux = "4.5.0" @@ -196,12 +232,18 @@ syn = { version = "2.0.55", features = ["full", "extra-traits"] } sysinfo = "0.32.0" tar = "0.4.40" tempfile = "3.10.1" -test-log = { version = "0.2.15", default-features = false, features = ["trace"] } +test-log = { version = "0.2.15", default-features = false, features = [ + "trace", +] } thiserror = "1.0.58" thread-priority = "1.1.0" time = "0.3.36" tokio = { version = "1.37.0", features = ["full"] } -tokio-postgres = { version = "0.7.10", features = ["runtime", "with-chrono-0_4", "with-serde_json-1"] } +tokio-postgres = { version = "0.7.10", features = [ + "runtime", + "with-chrono-0_4", + "with-serde_json-1", +] } tokio-postgres-rustls = { version = "0.12.0" } tokio-serde = { version = "0.9.0", features = ["json"] } tokio-stream = { version = "0.1.15", features = ["sync", "time"] } @@ -211,10 +253,20 @@ tokio-util = { version = "0.7.10", features = ["codec", "rt"] } tokio-vsock = { version = "0.4.0" } toml = { version = "0.8.12" } tower = { version = "0.4.13", features = ["full"] } -tower-http = { version = "0.4.4", features = ["compression-br", "compression-deflate", "compression-gzip", "cors", "trace"] } # todo: pinning back to 0.4.4, upgrade this alongside hyper/http/axum/tokio-tungstenite +tower-http = { version = "0.4.4", features = [ + "compression-br", + "compression-deflate", + "compression-gzip", + "cors", + "trace", +] } # todo: pinning back to 0.4.4, upgrade this alongside hyper/http/axum/tokio-tungstenite tracing = { version = "0.1.40" } tracing-opentelemetry = "0.27.0" -tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "std"] } +tracing-subscriber = { version = "0.3.18", features = [ + "env-filter", + "json", + "std", +] } tracing-tunnel = "0.1.0" trybuild = { version = "1.0.99", features = ["diff"] } tryhard = "0.5.1" diff --git a/bin/auth-api/README.md b/bin/auth-api/README.md index eb05f0155c..44658051e7 100644 --- a/bin/auth-api/README.md +++ b/bin/auth-api/README.md @@ -10,6 +10,13 @@ Use `pnpx prisma` to run prisma commands locally. For example ### JWT Signing Key +### ES256 + +- `ssh-keygen -t ecdsa -b 256 -m PEM -f jwtES256.key` +- `openssl ec -in jwtES256.key -pubout -outform PEM -out jwtES256.key.pub` + +### RS256 (deprecated) + - `ssh-keygen -t rsa -b 4096 -m PEM -f jwtRS256.key` - `openssl rsa -in jwtRS256.key -pubout -outform PEM -out jwtRS256.key.pub` diff --git a/bin/auth-api/src/lib/jwt.ts b/bin/auth-api/src/lib/jwt.ts index 91c2f1d877..6e7997e5f0 100644 --- a/bin/auth-api/src/lib/jwt.ts +++ b/bin/auth-api/src/lib/jwt.ts @@ -1,5 +1,11 @@ /* instructions to generate JWT signing key + run `ssh-keygen -t ecdsa -b 256 -m PEM -f jwtES256.key` +- run `openssl ec -in jwtES256.key -pubout -outform PEM -out jwtES256.key.pub` +- `cat jwtES256.key` +- `cat jwtES256.key.pub` + +For RS256: (deprecated) - run `ssh-keygen -t rsa -b 4096 -m PEM -f jwtRS256.key` # Don't add passphrase - run `openssl rsa -in jwtRS256.key -pubout -outform PEM -out jwtRS256.key.pub` - `cat jwtRS256.key` @@ -9,34 +15,92 @@ instructions to generate JWT signing key import fs from "fs"; import JWT from "jsonwebtoken"; -// load private and public key from either env var or paths set in config -// keys in the repo are also used by SDF to verify jwt is signed correctly and in tests to create/sign jwts -let _JWT_PRIVATE_KEY = process.env.JWT_PRIVATE_KEY; -if (!_JWT_PRIVATE_KEY && process.env.JWT_PRIVATE_KEY_PATH) { - // path is relative to .env file - _JWT_PRIVATE_KEY = fs.readFileSync(`${process.env.JWT_PRIVATE_KEY_PATH}`, 'utf-8'); -} -let _JWT_PUBLIC_KEY = process.env.JWT_PUBLIC_KEY; -if (!_JWT_PUBLIC_KEY && process.env.JWT_PUBLIC_KEY_PATH) { - // path is relative to .env file - _JWT_PUBLIC_KEY = fs.readFileSync(`${process.env.JWT_PUBLIC_KEY_PATH}`, 'utf-8'); -} -if (!_JWT_PRIVATE_KEY) throw new Error('Missing JWT signing private key'); -if (!_JWT_PUBLIC_KEY) throw new Error('Missing JWT signing public key'); +const DEFAULT_ALGO = "RS256"; + +type Algo = "RS256" | "ES256"; + +const jwtAlgo = (algo?: string): Algo => { + switch (algo) { + case "RS256": + case "ES256": + return algo; + default: + return DEFAULT_ALGO; + } +}; + +const keyEnvPaths = { + primary: { + private: "JWT_PRIVATE_KEY", + privatePath: "JWT_PRIVATE_KEY_PATH", + public: "JWT_PUBLIC_KEY", + publicPath: "JWT_PUBLIC_KEY_PATH", + algo: "JWT_ALGO", + }, + secondary: { + private: "JWT_2ND_PRIVATE_KEY", + privatePath: "JWT_2ND_PRIVATE_KEY_PATH", + public: "JWT_2ND_PUBLIC_KEY", + publicPath: "JWT_2ND_PUBLIC_KEY_PATH", + algo: "JWT_2ND_ALGO", + }, +}; + +// load private and public keys from either env var or paths set in config keys +// in the repo are also used by SDF to verify jwt is signed correctly and in +// tests to create/sign jwts + +const prepareKeys = (which: "primary" | "secondary"): { privKey?: string, pubKey?: string, algo: Algo } => { + const privateLiteral = process.env[keyEnvPaths[which].private]; + const privatePath = process.env[keyEnvPaths[which].privatePath]; + + let privKey = privateLiteral ?? (privatePath ? fs.readFileSync(privatePath, 'utf-8') : undefined); + if (privKey) { + privKey = privKey.replace(/\\n/g, '\n'); + } + + const publicLiteral = process.env[keyEnvPaths[which].public]; + const publicPath = process.env[keyEnvPaths[which].publicPath]; + + let pubKey = publicLiteral ?? (publicPath ? fs.readFileSync(publicPath, 'utf-8') : undefined); + if (pubKey) { + pubKey = pubKey.replace(/\\n/g, '\n'); + } + + const algo = jwtAlgo(process.env[keyEnvPaths[which].algo]); + + return { + privKey, + pubKey, + algo, + }; +}; + +const { privKey: primaryPrivKey, pubKey: primaryPubKey, algo } = prepareKeys("primary"); +const { pubKey: secondaryPubKey } = prepareKeys("secondary"); -_JWT_PRIVATE_KEY = _JWT_PRIVATE_KEY.replace(/\\n/g, '\n'); -_JWT_PUBLIC_KEY = _JWT_PUBLIC_KEY.replace(/\\n/g, '\n'); +if (!primaryPrivKey) throw new Error('Missing JWT signing private key'); +if (!primaryPubKey) throw new Error('Missing JWT signing public key'); -export const JWT_PUBLIC_KEY = _JWT_PUBLIC_KEY; +export const JWT_PUBLIC_KEY = primaryPubKey; +export const JWT_2ND_PUBLIC_KEY = secondaryPubKey; export function createJWT( payload: Record, options?: Omit, ) { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - return JWT.sign(payload, _JWT_PRIVATE_KEY!, { algorithm: "RS256", ...options }); + return JWT.sign(payload, primaryPrivKey!, { algorithm: algo, ...options }); } export function verifyJWT(token: string) { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - return JWT.verify(token, _JWT_PUBLIC_KEY!); + try { + return JWT.verify(token, primaryPubKey!); + } catch (err) { + if (secondaryPubKey) { + return JWT.verify(token, secondaryPubKey); + } else { + throw err; + } + } } diff --git a/bin/module-index/BUCK b/bin/module-index/BUCK index 3f25618d7e..02ea393f81 100644 --- a/bin/module-index/BUCK +++ b/bin/module-index/BUCK @@ -10,6 +10,7 @@ rust_binary( deps = [ "//lib/module-index-server:module-index-server", "//lib/si-std:si-std", + "//lib/si-jwt-public-key:si-jwt-public-key", "//lib/telemetry-application-rs:telemetry-application", "//third-party/rust:clap", "//third-party/rust:color-eyre", diff --git a/bin/module-index/src/args.rs b/bin/module-index/src/args.rs index 6a60d6e6eb..57cd78e1b9 100644 --- a/bin/module-index/src/args.rs +++ b/bin/module-index/src/args.rs @@ -118,6 +118,15 @@ pub(crate) struct Args { /// The path to the JWT public signing key #[arg(long, env)] pub(crate) jwt_public_key: Option, + + #[arg(long, env)] + pub(crate) jwt_public_key_algo: Option, + + #[arg(long, env)] + pub(crate) jwt_secondary_public_key: Option, + + #[arg(long, env)] + pub(crate) jwt_secondary_public_key_algo: Option, // /// Database migration mode on startup // #[arg(long, value_parser = PossibleValuesParser::new(MigrationMode::variants()))] } diff --git a/bin/module-index/src/main.rs b/bin/module-index/src/main.rs index 0f4d1f7ff1..bdf00cae97 100644 --- a/bin/module-index/src/main.rs +++ b/bin/module-index/src/main.rs @@ -54,8 +54,7 @@ async fn async_main() -> Result<()> { let config = Config::try_from(args)?; - let jwt_public_signing_key = - Server::load_jwt_public_signing_key(config.jwt_signing_public_key_path()).await?; + let jwt_public_signing_key = Server::load_jwt_public_signing_key(&config).await?; // our pg pool works for migrations (refinery) but doesnt work for SeaORM :( // so we set up both connections for now... Would like to clean this up diff --git a/bin/sdf/src/args.rs b/bin/sdf/src/args.rs index 131e32901a..d9562de471 100644 --- a/bin/sdf/src/args.rs +++ b/bin/sdf/src/args.rs @@ -125,7 +125,19 @@ pub(crate) struct Args { /// jwt public signing key as a base64 string #[arg(long)] - pub(crate) jwt_public_signing_key_base64: Option, + pub(crate) jwt_public_signing_key_base64: Option, + + /// jwt public signing key algorithm (ES256 or RS256) + #[arg(long)] + pub(crate) jwt_public_signing_key_algo: Option, + + /// jwt secondary public signing key as a base64 string + #[arg(long)] + pub(crate) jwt_secondary_public_signing_key_base64: Option, + + /// jwt secondary public signing key algorithm (ES256 or RS256) + #[arg(long)] + pub(crate) jwt_secondary_public_signing_key_algo: Option, /// The path at which the layer db cache is created/used on disk [e.g. /banana/] #[arg(long)] @@ -338,9 +350,21 @@ impl TryFrom for Config { base64.to_string(), ); } + if let Some(jwt) = args.jwt_public_signing_key_base64 { - config_map.set("jwt_signing_public_key.key_base64", jwt.to_string()); + config_map.set("jwt_signing_public_key.key_base64", jwt); + } + if let Some(algo) = args.jwt_public_signing_key_algo { + config_map.set("jwt_signing_public_key.algo", algo); + } + + if let Some(jwt) = args.jwt_secondary_public_signing_key_base64 { + config_map.set("jwt_secondary_signing_public_key.key_base64", jwt); + } + if let Some(algo) = args.jwt_secondary_public_signing_key_algo { + config_map.set("jwt_secondary_signing_public_key.algo", algo); } + if let Some(layer_cache_disk_path) = args.layer_db_disk_path { config_map.set("layer_db_config.disk_path", layer_cache_disk_path); } diff --git a/component/init/configs/service.toml b/component/init/configs/service.toml index 93214277f1..2dc64d417c 100644 --- a/component/init/configs/service.toml +++ b/component/init/configs/service.toml @@ -18,6 +18,7 @@ runtime_strategy = "LocalFirecracker" [jwt_signing_public_key] key_base64 = "$SI_JWT_KEY_BASE64" +algo = "RS256" [nats] creds = """ diff --git a/lib/dal-test/BUCK b/lib/dal-test/BUCK index cddc24e614..dc76191df2 100644 --- a/lib/dal-test/BUCK +++ b/lib/dal-test/BUCK @@ -14,6 +14,7 @@ rust_library( "//lib/si-data-nats:si-data-nats", "//lib/si-data-pg:si-data-pg", "//lib/si-events-rs:si-events", + "//lib/si-jwt-public-key:si-jwt-public-key", "//lib/si-layer-cache:si-layer-cache", "//lib/si-pkg:si-pkg", "//lib/si-runtime-rs:si-runtime", diff --git a/lib/dal-test/Cargo.toml b/lib/dal-test/Cargo.toml index b4ea484ad5..0ec5ea5c40 100644 --- a/lib/dal-test/Cargo.toml +++ b/lib/dal-test/Cargo.toml @@ -20,6 +20,7 @@ si-crypto = { path = "../../lib/si-crypto" } si-data-nats = { path = "../../lib/si-data-nats" } si-data-pg = { path = "../../lib/si-data-pg" } si-events = { path = "../../lib/si-events-rs" } +si-jwt-public-key = { path = "../../lib/si-jwt-public-key" } si-layer-cache = { path = "../../lib/si-layer-cache" } si-pkg = { path = "../../lib/si-pkg" } si-runtime = { path = "../../lib/si-runtime-rs" } @@ -50,5 +51,5 @@ tokio = { workspace = true } tokio-util = { workspace = true } tracing-opentelemetry = { workspace = true } tracing-subscriber = { workspace = true } -ulid = { workspace = true } +ulid = { workspace = true } uuid = { workspace = true } diff --git a/lib/dal-test/src/lib.rs b/lib/dal-test/src/lib.rs index a1338d116b..862ed09e23 100644 --- a/lib/dal-test/src/lib.rs +++ b/lib/dal-test/src/lib.rs @@ -30,6 +30,7 @@ use std::{ env, fmt, future::IntoFuture, path::{Path, PathBuf}, + str::FromStr, sync::{Arc, Once}, }; @@ -39,8 +40,7 @@ use dal::{ builtins::func, feature_flags::FeatureFlagService, job::processor::{JobQueueProcessor, NatsProcessor}, - DalContext, DalLayerDb, JetstreamStreams, JwtPublicSigningKey, ModelResult, ServicesContext, - Workspace, + DalContext, DalLayerDb, JetstreamStreams, ModelResult, ServicesContext, Workspace, }; use derive_builder::Builder; use jwt_simple::prelude::RS256KeyPair; @@ -51,9 +51,10 @@ use si_crypto::{ }; use si_data_nats::{jetstream, NatsClient, NatsConfig}; use si_data_pg::{PgPool, PgPoolConfig}; +use si_jwt_public_key::{JwtAlgo, JwtConfig, JwtPublicSigningKeyChain}; use si_layer_cache::hybrid_cache::CacheConfig; use si_runtime::DedicatedExecutor; -use si_std::ResultExt; +use si_std::{CanonicalFile, ResultExt}; use telemetry::prelude::*; use tokio::{fs::File, io::AsyncReadExt, sync::Mutex}; use tokio_util::{sync::CancellationToken, task::TaskTracker}; @@ -170,6 +171,7 @@ pub struct Config { module_index_url: String, veritech_encryption_key_path: String, jwt_signing_public_key_path: String, + jwt_signing_public_key_algo: JwtAlgo, jwt_signing_private_key_path: String, postgres_key_path: String, #[builder(default)] @@ -619,13 +621,22 @@ pub fn random_identifier_string() -> String { } /// Returns a JWT public signing key, which is used to verify claims. -pub async fn jwt_public_signing_key() -> Result { - let jwt_signing_public_key_path = { +pub async fn jwt_public_signing_key() -> Result { + let jwt_config = { let context_builder = TEST_CONTEXT_BUILDER.lock().await; let config = context_builder.config()?; - config.jwt_signing_public_key_path.clone() + let key_file = Some(CanonicalFile::from_str( + &config.jwt_signing_public_key_path, + )?); + + JwtConfig { + key_file, + key_base64: None, + algo: config.jwt_signing_public_key_algo, + } }; - let key = JwtPublicSigningKey::load(&jwt_signing_public_key_path).await?; + + let key = JwtPublicSigningKeyChain::from_config(jwt_config, None).await?; Ok(key) } diff --git a/lib/dal/BUCK b/lib/dal/BUCK index 6ed5269456..23dc900301 100644 --- a/lib/dal/BUCK +++ b/lib/dal/BUCK @@ -24,6 +24,7 @@ rust_library( "//lib/si-frontend-types-rs:si-frontend-types", "//lib/si-hash:si-hash", "//lib/si-id:si-id", + "//lib/si-jwt-public-key:si-jwt-public-key", "//lib/si-layer-cache:si-layer-cache", "//lib/si-pkg:si-pkg", "//lib/si-runtime-rs:si-runtime", @@ -122,7 +123,7 @@ rust_test( ], crate_root = "tests/integration.rs", srcs = glob([ - "tests/**/*.rs", + "tests/**/*.rs", "tests/integration_test/external/ignition/*.ign", ]), env = { diff --git a/lib/dal/Cargo.toml b/lib/dal/Cargo.toml index 3b0141d22d..ab975feb1e 100644 --- a/lib/dal/Cargo.toml +++ b/lib/dal/Cargo.toml @@ -25,6 +25,7 @@ si-data-pg = { path = "../../lib/si-data-pg" } si-events = { path = "../../lib/si-events-rs" } si-frontend-types = { path = "../../lib/si-frontend-types-rs" } si-hash = { path = "../../lib/si-hash" } +si-jwt-public-key = { path = "../../lib/si-jwt-public-key" } si-id = { path = "../../lib/si-id" } si-layer-cache = { path = "../../lib/si-layer-cache" } si-pkg = { path = "../../lib/si-pkg" } diff --git a/lib/dal/src/jwt_key.rs b/lib/dal/src/jwt_key.rs deleted file mode 100644 index e5e4f3ea68..0000000000 --- a/lib/dal/src/jwt_key.rs +++ /dev/null @@ -1,153 +0,0 @@ -use si_std::CanonicalFile; -use std::{io::Cursor, path::Path, pin::Pin, sync::Arc}; - -use base64::{engine::general_purpose, Engine}; -use jwt_simple::{ - algorithms::RS256PublicKey, - prelude::{JWTClaims, RSAPublicKeyLike}, -}; -use serde::{Deserialize, Serialize}; -use si_data_pg::{PgError, PgPoolError}; -use telemetry::prelude::*; -use thiserror::Error; -use tokio::{ - fs, - io::{AsyncRead, AsyncReadExt}, - task::JoinError, -}; - -use crate::{TransactionsError, UserClaim, UserPk, WorkspacePk}; - -#[remain::sorted] -#[derive(Error, Debug)] -pub enum JwtKeyError { - #[error("bad nonce bytes")] - BadNonce, - #[error("failed to decode base64 string: {0}")] - Base64Decode(#[from] base64::DecodeError), - #[error("invalid bearer token")] - BearerToken, - #[error("failed to decrypt secret data")] - Decrypt, - #[error("error creating jwt from config")] - FromConfig, - #[error("error generating new keypair")] - GenerateKeyPair, - #[error("io error: {0}")] - Io(#[from] std::io::Error), - #[error("failed to load jwt encryption key from bytes")] - JwtEncryptionKeyParse, - #[error("failure to build signing key from pem: {0}")] - KeyFromPem(String), - #[error("failure to extract metadata from bearer token: {0}")] - Metadata(String), - #[error("no signing keys - bad news for you!")] - NoKeys, - #[error("pg error: {0}")] - Pg(#[from] PgError), - #[error("pg pool error: {0}")] - PgPool(#[from] PgPoolError), - #[error("{0}")] - TaskJoin(#[from] JoinError), - #[error("failed to convert into PEM format")] - ToPem, - #[error("transactions error: {0}")] - Transactions(#[from] TransactionsError), - #[error(transparent)] - UlidDecode(#[from] ulid::DecodeError), - #[error("failed to build string from utf8: {0}")] - Utf8(#[from] std::string::FromUtf8Error), - #[error("failure to verify token: {0}")] - Verify(String), -} - -pub type JwtKeyResult = Result; - -#[derive(Clone, Debug, Default, Deserialize, Serialize)] -pub struct JwtConfig { - pub key_file: Option, - pub key_base64: Option, -} - -#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] -pub struct SiClaims { - pub user_pk: UserPk, - pub workspace_pk: WorkspacePk, -} - -#[derive(Clone, Debug)] -pub struct JwtPublicSigningKey { - inner: Arc, -} - -impl JwtPublicSigningKey { - pub async fn from_config(config: JwtConfig) -> JwtKeyResult { - match (config.key_file, config.key_base64) { - (Some(path), None) => Self::load(path).await, - (None, Some(b64_string)) => Self::decode(b64_string).await, - _ => Err(JwtKeyError::FromConfig), - } - } - - #[instrument(level = "debug", skip_all)] - pub async fn load(path: impl AsRef) -> JwtKeyResult { - trace!( - path = path.as_ref().to_string_lossy().as_ref(), - "loading jwt public signing key" - ); - let mut file = fs::File::open(path).await?; - Self::from_reader(Pin::new(&mut file)).await - } - - #[instrument(level = "debug", skip_all)] - pub async fn decode(key_string: String) -> JwtKeyResult { - let buf = general_purpose::STANDARD.decode(key_string)?; - - Self::from_reader(Pin::new(&mut Cursor::new(&buf))).await - } - - async fn from_reader(mut reader: Pin<&mut impl AsyncRead>) -> JwtKeyResult { - let mut public_key_string = String::new(); - reader.read_to_string(&mut public_key_string).await?; - - let inner = tokio::task::spawn_blocking(move || { - RS256PublicKey::from_pem(&public_key_string) - .map_err(|err| JwtKeyError::KeyFromPem(format!("{err}"))) - }) - .instrument(trace_span!( - "from_pem", - code.namespace = "jwt_simple::algorithms::RS256PublicKey" - )) - .await??; - - Ok(Self { - inner: Arc::new(inner), - }) - } -} - -#[instrument(level = "debug", skip_all)] -pub async fn validate_bearer_token( - public_key: JwtPublicSigningKey, - bearer_token: impl AsRef, -) -> JwtKeyResult> { - let bearer_token = bearer_token.as_ref(); - let token = if let Some(token) = bearer_token.strip_prefix("Bearer ") { - token.to_string() - } else { - return Err(JwtKeyError::BearerToken); - }; - - let claims = tokio::task::spawn_blocking(move || { - public_key - .inner - .verify_token::(&token, None) - .map_err(|err| JwtKeyError::Verify(format!("{err}"))) - }) - .instrument(trace_span!( - "verfy_token", - code.namespace = "jwt_simple::algorithms::RSAPublicKeyLike" - )) - .await??; - Ok(claims) -} diff --git a/lib/dal/src/lib.rs b/lib/dal/src/lib.rs index dd3e37cec2..d206e6f703 100644 --- a/lib/dal/src/lib.rs +++ b/lib/dal/src/lib.rs @@ -40,7 +40,6 @@ pub mod history_event; pub mod input_sources; pub mod jetstream_streams; pub mod job; -pub mod jwt_key; pub mod key_pair; pub mod label_list; pub mod layer_db_types; @@ -94,7 +93,6 @@ pub use func::{ pub use history_event::{HistoryActor, HistoryEvent, HistoryEventError}; pub use jetstream_streams::{JetstreamStreams, JetstreamStreamsError}; pub use job::processor::{JobQueueProcessor, NatsProcessor}; -pub use jwt_key::JwtPublicSigningKey; pub use key_pair::{KeyPair, KeyPairError, KeyPairResult, PublicKey}; pub use label_list::{LabelEntry, LabelList, LabelListError}; pub use prop::{Prop, PropId, PropKind}; diff --git a/lib/dal/src/user.rs b/lib/dal/src/user.rs index 793b70d656..da2c99c689 100644 --- a/lib/dal/src/user.rs +++ b/lib/dal/src/user.rs @@ -2,14 +2,17 @@ use serde::{Deserialize, Serialize}; use si_data_nats::NatsError; use si_data_pg::PgError; use si_events::ViewId; +use si_jwt_public_key::{ + validate_bearer_token, JwtPublicSigningKeyChain, JwtPublicSigningKeyError, SiJwtClaims, +}; use telemetry::prelude::*; use thiserror::Error; use tokio::task::JoinError; use crate::ws_event::{WsEvent, WsEventResult, WsPayload}; use crate::{ - jwt_key::JwtKeyError, standard_model_accessor_ro, ChangeSetId, DalContext, HistoryEvent, - HistoryEventError, JwtPublicSigningKey, Tenancy, Timestamp, TransactionsError, WorkspacePk, + standard_model_accessor_ro, ChangeSetId, DalContext, HistoryEvent, HistoryEventError, Tenancy, + Timestamp, TransactionsError, WorkspacePk, }; const USER_GET_BY_PK: &str = include_str!("queries/user/get_by_pk.sql"); @@ -23,7 +26,7 @@ pub enum UserError { #[error("failed to join long lived async task; bug!")] Join(#[from] JoinError), #[error(transparent)] - JwtKey(#[from] JwtKeyError), + JwtKey(#[from] JwtPublicSigningKeyError), #[error("nats txn error: {0}")] Nats(#[from] NatsError), #[error("user not found in tenancy: {0} {1:?}")] @@ -212,6 +215,15 @@ pub struct UserClaim { pub workspace_pk: WorkspacePk, } +impl From for UserClaim { + fn from(value: SiJwtClaims) -> Self { + Self { + user_pk: value.user_pk.into_raw_id().into(), + workspace_pk: value.workspace_pk.into_raw_id().into(), + } + } +} + impl UserClaim { pub fn new(user_pk: UserPk, workspace_pk: WorkspacePk) -> Self { UserClaim { @@ -221,11 +233,11 @@ impl UserClaim { } pub async fn from_bearer_token( - public_key: JwtPublicSigningKey, + public_key: JwtPublicSigningKeyChain, token: impl AsRef, ) -> UserResult { - let claims = crate::jwt_key::validate_bearer_token(public_key, &token).await?; - Ok(claims.custom) + let claims = validate_bearer_token(public_key, token).await?; + Ok(claims.custom.into()) } } diff --git a/lib/module-index-server/BUCK b/lib/module-index-server/BUCK index 6cbbb2b794..ea2c99e0f0 100644 --- a/lib/module-index-server/BUCK +++ b/lib/module-index-server/BUCK @@ -9,6 +9,7 @@ rust_library( "//lib/si-data-pg:si-data-pg", "//lib/si-pkg:si-pkg", "//lib/si-hash:si-hash", + "//lib/si-jwt-public-key:si-jwt-public-key", "//lib/si-posthog-rs:si-posthog", "//lib/si-settings:si-settings", "//lib/si-std:si-std", diff --git a/lib/module-index-server/Cargo.toml b/lib/module-index-server/Cargo.toml index b91ab3d00f..4d58c3931b 100644 --- a/lib/module-index-server/Cargo.toml +++ b/lib/module-index-server/Cargo.toml @@ -14,6 +14,7 @@ buck2-resources = { path = "../../lib/buck2-resources" } module-index-types = { path = "../../lib/module-index-types" } si-data-pg = { path = "../../lib/si-data-pg" } si-hash = { path = "../../lib/si-hash" } +si-jwt-public-key = { path = "../../lib/si-jwt-public-key" } si-pkg = { path = "../../lib/si-pkg" } si-posthog = { path = "../../lib/si-posthog-rs" } si-settings = { path = "../../lib/si-settings" } diff --git a/lib/module-index-server/src/app_state.rs b/lib/module-index-server/src/app_state.rs index 53aeb60262..fc7a0b4ac6 100644 --- a/lib/module-index-server/src/app_state.rs +++ b/lib/module-index-server/src/app_state.rs @@ -4,11 +4,12 @@ use std::sync::Arc; use axum::extract::FromRef; use s3::creds::Credentials as AwsCredentials; use sea_orm::DatabaseConnection; +use si_jwt_public_key::JwtPublicSigningKeyChain; pub use si_posthog::PosthogClient; use tokio::sync::{mpsc, Mutex}; -use crate::{jwt_key::JwtPublicSigningKey, s3::S3Config}; +use crate::s3::S3Config; #[remain::sorted] #[derive(Debug, Eq, PartialEq)] @@ -18,7 +19,7 @@ pub enum ShutdownSource {} pub struct AppState { /// A PostgreSQL connection pool. pg_pool: DatabaseConnection, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, aws_creds: AwsCredentials, s3_config: S3Config, @@ -34,7 +35,7 @@ impl AppState { #[allow(clippy::too_many_arguments)] pub fn new( pg_pool: DatabaseConnection, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, aws_creds: AwsCredentials, s3_config: S3Config, @@ -42,7 +43,7 @@ impl AppState { ) -> Self { Self { pg_pool, - jwt_public_signing_key, + jwt_public_signing_key_chain, posthog_client, aws_creds, s3_config, @@ -57,8 +58,8 @@ impl AppState { } /// Gets a reference to the public key used to sign the JWT - pub fn jwt_public_signing_key(&self) -> &JwtPublicSigningKey { - &self.jwt_public_signing_key + pub fn jwt_public_signing_key(&self) -> &JwtPublicSigningKeyChain { + &self.jwt_public_signing_key_chain } /// Gets a reference to the Posthog client. diff --git a/lib/module-index-server/src/config.rs b/lib/module-index-server/src/config.rs index f6cc1cd53a..f228aab15b 100644 --- a/lib/module-index-server/src/config.rs +++ b/lib/module-index-server/src/config.rs @@ -4,6 +4,7 @@ use buck2_resources::Buck2Resources; use derive_builder::Builder; use serde::{Deserialize, Serialize}; use si_data_pg::PgPoolConfig; +use si_jwt_public_key::JwtAlgo; use si_posthog::PosthogConfig; use si_std::{CanonicalFile, CanonicalFileError}; use telemetry::prelude::*; @@ -51,6 +52,12 @@ pub struct Config { instance_id: String, jwt_signing_public_key_path: CanonicalFile, + jwt_signing_public_key_algo: JwtAlgo, + + #[builder(default)] + jwt_secondary_signing_public_key_path: Option, + #[builder(default)] + jwt_secondary_signing_public_key_algo: Option, #[builder(default = "PosthogConfig::default()")] posthog: PosthogConfig, @@ -81,8 +88,25 @@ impl Config { /// Gets a reference to the config's jwt signing public key path. #[must_use] - pub fn jwt_signing_public_key_path(&self) -> &Path { - self.jwt_signing_public_key_path.as_path() + pub fn jwt_signing_public_key_path(&self) -> &CanonicalFile { + &self.jwt_signing_public_key_path + } + + /// Gets a reference to the config's jwt signing public key path. + #[must_use] + pub fn jwt_signing_public_key_algo(&self) -> JwtAlgo { + self.jwt_signing_public_key_algo + } + + /// Gets a reference to the config's jwt secondary signing public key path. + #[must_use] + pub fn jwt_secondary_signing_public_key_path(&self) -> Option<&CanonicalFile> { + self.jwt_secondary_signing_public_key_path.as_ref() + } + + #[must_use] + pub fn jwt_secondary_signing_public_key_algo(&self) -> Option { + self.jwt_secondary_signing_public_key_algo } /// Gets a reference to the config's posthog config. @@ -108,6 +132,12 @@ pub struct ConfigFile { instance_id: String, #[serde(default = "default_jwt_signing_public_key_path")] pub jwt_signing_public_key_path: String, + #[serde(default = "default_jwt_signing_public_key_algo")] + pub jwt_signing_public_key_algo: JwtAlgo, + #[serde(default)] + pub jwt_secondary_signing_public_key_path: Option, + #[serde(default)] + pub jwt_secondary_signing_public_key_algo: Option, #[serde(default)] pub posthog: PosthogConfig, #[serde(default)] @@ -127,6 +157,9 @@ impl Default for ConfigFile { socket_addr: get_default_socket_addr(), instance_id: random_instance_id(), jwt_signing_public_key_path: default_jwt_signing_public_key_path(), + jwt_signing_public_key_algo: default_jwt_signing_public_key_algo(), + jwt_secondary_signing_public_key_path: None, + jwt_secondary_signing_public_key_algo: None, posthog: Default::default(), s3: Default::default(), } @@ -161,6 +194,10 @@ fn default_jwt_signing_public_key_path() -> String { "/run/sdf/jwt_signing_public_key.pem".to_string() } +fn default_jwt_signing_public_key_algo() -> JwtAlgo { + JwtAlgo::RS256 +} + #[allow(clippy::disallowed_methods)] // Used to determine if running in development pub fn detect_and_configure_development(config: &mut ConfigFile) -> Result<()> { if env::var("BUCK_RUN_BUILD_ID").is_ok() || env::var("BUCK_BUILD_ID").is_ok() { diff --git a/lib/module-index-server/src/extract.rs b/lib/module-index-server/src/extract.rs index 5e753e6e93..0cc766f2e1 100644 --- a/lib/module-index-server/src/extract.rs +++ b/lib/module-index-server/src/extract.rs @@ -4,12 +4,9 @@ use axum::{async_trait, extract::FromRequestParts, http::request::Parts, Json}; use hyper::StatusCode; use s3::{Bucket as S3Bucket, Region as AwsRegion}; use sea_orm::{DatabaseTransaction, TransactionTrait}; -use serde::{Deserialize, Serialize}; -use thiserror::Error; -use ulid::Ulid; +use si_jwt_public_key::SiJwtClaims; use super::app_state::AppState; -use crate::jwt_key::{JwtKeyError, JwtPublicSigningKey}; pub struct ExtractedS3Bucket(pub S3Bucket); @@ -80,36 +77,8 @@ impl FromRequestParts for DbConnection { } } -pub type UserPk = Ulid; -pub type WorkspacePk = Ulid; - -#[derive(Deserialize, Serialize, Debug, Clone, Copy)] -pub struct UserClaim { - pub user_pk: UserPk, - pub workspace_pk: WorkspacePk, -} - -#[remain::sorted] -#[derive(Error, Debug)] -pub enum AuthError { - #[error(transparent)] - JwtKey(#[from] JwtKeyError), -} - -pub type AuthResult = Result; - -impl UserClaim { - pub async fn from_bearer_token( - public_key: JwtPublicSigningKey, - token: impl AsRef, - ) -> AuthResult { - let claims = crate::jwt_key::validate_bearer_token(public_key, &token).await?; - Ok(claims.custom) - } -} - pub struct Authorization { - pub user_claim: UserClaim, + pub user_claim: SiJwtClaims, pub auth_token: String, } @@ -121,7 +90,7 @@ impl FromRequestParts for Authorization { parts: &mut Parts, state: &AppState, ) -> Result { - let jwt_public_signing_key = state.jwt_public_signing_key().clone(); + let jwt_public_signing_key = state.jwt_public_signing_key(); let headers = &parts.headers; let authorization_header_value = headers @@ -130,9 +99,11 @@ impl FromRequestParts for Authorization { let auth_token = authorization_header_value .to_str() .map_err(internal_error)?; - let user_claim = UserClaim::from_bearer_token(jwt_public_signing_key, auth_token) - .await - .map_err(|_| unauthorized_error())?; + let user_claim = + si_jwt_public_key::validate_bearer_token(jwt_public_signing_key.clone(), &auth_token) + .await + .map_err(|_| unauthorized_error())? + .custom; Ok(Self { user_claim, diff --git a/lib/module-index-server/src/jwt_key.rs b/lib/module-index-server/src/jwt_key.rs deleted file mode 100644 index d2816136de..0000000000 --- a/lib/module-index-server/src/jwt_key.rs +++ /dev/null @@ -1,110 +0,0 @@ -use jwt_simple::prelude::{JWTClaims, RS256PublicKey, RSAPublicKeyLike}; -use std::{path::Path, pin::Pin, sync::Arc}; -use telemetry::prelude::*; -use thiserror::Error; - -use tokio::{ - fs, - io::{AsyncRead, AsyncReadExt}, - task::JoinError, -}; - -use crate::extract::UserClaim; - -#[derive(Clone, Debug)] -pub struct JwtPublicSigningKey { - inner: Arc, -} -impl JwtPublicSigningKey { - #[instrument(level = "debug", skip_all)] - pub async fn load(path: impl AsRef) -> JwtKeyResult { - trace!( - path = path.as_ref().to_string_lossy().as_ref(), - "loading jwt public signing key" - ); - let mut file = fs::File::open(path).await?; - Self::from_reader(Pin::new(&mut file)).await - } - - async fn from_reader(mut reader: Pin<&mut impl AsyncRead>) -> JwtKeyResult { - let mut public_key_string = String::new(); - reader.read_to_string(&mut public_key_string).await?; - - let inner = tokio::task::spawn_blocking(move || { - RS256PublicKey::from_pem(&public_key_string) - .map_err(|err| JwtKeyError::KeyFromPem(format!("{err}"))) - }) - .instrument(trace_span!( - "from_pem", - code.namespace = "jwt_simple::algorithms::RS256PublicKey" - )) - .await??; - - Ok(Self { - inner: Arc::new(inner), - }) - } -} - -pub type JwtKeyResult = Result; - -#[remain::sorted] -#[derive(Error, Debug)] -pub enum JwtKeyError { - #[error("bad nonce bytes")] - BadNonce, - #[error("failed to decode base64 string: {0}")] - Base64Decode(#[from] base64::DecodeError), - #[error("invalid bearer token")] - BearerToken, - #[error("failed to decrypt secret data")] - Decrypt, - #[error("error generating new keypair")] - GenerateKeyPair, - #[error("io error: {0}")] - Io(#[from] std::io::Error), - #[error("failed to load jwt encryption key from bytes")] - JwtEncryptionKeyParse, - #[error("failure to build signing key from pem: {0}")] - KeyFromPem(String), - #[error("failure to extract metadata from bearer token: {0}")] - Metadata(String), - #[error("no signing keys - bad news for you!")] - NoKeys, - #[error("{0}")] - TaskJoin(#[from] JoinError), - #[error("failed to convert into PEM format")] - ToPem, - #[error(transparent)] - UlidDecode(#[from] ulid::DecodeError), - #[error("failed to build string from utf8: {0}")] - Utf8(#[from] std::string::FromUtf8Error), - #[error("failure to verify token: {0}")] - Verify(String), -} - -#[instrument(level = "debug", skip_all)] -pub async fn validate_bearer_token( - public_key: JwtPublicSigningKey, - bearer_token: impl AsRef, -) -> JwtKeyResult> { - let bearer_token = bearer_token.as_ref(); - let token = if let Some(token) = bearer_token.strip_prefix("Bearer ") { - token.to_string() - } else { - return Err(JwtKeyError::BearerToken); - }; - - let claims = tokio::task::spawn_blocking(move || { - public_key - .inner - .verify_token::(&token, None) - .map_err(|err| JwtKeyError::Verify(format!("{err}"))) - }) - .instrument(trace_span!( - "verfy_token", - code.namespace = "jwt_simple::algorithms::RSAPublicKeyLike" - )) - .await??; - Ok(claims) -} diff --git a/lib/module-index-server/src/lib.rs b/lib/module-index-server/src/lib.rs index 451d5429b5..a6077ab43a 100644 --- a/lib/module-index-server/src/lib.rs +++ b/lib/module-index-server/src/lib.rs @@ -1,7 +1,6 @@ mod app_state; mod config; mod extract; -mod jwt_key; mod models; mod routes; mod s3; diff --git a/lib/module-index-server/src/server.rs b/lib/module-index-server/src/server.rs index 8e208b029c..e2d357d913 100644 --- a/lib/module-index-server/src/server.rs +++ b/lib/module-index-server/src/server.rs @@ -1,4 +1,4 @@ -use std::{io, net::SocketAddr, path::Path, time::Duration}; +use std::{io, net::SocketAddr, time::Duration}; use super::routes; @@ -8,6 +8,7 @@ use hyper::server::{accept::Accept, conn::AddrIncoming}; use s3::creds::{error::CredentialsError, Credentials as AwsCredentials}; use sea_orm::{ConnectOptions, Database, DatabaseConnection, DbErr}; use si_data_pg::{PgPool, PgPoolConfig, PgPoolError}; +use si_jwt_public_key::{JwtConfig, JwtPublicSigningKeyChain, JwtPublicSigningKeyError}; use si_posthog::{PosthogClient, PosthogConfig}; use telemetry::prelude::*; use thiserror::Error; @@ -21,7 +22,6 @@ use tower_http::trace::{DefaultMakeSpan, TraceLayer}; use crate::{ app_state::{AppState, ShutdownSource}, - jwt_key::{JwtKeyError, JwtPublicSigningKey}, s3::S3Config, Config, }; @@ -43,8 +43,8 @@ pub enum ServerError { DbErr(#[from] DbErr), #[error("hyper server error")] Hyper(#[from] hyper::Error), - #[error("jwt secret key error")] - JwtSecretKey(#[from] JwtKeyError), + #[error("jwt public key error: {0}")] + JwtPublicKey(#[from] JwtPublicSigningKeyError), #[error(transparent)] PgPool(#[from] Box), #[error(transparent)] @@ -61,7 +61,7 @@ impl From for ServerError { } } -type Result = std::result::Result; +type ServerResult = std::result::Result; pub struct Server { config: Config, @@ -75,9 +75,9 @@ impl Server<(), ()> { pub fn http( config: Config, pg_pool: DatabaseConnection, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key: JwtPublicSigningKeyChain, posthog_client: PosthogClient, - ) -> Result<(Server, broadcast::Receiver<()>)> { + ) -> ServerResult<(Server, broadcast::Receiver<()>)> { // socket_addr // try to load aws creds from a few different places @@ -141,7 +141,7 @@ impl Server<(), ()> { // this creates our si_data_pg::PgPool, which wont work with SeaORM #[instrument(name = "module-index.init.create_pg_pool", level = "info", skip_all)] - pub async fn create_pg_pool(pg_pool_config: &PgPoolConfig) -> Result { + pub async fn create_pg_pool(pg_pool_config: &PgPoolConfig) -> ServerResult { let pool = PgPool::new(pg_pool_config).await?; debug!("successfully started pg pool (note that not all connections may be healthy)"); Ok(pool) @@ -153,7 +153,9 @@ impl Server<(), ()> { level = "info", skip_all )] - pub async fn create_db_connection(pg_pool_config: &PgPoolConfig) -> Result { + pub async fn create_db_connection( + pg_pool_config: &PgPoolConfig, + ) -> ServerResult { let mut opt = ConnectOptions::new(format!( "{protocol}://{username}:{password}@{host}:{port}/{database}", protocol = "postgres", @@ -181,7 +183,7 @@ impl Server<(), ()> { Ok(db) } - pub async fn run_migrations(pg_pool: &PgPool) -> Result<()> { + pub async fn run_migrations(pg_pool: &PgPool) -> ServerResult<()> { Ok(pg_pool .migrate(embedded_migrations::migrations::runner()) .await?) @@ -193,12 +195,27 @@ impl Server<(), ()> { skip_all )] pub async fn load_jwt_public_signing_key( - path: impl AsRef, - ) -> Result { - Ok(JwtPublicSigningKey::load(path).await?) + config: &Config, + ) -> ServerResult { + let primary = JwtConfig { + key_file: Some(config.jwt_signing_public_key_path().to_owned()), + key_base64: None, + algo: config.jwt_signing_public_key_algo(), + }; + + let secondary = config + .jwt_secondary_signing_public_key_path() + .zip(config.jwt_secondary_signing_public_key_algo()) + .map(|(path, algo)| JwtConfig { + key_file: Some(path.to_owned()), + key_base64: None, + algo, + }); + + Ok(JwtPublicSigningKeyChain::from_config(primary, secondary).await?) } - pub async fn start_posthog(config: &PosthogConfig) -> Result { + pub async fn start_posthog(config: &PosthogConfig) -> ServerResult { // TODO(fnichol): this should be threaded through let token = CancellationToken::new(); @@ -216,7 +233,7 @@ where IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, IE: Into>, { - pub async fn run(self) -> Result<()> { + pub async fn run(self) -> ServerResult<()> { let shutdown_rx = self.shutdown_rx; self.inner @@ -240,17 +257,17 @@ where pub fn build_service( pg_pool: DatabaseConnection, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, aws_creds: AwsCredentials, s3_config: S3Config, -) -> Result<(Router, oneshot::Receiver<()>, broadcast::Receiver<()>)> { +) -> ServerResult<(Router, oneshot::Receiver<()>, broadcast::Receiver<()>)> { let (shutdown_tx, shutdown_rx) = mpsc::channel(1); let (shutdown_broadcast_tx, shutdown_broadcast_rx) = broadcast::channel(1); let state = AppState::new( pg_pool, - jwt_public_signing_key, + jwt_public_signing_key_chain, posthog_client, aws_creds, s3_config, @@ -273,7 +290,7 @@ pub fn build_service( fn prepare_graceful_shutdown( mut shutdown_rx: mpsc::Receiver, shutdown_broadcast_tx: broadcast::Sender<()>, -) -> Result> { +) -> ServerResult> { let (graceful_shutdown_tx, graceful_shutdown_rx) = oneshot::channel::<()>(); let mut sigterm_watcher = signal::unix::signal(signal::unix::SignalKind::terminate()).map_err(ServerError::Signal)?; diff --git a/lib/sdf-server/BUCK b/lib/sdf-server/BUCK index 5f58ec72ca..9f807d5554 100644 --- a/lib/sdf-server/BUCK +++ b/lib/sdf-server/BUCK @@ -19,6 +19,7 @@ rust_library( "//lib/si-data-spicedb:si-data-spicedb", "//lib/si-events-rs:si-events", "//lib/si-frontend-types-rs:si-frontend-types", + "//lib/si-jwt-public-key:si-jwt-public-key", "//lib/si-layer-cache:si-layer-cache", "//lib/si-pkg:si-pkg", "//lib/si-posthog-rs:si-posthog", diff --git a/lib/sdf-server/Cargo.toml b/lib/sdf-server/Cargo.toml index ac66a9b653..afd7e8c1e1 100644 --- a/lib/sdf-server/Cargo.toml +++ b/lib/sdf-server/Cargo.toml @@ -25,6 +25,7 @@ si-data-pg = { path = "../../lib/si-data-pg" } si-data-spicedb = { path = "../../lib/si-data-spicedb" } si-events = { path = "../../lib/si-events-rs" } si-frontend-types = { path = "../../lib/si-frontend-types-rs" } +si-jwt-public-key = { path = "../../lib/si-jwt-public-key" } si-layer-cache = { path = "../../lib/si-layer-cache" } si-pkg = { path = "../../lib/si-pkg" } si-posthog = { path = "../../lib/si-posthog-rs" } diff --git a/lib/sdf-server/src/app.rs b/lib/sdf-server/src/app.rs index 293cf54470..5ffeed835a 100644 --- a/lib/sdf-server/src/app.rs +++ b/lib/sdf-server/src/app.rs @@ -3,9 +3,10 @@ use std::sync::Arc; use asset_sprayer::AssetSprayer; use audit_database::AuditDatabaseContext; use axum::Router; -use dal::{JwtPublicSigningKey, ServicesContext}; +use dal::ServicesContext; use nats_multiplexer_client::MultiplexerClient; use si_data_spicedb::SpiceDbClient; +use si_jwt_public_key::JwtPublicSigningKeyChain; use si_posthog::PosthogClient; use telemetry::prelude::*; use tokio::sync::RwLock; @@ -25,7 +26,7 @@ impl AxumApp { #[inline] pub fn from_services( services_context: ServicesContext, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, auth_api_url: impl AsRef, asset_sprayer: Option, @@ -40,7 +41,7 @@ impl AxumApp { ) -> Self { Self::inner_from_services( services_context, - jwt_public_signing_key, + jwt_public_signing_key_chain, posthog_client, auth_api_url, asset_sprayer, @@ -65,7 +66,7 @@ impl AxumApp { #[inline] pub fn from_services_for_tests( services_context: ServicesContext, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, auth_api_url: impl AsRef, asset_sprayer: Option, @@ -80,7 +81,7 @@ impl AxumApp { ) -> Self { Self::inner_from_services( services_context, - jwt_public_signing_key, + jwt_public_signing_key_chain, posthog_client, auth_api_url, asset_sprayer, @@ -103,7 +104,7 @@ impl AxumApp { #[allow(clippy::too_many_arguments)] fn inner_from_services( services_context: ServicesContext, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, auth_api_url: impl AsRef, asset_sprayer: Option, @@ -119,7 +120,7 @@ impl AxumApp { ) -> Self { let state = AppState::new( services_context, - jwt_public_signing_key, + jwt_public_signing_key_chain, posthog_client, auth_api_url, asset_sprayer, diff --git a/lib/sdf-server/src/app_state.rs b/lib/sdf-server/src/app_state.rs index 4cc29053e5..c731cad42c 100644 --- a/lib/sdf-server/src/app_state.rs +++ b/lib/sdf-server/src/app_state.rs @@ -3,9 +3,9 @@ use std::{ops::Deref, sync::Arc}; use asset_sprayer::AssetSprayer; use audit_database::AuditDatabaseContext; use axum::extract::FromRef; -use dal::JwtPublicSigningKey; use nats_multiplexer_client::MultiplexerClient; use si_data_spicedb::SpiceDbClient; +use si_jwt_public_key::JwtPublicSigningKeyChain; use std::fmt; use tokio::sync::{Mutex, RwLock}; use tokio_util::sync::CancellationToken; @@ -26,7 +26,7 @@ pub enum ApplicationRuntimeMode { pub struct AppState { services_context: ServicesContext, broadcast_groups: BroadcastGroups, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, auth_api_url: String, // TODO(victor) store the auth client on state instead of just the URL asset_sprayer: Option, @@ -44,7 +44,7 @@ impl AppState { #[allow(clippy::too_many_arguments)] pub fn new( services_context: impl Into, - jwt_public_signing_key: impl Into, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: impl Into, auth_api_url: impl AsRef, asset_sprayer: Option, @@ -65,7 +65,7 @@ impl AppState { Self { services_context: services_context.into(), - jwt_public_signing_key: jwt_public_signing_key.into(), + jwt_public_signing_key_chain, broadcast_groups: Default::default(), posthog_client: posthog_client.into(), auth_api_url: auth_api_url.as_ref().to_string(), @@ -97,8 +97,8 @@ impl AppState { self.asset_sprayer.as_ref() } - pub fn jwt_public_signing_key(&self) -> &JwtPublicSigningKey { - &self.jwt_public_signing_key + pub fn jwt_public_signing_key_chain(&self) -> &JwtPublicSigningKeyChain { + &self.jwt_public_signing_key_chain } pub fn for_tests(&self) -> bool { diff --git a/lib/sdf-server/src/config.rs b/lib/sdf-server/src/config.rs index 2c00ad6243..1f68fcd110 100644 --- a/lib/sdf-server/src/config.rs +++ b/lib/sdf-server/src/config.rs @@ -1,9 +1,9 @@ use asset_sprayer::config::{AssetSprayerConfig, SIOpenAIConfig}; use audit_database::AuditDatabaseConfig; -use dal::jwt_key::JwtConfig; use serde_with::{DeserializeFromStr, SerializeDisplay}; use si_crypto::VeritechCryptoConfig; use si_data_spicedb::SpiceDbConfig; +use si_jwt_public_key::{JwtAlgo, JwtConfig}; use si_layer_cache::{db::LayerDbConfig, error::LayerDbError}; use std::collections::HashSet; use std::{ @@ -130,6 +130,9 @@ pub struct Config { #[builder(default)] jwt_signing_public_key: JwtConfig, + #[builder(default)] + jwt_secondary_signing_public_key: Option, + #[builder(default = "default_layer_db_config()")] layer_db_config: LayerDbConfig, @@ -191,6 +194,10 @@ impl Config { &self.jwt_signing_public_key } + pub fn jwt_secondary_signing_public_key(&self) -> Option<&JwtConfig> { + self.jwt_secondary_signing_public_key.as_ref() + } + /// Gets a reference to the config's cyclone public key path. #[must_use] pub fn crypto(&self) -> &VeritechCryptoConfig { @@ -298,6 +305,8 @@ pub struct ConfigFile { #[serde(default)] pub jwt_signing_public_key: JwtConfig, #[serde(default)] + pub jwt_secondary_signing_public_key: Option, + #[serde(default)] pub crypto: VeritechCryptoConfig, #[serde(default = "default_pkgs_path")] pub pkgs_path: String, @@ -335,6 +344,7 @@ impl Default for ConfigFile { nats: Default::default(), migration_mode: Default::default(), jwt_signing_public_key: Default::default(), + jwt_secondary_signing_public_key: Default::default(), crypto: Default::default(), pkgs_path: default_pkgs_path(), posthog: Default::default(), @@ -371,6 +381,7 @@ impl TryFrom for Config { incoming_stream: IncomingStream::default(), // TODO this OK? migration_mode: value.migration_mode, jwt_signing_public_key: value.jwt_signing_public_key, + jwt_secondary_signing_public_key: value.jwt_secondary_signing_public_key, crypto: value.crypto, pkgs_path: value.pkgs_path.try_into()?, posthog: value.posthog, @@ -463,7 +474,7 @@ fn buck2_development(config: &mut ConfigFile) -> Result<()> { #[allow(clippy::disallowed_methods)] // Used in development with a local auth services // Note(victor): If the user has set a custom auth ip url via env variable we assume dev mode - let jwt_signing_public_key_path = if env::var("SI_AUTH_API_URL").is_ok() { + let jwt_primary_signing_public_key_path = if env::var("SI_AUTH_API_URL").is_ok() { resources .get_ends_with("dev.jwt_signing_public_key.pem") .map_err(ConfigError::development)? @@ -498,7 +509,7 @@ fn buck2_development(config: &mut ConfigFile) -> Result<()> { .to_string(); warn!( - jwt_signing_public_key_path = jwt_signing_public_key_path.as_str(), + jwt_signing_public_key_path = jwt_primary_signing_public_key_path.as_str(), veritech_encryption_key_path = veritech_encryption_key_path.as_str(), symmetric_crypto_service_key = symmetric_crypto_service_key.as_str(), postgres_cert = postgres_cert.as_str(), @@ -507,8 +518,9 @@ fn buck2_development(config: &mut ConfigFile) -> Result<()> { ); config.jwt_signing_public_key = JwtConfig { - key_file: Some(jwt_signing_public_key_path.try_into()?), + key_file: Some(jwt_primary_signing_public_key_path.try_into()?), key_base64: None, + algo: JwtAlgo::RS256, }; config.crypto.encryption_key_file = veritech_encryption_key_path.parse().ok(); config.symmetric_crypto_service = SymmetricCryptoServiceConfigFile { @@ -572,6 +584,7 @@ fn cargo_development(dir: String, config: &mut ConfigFile) -> Result<()> { config.jwt_signing_public_key = JwtConfig { key_file: Some(jwt_signing_public_key_path.try_into()?), key_base64: None, + algo: JwtAlgo::RS256, }; config.crypto.encryption_key_file = veritech_encryption_key_path.parse().ok(); config.symmetric_crypto_service = SymmetricCryptoServiceConfigFile { diff --git a/lib/sdf-server/src/extract.rs b/lib/sdf-server/src/extract.rs index f387b0ab91..66e00c105f 100644 --- a/lib/sdf-server/src/extract.rs +++ b/lib/sdf-server/src/extract.rs @@ -211,7 +211,7 @@ impl FromRequestParts for Authorization { let HandlerContext(builder) = HandlerContext::from_request_parts(parts, state).await?; let mut ctx = builder.build_default().await.map_err(internal_error)?; - let jwt_public_signing_key = state.jwt_public_signing_key().clone(); + let jwt_public_signing_key = state.jwt_public_signing_key_chain().clone(); let headers = &parts.headers; let authorization_header_value = headers @@ -251,7 +251,7 @@ impl FromRequestParts for WsAuthorization { ) -> Result { let HandlerContext(builder) = HandlerContext::from_request_parts(parts, state).await?; let mut ctx = builder.build_default().await.map_err(internal_error)?; - let jwt_public_signing_key = state.jwt_public_signing_key().clone(); + let jwt_public_signing_key = state.jwt_public_signing_key_chain().clone(); let query: Query> = Query::from_request_parts(parts, state) .await diff --git a/lib/sdf-server/src/init.rs b/lib/sdf-server/src/init.rs index 442298fd77..81555258ee 100644 --- a/lib/sdf-server/src/init.rs +++ b/lib/sdf-server/src/init.rs @@ -1,8 +1,8 @@ use std::{path::PathBuf, sync::Arc}; use dal::{ - feature_flags::FeatureFlagService, jwt_key::JwtConfig, DalLayerDb, DedicatedExecutor, - JetstreamStreams, JobQueueProcessor, JwtPublicSigningKey, NatsProcessor, ServicesContext, + feature_flags::FeatureFlagService, DalLayerDb, DedicatedExecutor, JetstreamStreams, + JobQueueProcessor, NatsProcessor, ServicesContext, }; use rebaser_client::RebaserClient; use si_crypto::{ @@ -11,6 +11,7 @@ use si_crypto::{ }; use si_data_nats::{NatsClient, NatsConfig}; use si_data_pg::{PgPool, PgPoolConfig}; +use si_jwt_public_key::{JwtConfig, JwtPublicSigningKeyChain, JwtPublicSigningKeyError}; use si_layer_cache::{ db::{LayerDbConfig, LayerDbGracefulShutdown}, LayerDb, @@ -32,7 +33,7 @@ pub enum InitError { #[error("failed to initialize a dal jetstream streams: {0}")] DalJetstreamStreams(#[source] dal::JetstreamStreamsError), #[error("jwt key error")] - JwtKey(#[from] dal::jwt_key::JwtKeyError), + JwtKey(#[from] JwtPublicSigningKeyError), #[error("layer cache error: {0}")] LayerCache(#[from] si_layer_cache::LayerDbError), #[error("failed to initialize a nats client: {0}")] @@ -195,9 +196,10 @@ pub(crate) async fn initialize_layer_db( skip_all )] pub(crate) async fn load_jwt_public_signing_key( - config: JwtConfig, -) -> InitResult { - Ok(JwtPublicSigningKey::from_config(config).await?) + primary: JwtConfig, + secondary: Option, +) -> InitResult { + Ok(JwtPublicSigningKeyChain::from_config(primary, secondary).await?) } pub(crate) fn initialize_posthog( diff --git a/lib/sdf-server/src/server.rs b/lib/sdf-server/src/server.rs index 7cbd92667e..8c1e8d0f9a 100644 --- a/lib/sdf-server/src/server.rs +++ b/lib/sdf-server/src/server.rs @@ -3,11 +3,12 @@ use std::{fmt, future::IntoFuture as _, net::SocketAddr, path::PathBuf, sync::Ar use asset_sprayer::AssetSprayer; use audit_database::AuditDatabaseContext; use axum::{async_trait, routing::IntoMakeService, Router}; -use dal::{JwtPublicSigningKey, ServicesContext}; +use dal::ServicesContext; use hyper::server::accept::Accept; use nats_multiplexer::Multiplexer; use nats_multiplexer_client::MultiplexerClient; use si_data_spicedb::SpiceDbClient; +use si_jwt_public_key::JwtPublicSigningKeyChain; use si_posthog::PosthogClient; use telemetry::prelude::*; use tokio::{ @@ -83,8 +84,11 @@ impl Server { let (services_context, layer_db_graceful_shutdown) = init::services_context_from_config(&config, helping_tasks_token.clone()).await?; - let jwt_public_signing_key = - init::load_jwt_public_signing_key(config.jwt_signing_public_key().clone()).await?; + let jwt_public_signing_key = init::load_jwt_public_signing_key( + config.jwt_signing_public_key().clone(), + config.jwt_secondary_signing_public_key().cloned(), + ) + .await?; let (posthog_sender, posthog_client) = init::initialize_posthog(config.posthog(), helping_tasks_token.clone())?; @@ -155,7 +159,7 @@ impl Server { instance_id: impl Into, incoming_stream: IncomingStream, services_context: ServicesContext, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, auth_api_url: impl AsRef, asset_sprayer: Option, @@ -170,7 +174,7 @@ impl Server { ) -> ServerResult { let app = AxumApp::from_services( services_context.clone(), - jwt_public_signing_key, + jwt_public_signing_key_chain, posthog_client, auth_api_url, asset_sprayer, diff --git a/lib/si-jwt-public-key/BUCK b/lib/si-jwt-public-key/BUCK new file mode 100644 index 0000000000..2c8f75f0a1 --- /dev/null +++ b/lib/si-jwt-public-key/BUCK @@ -0,0 +1,23 @@ +load( + "@prelude-si//:macros.bzl", + "rust_library", + "rust_test", +) + +rust_library( + name = "si-jwt-public-key", + deps = [ + "//lib/si-events-rs:si-events", + "//lib/si-std:si-std", + "//lib/telemetry-rs:telemetry", + "//third-party/rust:base64", + "//third-party/rust:jwt-simple", + "//third-party/rust:remain", + "//third-party/rust:serde", + "//third-party/rust:thiserror", + "//third-party/rust:tokio", + ], + srcs = glob([ + "src/**/*.rs", + ]), +) diff --git a/lib/si-jwt-public-key/Cargo.toml b/lib/si-jwt-public-key/Cargo.toml new file mode 100644 index 0000000000..b7978ee520 --- /dev/null +++ b/lib/si-jwt-public-key/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "si-jwt-public-key" +version.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +edition.workspace = true +rust-version.workspace = true +publish.workspace = true + +[dependencies] +si-events = { path = "../../lib/si-events-rs" } +si-std = { path = "../../lib/si-std" } +telemetry = { path = "../../lib/telemetry-rs" } + +base64 = { workspace = true } +jwt-simple = { workspace = true } +remain = { workspace = true } +serde = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true } diff --git a/lib/si-jwt-public-key/src/lib.rs b/lib/si-jwt-public-key/src/lib.rs new file mode 100644 index 0000000000..f8aa70316b --- /dev/null +++ b/lib/si-jwt-public-key/src/lib.rs @@ -0,0 +1,355 @@ +use core::str; +use si_events::{UserPk, WorkspacePk}; +use si_std::CanonicalFile; +use std::sync::Arc; + +use base64::{engine::general_purpose, Engine}; +use jwt_simple::{common::VerificationOptions, prelude::*}; +use serde::{Deserialize, Serialize}; +use telemetry::prelude::*; +use thiserror::Error; +use tokio::{fs, io::AsyncReadExt, task::JoinError}; + +#[remain::sorted] +#[derive(Error, Debug)] +pub enum JwtPublicSigningKeyError { + #[error("failed to decode base64 string: {0}")] + Base64Decode(#[from] base64::DecodeError), + #[error("invalid bearer token")] + BearerToken, + #[error("error creating jwt from config")] + FromConfig, + #[error("io error: {0}")] + Io(#[from] std::io::Error), + #[error("JWT error: {0}")] + Jwt(#[from] jwt_simple::Error), + #[error("{0}")] + TaskJoin(#[from] JoinError), + #[error("Unsupported JWT signing algorithm: {0}")] + UnsupportedAlgo(String), + #[error("failed to build string from utf8: {0}")] + Utf8(#[from] std::str::Utf8Error), + #[error("failure to verify token: {0}")] + Verify(String), + #[error("failure to verify against secondary token: first error: {0}, second error: {1}")] + VerifySecondaryFail(String, String), +} + +pub type JwtKeyResult = Result; + +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct JwtConfig { + pub key_file: Option, + pub key_base64: Option, + pub algo: JwtAlgo, +} + +impl JwtConfig { + pub async fn to_pem(self) -> JwtKeyResult { + Ok(match (self.key_file.as_ref(), self.key_base64.as_deref()) { + (None, Some(key_base64)) => { + let buf = general_purpose::STANDARD.decode(key_base64)?; + str::from_utf8(&buf)?.to_string() + } + (Some(key_file), None) => { + let mut file = fs::File::open(key_file).await?; + let mut buf = String::new(); + file.read_to_string(&mut buf).await?; + + buf + } + _ => Err(JwtPublicSigningKeyError::FromConfig)?, + }) + } + + pub async fn into_verify(self) -> JwtKeyResult> { + let algo = self.algo; + let pem = self.to_pem().await?; + + Ok(match algo { + JwtAlgo::ES256 => { + Arc::new(ES256PublicKey::from_pem(&pem)?) as Arc + } + JwtAlgo::RS256 => { + Arc::new(RS256PublicKey::from_pem(&pem)?) as Arc + } + }) + } +} + +#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] +pub struct SiJwtClaims { + pub user_pk: UserPk, + pub workspace_pk: WorkspacePk, +} + +#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] +pub enum JwtAlgo { + #[default] + ES256, + RS256, +} + +pub trait JwtPublicKeyVerify: std::fmt::Debug + Send + Sync { + fn algo(&self) -> JwtAlgo; + fn verify( + &self, + token: &str, + options: Option, + ) -> JwtKeyResult>; +} + +impl JwtPublicKeyVerify for RS256PublicKey { + fn algo(&self) -> JwtAlgo { + JwtAlgo::RS256 + } + + fn verify( + &self, + token: &str, + options: Option, + ) -> JwtKeyResult> { + self.verify_token(token, options) + .map_err(|err| JwtPublicSigningKeyError::Verify(format!("{err}"))) + } +} + +impl JwtPublicKeyVerify for ES256PublicKey { + fn algo(&self) -> JwtAlgo { + JwtAlgo::ES256 + } + + fn verify( + &self, + token: &str, + options: Option, + ) -> JwtKeyResult> { + self.verify_token(token, options) + .map_err(|err| JwtPublicSigningKeyError::Verify(format!("{err}"))) + } +} + +#[derive(Clone, Debug)] +pub struct JwtPublicSigningKeyChain { + primary: Arc, + secondary: Option>, +} + +impl JwtPublicSigningKeyChain { + pub async fn from_config( + primary: JwtConfig, + secondary: Option, + ) -> JwtKeyResult { + Ok(Self { + primary: primary.into_verify().await?, + secondary: match secondary { + Some(jwt_cfg) => Some(jwt_cfg.into_verify().await?), + None => None, + }, + }) + } + + /// Attempt to verify that this token was signed by either the primary or + /// secondary key(s) + pub fn verify_token( + &self, + token: &str, + options: Option, + ) -> JwtKeyResult> { + match self.primary.verify(token, options.clone()) { + Ok(claims) => Ok(claims), + Err(err) => match self.secondary.as_ref() { + Some(secondary) => match secondary.verify(token, options) { + Ok(claims) => Ok(claims), + Err(second_err) => Err(JwtPublicSigningKeyError::VerifySecondaryFail( + err.to_string(), + second_err.to_string(), + )), + }, + None => Err(err), + }, + } + } +} + +#[instrument(level = "debug", skip_all)] +pub async fn validate_bearer_token( + public_key: JwtPublicSigningKeyChain, + bearer_token: impl AsRef, +) -> JwtKeyResult> { + let bearer_token = bearer_token.as_ref(); + let token = if let Some(token) = bearer_token.strip_prefix("Bearer ") { + token.to_string() + } else { + return Err(JwtPublicSigningKeyError::BearerToken); + }; + + let claims = + tokio::task::spawn_blocking(move || public_key.verify_token(&token, None)).await??; + + Ok(claims) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn validate_with_primary_rs256() { + println!("generating key..."); + let key_pair = RS256KeyPair::generate(2048).expect("generate key pair"); + println!("done"); + + let pub_key = key_pair.public_key(); + let pub_key_pem = pub_key.to_pem().expect("get pub key pem"); + let pub_key_base64 = general_purpose::STANDARD.encode(pub_key_pem); + + let si_claim = SiJwtClaims { + user_pk: UserPk::generate(), + workspace_pk: WorkspacePk::generate(), + }; + + let claims = JWTClaims { + issued_at: None, + expires_at: None, + invalid_before: None, + issuer: None, + subject: None, + audiences: None, + jwt_id: None, + nonce: None, + custom: si_claim.clone(), + }; + + let signed = key_pair.sign(claims).expect("sign the key"); + let bearer_token = format!("Bearer {signed}"); + + let primary_cfg = JwtConfig { + key_file: None, + key_base64: Some(pub_key_base64), + algo: JwtAlgo::RS256, + }; + + let key_chain = JwtPublicSigningKeyChain::from_config(primary_cfg, None) + .await + .expect("make key chain"); + + let claims = validate_bearer_token(key_chain, &bearer_token) + .await + .expect("shoudl validate"); + + assert_eq!(si_claim, claims.custom); + } + + #[tokio::test] + async fn validate_with_primary_es256() { + println!("generating key..."); + let key_pair = ES256KeyPair::generate(); + let key_pair_2 = ES256KeyPair::generate(); + println!("done"); + + let pub_key = key_pair.public_key(); + let pub_key_pem = pub_key.to_pem().expect("get pub key pem"); + let pub_key_base64 = general_purpose::STANDARD.encode(pub_key_pem); + + let si_claim = SiJwtClaims { + user_pk: UserPk::generate(), + workspace_pk: WorkspacePk::generate(), + }; + + let claims = JWTClaims { + issued_at: None, + expires_at: None, + invalid_before: None, + issuer: None, + subject: None, + audiences: None, + jwt_id: None, + nonce: None, + custom: si_claim.clone(), + }; + + let signed = key_pair.sign(claims.clone()).expect("sign the key"); + let bearer_token = format!("Bearer {signed}"); + + let primary_cfg = JwtConfig { + key_file: None, + key_base64: Some(pub_key_base64), + algo: JwtAlgo::ES256, + }; + + let key_chain = JwtPublicSigningKeyChain::from_config(primary_cfg, None) + .await + .expect("make key chain"); + + let claims = validate_bearer_token(key_chain.clone(), &bearer_token) + .await + .expect("shoudl validate"); + + assert_eq!(si_claim, claims.custom); + + // Just confirm it fails with the wrong key + let signed_bad = key_pair_2.sign(claims).expect("sign the key"); + let bearer_bad = format!("Bearer {signed_bad}"); + let result = validate_bearer_token(key_chain, &bearer_bad).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn validate_with_secondary_rs256() { + println!("generating keys..."); + let key_pair_es256 = ES256KeyPair::generate(); + let key_pair_rs256 = RS256KeyPair::generate(2048).expect("generate rs256 key"); + println!("done"); + + let pub_key_es256 = key_pair_es256.public_key(); + let pub_key_pem = pub_key_es256.to_pem().expect("get pub key pem"); + let pub_key_base64_es256 = general_purpose::STANDARD.encode(pub_key_pem); + + let pub_key_rs256 = key_pair_rs256.public_key(); + let pub_key_pem = pub_key_rs256.to_pem().expect("get pub key pem"); + let pub_key_base64_rs256 = general_purpose::STANDARD.encode(pub_key_pem); + + let si_claim = SiJwtClaims { + user_pk: UserPk::generate(), + workspace_pk: WorkspacePk::generate(), + }; + + let claims = JWTClaims { + issued_at: None, + expires_at: None, + invalid_before: None, + issuer: None, + subject: None, + audiences: None, + jwt_id: None, + nonce: None, + custom: si_claim.clone(), + }; + + let signed = key_pair_rs256.sign(claims.clone()).expect("sign the key"); + let bearer_token = format!("Bearer {signed}"); + + let primary_cfg = JwtConfig { + key_file: None, + key_base64: Some(pub_key_base64_es256), + algo: JwtAlgo::ES256, + }; + + let secondary_cfg = JwtConfig { + key_file: None, + key_base64: Some(pub_key_base64_rs256), + algo: JwtAlgo::RS256, + }; + + let key_chain = JwtPublicSigningKeyChain::from_config(primary_cfg, Some(secondary_cfg)) + .await + .expect("make key chain"); + + let claims = validate_bearer_token(key_chain.clone(), &bearer_token) + .await + .expect("shoudl validate"); + + assert_eq!(si_claim, claims.custom); + } +}