From 27d25c397bc6d86b309ada825edb255397509bff Mon Sep 17 00:00:00 2001 From: Zachary Hamm Date: Fri, 6 Dec 2024 16:09:55 -0600 Subject: [PATCH] feat: support multiple jwt public keys Support two JWT keys. One used for signing (the primary private key), and an extra one for verification, so that we can rotate keys without downtime, and then remove the "secondary" if desired. It also now possible to specify the key algorithm. ES256 and RS256 are supported but RS256 will be removed. Adds a new crate, `si-jwt-public-key` (since this logic needs to be shared between sdf and the module-index). --- Cargo.lock | 19 ++ Cargo.toml | 80 ++++- bin/auth-api/README.md | 7 + bin/auth-api/src/lib/jwt.ts | 102 +++++-- bin/module-index/BUCK | 1 + bin/module-index/src/args.rs | 9 + bin/module-index/src/main.rs | 3 +- bin/sdf/src/args.rs | 28 +- component/init/configs/service.toml | 1 + lib/dal-test/BUCK | 1 + lib/dal-test/Cargo.toml | 3 +- lib/dal-test/src/lib.rs | 25 +- lib/dal/BUCK | 3 +- lib/dal/Cargo.toml | 1 + lib/dal/src/jwt_key.rs | 153 ---------- lib/dal/src/lib.rs | 2 - lib/dal/src/user.rs | 24 +- lib/module-index-server/BUCK | 1 + lib/module-index-server/Cargo.toml | 1 + lib/module-index-server/src/app_state.rs | 13 +- lib/module-index-server/src/config.rs | 41 ++- lib/module-index-server/src/extract.rs | 45 +-- lib/module-index-server/src/jwt_key.rs | 110 ------- lib/module-index-server/src/lib.rs | 1 - lib/module-index-server/src/server.rs | 55 ++-- lib/sdf-server/BUCK | 1 + lib/sdf-server/Cargo.toml | 1 + lib/sdf-server/src/app.rs | 15 +- lib/sdf-server/src/app_state.rs | 12 +- lib/sdf-server/src/config.rs | 21 +- lib/sdf-server/src/extract.rs | 4 +- lib/sdf-server/src/init.rs | 14 +- lib/sdf-server/src/server.rs | 14 +- lib/si-jwt-public-key/BUCK | 23 ++ lib/si-jwt-public-key/Cargo.toml | 21 ++ lib/si-jwt-public-key/src/lib.rs | 355 +++++++++++++++++++++++ 36 files changed, 798 insertions(+), 412 deletions(-) delete mode 100644 lib/dal/src/jwt_key.rs delete mode 100644 lib/module-index-server/src/jwt_key.rs create mode 100644 lib/si-jwt-public-key/BUCK create mode 100644 lib/si-jwt-public-key/Cargo.toml create mode 100644 lib/si-jwt-public-key/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index a2f2a3ab4a..80e4d666b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1941,6 +1941,7 @@ dependencies = [ "si-frontend-types", "si-hash", "si-id", + "si-jwt-public-key", "si-layer-cache", "si-pkg", "si-runtime", @@ -2005,6 +2006,7 @@ dependencies = [ "si-data-nats", "si-data-pg", "si-events", + "si-jwt-public-key", "si-layer-cache", "si-pkg", "si-runtime", @@ -4305,6 +4307,7 @@ dependencies = [ "serde_json", "si-data-pg", "si-hash", + "si-jwt-public-key", "si-pkg", "si-posthog", "si-settings", @@ -6232,6 +6235,7 @@ dependencies = [ "si-data-spicedb", "si-events", "si-frontend-types", + "si-jwt-public-key", "si-layer-cache", "si-pkg", "si-posthog", @@ -6787,6 +6791,21 @@ dependencies = [ "ulid", ] +[[package]] +name = "si-jwt-public-key" +version = "0.1.0" +dependencies = [ + "base64 0.22.1", + "jwt-simple", + "remain", + "serde", + "si-events", + "si-std", + "telemetry", + "thiserror 1.0.69", + "tokio", +] + [[package]] name = "si-layer-cache" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index e2fdbda080..f88b841f3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ members = [ "lib/si-frontend-types-rs", "lib/si-hash", "lib/si-id", + "lib/si-jwt-public-key", "lib/si-layer-cache", "lib/si-pkg", "lib/si-pool-noodle", @@ -93,7 +94,11 @@ async-recursion = "1.0.5" async-trait = "0.1.79" aws-config = { version = "1.1.7", features = ["behavior-version-latest"] } aws-sdk-firehose = "1.46.0" -axum = { version = "0.6.20", features = ["macros", "multipart", "ws"] } # todo: upgrade this alongside hyper/http/tokio-tungstenite +axum = { version = "0.6.20", features = [ + "macros", + "multipart", + "ws", +] } # todo: upgrade this alongside hyper/http/tokio-tungstenite base64 = "0.22.0" blake3 = "1.5.1" bollard = "0.16.1" @@ -103,13 +108,17 @@ ciborium = "0.2.2" clap = { version = "4.5.4", features = ["derive", "color", "env", "wrap_help"] } color-eyre = "0.6.3" colored = "2.1.0" -comfy-table = { version = "7.1.0", features = ["crossterm", "tty", "custom_styling"] } +comfy-table = { version = "7.1.0", features = [ + "crossterm", + "tty", + "custom_styling", +] } config = { version = "0.14.0", default-features = false, features = ["toml"] } console = "0.15.8" convert_case = "0.6.0" criterion = { version = "0.5.1", features = ["async_tokio"] } crossbeam-channel = "0.5.12" -crossbeam-queue = {version = "0.3.10"} +crossbeam-queue = { version = "0.3.10" } darling = "0.20.10" deadpool = { version = "0.10.0", features = ["rt_tokio_1"] } deadpool-postgres = "0.12.1" @@ -128,8 +137,15 @@ fs4 = "0.11.0" glob = "0.3.1" hex = "0.4.3" http = "0.2.12" # todo: upgrade this alongside hyper/axum/tokio-tungstenite/tower-http -hyper = { version = "0.14.28", features = ["client", "http1", "runtime", "server"] } # todo: upgrade this alongside http/axum/tokio-tungstenite/tower-http -hyperlocal = { version = "0.8.0", default-features = false, features = ["client"] } # todo: using the very latest of hyper client 1.x, we _may_ be able to phase this crate +hyper = { version = "0.14.28", features = [ + "client", + "http1", + "runtime", + "server", +] } # todo: upgrade this alongside http/axum/tokio-tungstenite/tower-http +hyperlocal = { version = "0.8.0", default-features = false, features = [ + "client", +] } # todo: using the very latest of hyper client 1.x, we _may_ be able to phase this crate iftree = "1.0.5" include_dir = "0.7.4" indexmap = { version = "2.2.6", features = ["serde", "std"] } @@ -137,14 +153,22 @@ indicatif = "0.17.8" indoc = "2.0.5" inquire = "0.7.4" itertools = "0.12.1" -jwt-simple = { version = "0.12.9", default-features = false, features = ["pure-rust"] } +jwt-simple = { version = "0.12.9", default-features = false, features = [ + "pure-rust", +] } krata-loopdev = "0.0.12" lazy_static = "1.4.0" manyhow = { version = "0.11.4", features = ["darling"] } mime_guess = { version = "=2.0.4" } # TODO(fnichol): 2.0.5 sets an env var in build.rs which needs to be tracked, required by reqwest miniz_oxide = { version = "0.7.2", features = ["simd"] } names = { version = "0.14.0", default-features = false } -nix = { version = "0.26.0", features = ["fs", "mount", "process", "signal", "user"] } +nix = { version = "0.26.0", features = [ + "fs", + "mount", + "process", + "signal", + "user", +] } nkeys = "0.4.0" num_cpus = "1.16.0" once_cell = "1.19.0" @@ -171,13 +195,25 @@ rand = "0.8.5" refinery = { version = "= 0.8.12", features = ["tokio-postgres"] } regex = "1.10.4" remain = "0.2.13" -reqwest = { version = "0.12.2", default-features = false, features = ["rustls-tls", "json", "multipart"] } +reqwest = { version = "0.12.2", default-features = false, features = [ + "rustls-tls", + "json", + "multipart", +] } ring = "=0.17.5" # Upgrading this is possible, but a pain, so we don't want to pick up every new minor version (see: https://github.com/facebook/buck2/commit/91af40b66960d003067c3d241595fb53d1e636c8) -rust-s3 = { version = "0.34.0-rc4", default-features = false, features = ["tokio-rustls-tls"] } +rust-s3 = { version = "0.34.0-rc4", default-features = false, features = [ + "tokio-rustls-tls", +] } rustls = { version = "0.23.18", default-features = false } # NOTE(nick,fletcher): rustls switched to "aws-lc-rs" as its default crypto provider, but we want ring (i.e. we disable the default feature for "aws-lc-rs") rustls-native-certs = "0.7.0" rustls-pemfile = { version = "2.1.1" } -sea-orm = { version = "1.1.0", features = ["sqlx-postgres", "runtime-tokio-rustls", "macros", "with-chrono", "debug-print"] } +sea-orm = { version = "1.1.0", features = [ + "sqlx-postgres", + "runtime-tokio-rustls", + "macros", + "with-chrono", + "debug-print", +] } self-replace = "1.3.7" serde = { version = "1.0.197", features = ["derive", "rc"] } serde-aux = "4.5.0" @@ -196,12 +232,18 @@ syn = { version = "2.0.55", features = ["full", "extra-traits"] } sysinfo = "0.32.0" tar = "0.4.40" tempfile = "3.10.1" -test-log = { version = "0.2.15", default-features = false, features = ["trace"] } +test-log = { version = "0.2.15", default-features = false, features = [ + "trace", +] } thiserror = "1.0.58" thread-priority = "1.1.0" time = "0.3.36" tokio = { version = "1.37.0", features = ["full"] } -tokio-postgres = { version = "0.7.10", features = ["runtime", "with-chrono-0_4", "with-serde_json-1"] } +tokio-postgres = { version = "0.7.10", features = [ + "runtime", + "with-chrono-0_4", + "with-serde_json-1", +] } tokio-postgres-rustls = { version = "0.12.0" } tokio-serde = { version = "0.9.0", features = ["json"] } tokio-stream = { version = "0.1.15", features = ["sync", "time"] } @@ -211,10 +253,20 @@ tokio-util = { version = "0.7.10", features = ["codec", "rt"] } tokio-vsock = { version = "0.4.0" } toml = { version = "0.8.12" } tower = { version = "0.4.13", features = ["full"] } -tower-http = { version = "0.4.4", features = ["compression-br", "compression-deflate", "compression-gzip", "cors", "trace"] } # todo: pinning back to 0.4.4, upgrade this alongside hyper/http/axum/tokio-tungstenite +tower-http = { version = "0.4.4", features = [ + "compression-br", + "compression-deflate", + "compression-gzip", + "cors", + "trace", +] } # todo: pinning back to 0.4.4, upgrade this alongside hyper/http/axum/tokio-tungstenite tracing = { version = "0.1.40" } tracing-opentelemetry = "0.27.0" -tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "std"] } +tracing-subscriber = { version = "0.3.18", features = [ + "env-filter", + "json", + "std", +] } tracing-tunnel = "0.1.0" trybuild = { version = "1.0.99", features = ["diff"] } tryhard = "0.5.1" diff --git a/bin/auth-api/README.md b/bin/auth-api/README.md index eb05f0155c..44658051e7 100644 --- a/bin/auth-api/README.md +++ b/bin/auth-api/README.md @@ -10,6 +10,13 @@ Use `pnpx prisma` to run prisma commands locally. For example ### JWT Signing Key +### ES256 + +- `ssh-keygen -t ecdsa -b 256 -m PEM -f jwtES256.key` +- `openssl ec -in jwtES256.key -pubout -outform PEM -out jwtES256.key.pub` + +### RS256 (deprecated) + - `ssh-keygen -t rsa -b 4096 -m PEM -f jwtRS256.key` - `openssl rsa -in jwtRS256.key -pubout -outform PEM -out jwtRS256.key.pub` diff --git a/bin/auth-api/src/lib/jwt.ts b/bin/auth-api/src/lib/jwt.ts index 91c2f1d877..6e7997e5f0 100644 --- a/bin/auth-api/src/lib/jwt.ts +++ b/bin/auth-api/src/lib/jwt.ts @@ -1,5 +1,11 @@ /* instructions to generate JWT signing key + run `ssh-keygen -t ecdsa -b 256 -m PEM -f jwtES256.key` +- run `openssl ec -in jwtES256.key -pubout -outform PEM -out jwtES256.key.pub` +- `cat jwtES256.key` +- `cat jwtES256.key.pub` + +For RS256: (deprecated) - run `ssh-keygen -t rsa -b 4096 -m PEM -f jwtRS256.key` # Don't add passphrase - run `openssl rsa -in jwtRS256.key -pubout -outform PEM -out jwtRS256.key.pub` - `cat jwtRS256.key` @@ -9,34 +15,92 @@ instructions to generate JWT signing key import fs from "fs"; import JWT from "jsonwebtoken"; -// load private and public key from either env var or paths set in config -// keys in the repo are also used by SDF to verify jwt is signed correctly and in tests to create/sign jwts -let _JWT_PRIVATE_KEY = process.env.JWT_PRIVATE_KEY; -if (!_JWT_PRIVATE_KEY && process.env.JWT_PRIVATE_KEY_PATH) { - // path is relative to .env file - _JWT_PRIVATE_KEY = fs.readFileSync(`${process.env.JWT_PRIVATE_KEY_PATH}`, 'utf-8'); -} -let _JWT_PUBLIC_KEY = process.env.JWT_PUBLIC_KEY; -if (!_JWT_PUBLIC_KEY && process.env.JWT_PUBLIC_KEY_PATH) { - // path is relative to .env file - _JWT_PUBLIC_KEY = fs.readFileSync(`${process.env.JWT_PUBLIC_KEY_PATH}`, 'utf-8'); -} -if (!_JWT_PRIVATE_KEY) throw new Error('Missing JWT signing private key'); -if (!_JWT_PUBLIC_KEY) throw new Error('Missing JWT signing public key'); +const DEFAULT_ALGO = "RS256"; + +type Algo = "RS256" | "ES256"; + +const jwtAlgo = (algo?: string): Algo => { + switch (algo) { + case "RS256": + case "ES256": + return algo; + default: + return DEFAULT_ALGO; + } +}; + +const keyEnvPaths = { + primary: { + private: "JWT_PRIVATE_KEY", + privatePath: "JWT_PRIVATE_KEY_PATH", + public: "JWT_PUBLIC_KEY", + publicPath: "JWT_PUBLIC_KEY_PATH", + algo: "JWT_ALGO", + }, + secondary: { + private: "JWT_2ND_PRIVATE_KEY", + privatePath: "JWT_2ND_PRIVATE_KEY_PATH", + public: "JWT_2ND_PUBLIC_KEY", + publicPath: "JWT_2ND_PUBLIC_KEY_PATH", + algo: "JWT_2ND_ALGO", + }, +}; + +// load private and public keys from either env var or paths set in config keys +// in the repo are also used by SDF to verify jwt is signed correctly and in +// tests to create/sign jwts + +const prepareKeys = (which: "primary" | "secondary"): { privKey?: string, pubKey?: string, algo: Algo } => { + const privateLiteral = process.env[keyEnvPaths[which].private]; + const privatePath = process.env[keyEnvPaths[which].privatePath]; + + let privKey = privateLiteral ?? (privatePath ? fs.readFileSync(privatePath, 'utf-8') : undefined); + if (privKey) { + privKey = privKey.replace(/\\n/g, '\n'); + } + + const publicLiteral = process.env[keyEnvPaths[which].public]; + const publicPath = process.env[keyEnvPaths[which].publicPath]; + + let pubKey = publicLiteral ?? (publicPath ? fs.readFileSync(publicPath, 'utf-8') : undefined); + if (pubKey) { + pubKey = pubKey.replace(/\\n/g, '\n'); + } + + const algo = jwtAlgo(process.env[keyEnvPaths[which].algo]); + + return { + privKey, + pubKey, + algo, + }; +}; + +const { privKey: primaryPrivKey, pubKey: primaryPubKey, algo } = prepareKeys("primary"); +const { pubKey: secondaryPubKey } = prepareKeys("secondary"); -_JWT_PRIVATE_KEY = _JWT_PRIVATE_KEY.replace(/\\n/g, '\n'); -_JWT_PUBLIC_KEY = _JWT_PUBLIC_KEY.replace(/\\n/g, '\n'); +if (!primaryPrivKey) throw new Error('Missing JWT signing private key'); +if (!primaryPubKey) throw new Error('Missing JWT signing public key'); -export const JWT_PUBLIC_KEY = _JWT_PUBLIC_KEY; +export const JWT_PUBLIC_KEY = primaryPubKey; +export const JWT_2ND_PUBLIC_KEY = secondaryPubKey; export function createJWT( payload: Record, options?: Omit, ) { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - return JWT.sign(payload, _JWT_PRIVATE_KEY!, { algorithm: "RS256", ...options }); + return JWT.sign(payload, primaryPrivKey!, { algorithm: algo, ...options }); } export function verifyJWT(token: string) { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - return JWT.verify(token, _JWT_PUBLIC_KEY!); + try { + return JWT.verify(token, primaryPubKey!); + } catch (err) { + if (secondaryPubKey) { + return JWT.verify(token, secondaryPubKey); + } else { + throw err; + } + } } diff --git a/bin/module-index/BUCK b/bin/module-index/BUCK index 3f25618d7e..02ea393f81 100644 --- a/bin/module-index/BUCK +++ b/bin/module-index/BUCK @@ -10,6 +10,7 @@ rust_binary( deps = [ "//lib/module-index-server:module-index-server", "//lib/si-std:si-std", + "//lib/si-jwt-public-key:si-jwt-public-key", "//lib/telemetry-application-rs:telemetry-application", "//third-party/rust:clap", "//third-party/rust:color-eyre", diff --git a/bin/module-index/src/args.rs b/bin/module-index/src/args.rs index 6a60d6e6eb..57cd78e1b9 100644 --- a/bin/module-index/src/args.rs +++ b/bin/module-index/src/args.rs @@ -118,6 +118,15 @@ pub(crate) struct Args { /// The path to the JWT public signing key #[arg(long, env)] pub(crate) jwt_public_key: Option, + + #[arg(long, env)] + pub(crate) jwt_public_key_algo: Option, + + #[arg(long, env)] + pub(crate) jwt_secondary_public_key: Option, + + #[arg(long, env)] + pub(crate) jwt_secondary_public_key_algo: Option, // /// Database migration mode on startup // #[arg(long, value_parser = PossibleValuesParser::new(MigrationMode::variants()))] } diff --git a/bin/module-index/src/main.rs b/bin/module-index/src/main.rs index 0f4d1f7ff1..bdf00cae97 100644 --- a/bin/module-index/src/main.rs +++ b/bin/module-index/src/main.rs @@ -54,8 +54,7 @@ async fn async_main() -> Result<()> { let config = Config::try_from(args)?; - let jwt_public_signing_key = - Server::load_jwt_public_signing_key(config.jwt_signing_public_key_path()).await?; + let jwt_public_signing_key = Server::load_jwt_public_signing_key(&config).await?; // our pg pool works for migrations (refinery) but doesnt work for SeaORM :( // so we set up both connections for now... Would like to clean this up diff --git a/bin/sdf/src/args.rs b/bin/sdf/src/args.rs index 131e32901a..d9562de471 100644 --- a/bin/sdf/src/args.rs +++ b/bin/sdf/src/args.rs @@ -125,7 +125,19 @@ pub(crate) struct Args { /// jwt public signing key as a base64 string #[arg(long)] - pub(crate) jwt_public_signing_key_base64: Option, + pub(crate) jwt_public_signing_key_base64: Option, + + /// jwt public signing key algorithm (ES256 or RS256) + #[arg(long)] + pub(crate) jwt_public_signing_key_algo: Option, + + /// jwt secondary public signing key as a base64 string + #[arg(long)] + pub(crate) jwt_secondary_public_signing_key_base64: Option, + + /// jwt secondary public signing key algorithm (ES256 or RS256) + #[arg(long)] + pub(crate) jwt_secondary_public_signing_key_algo: Option, /// The path at which the layer db cache is created/used on disk [e.g. /banana/] #[arg(long)] @@ -338,9 +350,21 @@ impl TryFrom for Config { base64.to_string(), ); } + if let Some(jwt) = args.jwt_public_signing_key_base64 { - config_map.set("jwt_signing_public_key.key_base64", jwt.to_string()); + config_map.set("jwt_signing_public_key.key_base64", jwt); + } + if let Some(algo) = args.jwt_public_signing_key_algo { + config_map.set("jwt_signing_public_key.algo", algo); + } + + if let Some(jwt) = args.jwt_secondary_public_signing_key_base64 { + config_map.set("jwt_secondary_signing_public_key.key_base64", jwt); + } + if let Some(algo) = args.jwt_secondary_public_signing_key_algo { + config_map.set("jwt_secondary_signing_public_key.algo", algo); } + if let Some(layer_cache_disk_path) = args.layer_db_disk_path { config_map.set("layer_db_config.disk_path", layer_cache_disk_path); } diff --git a/component/init/configs/service.toml b/component/init/configs/service.toml index 93214277f1..2dc64d417c 100644 --- a/component/init/configs/service.toml +++ b/component/init/configs/service.toml @@ -18,6 +18,7 @@ runtime_strategy = "LocalFirecracker" [jwt_signing_public_key] key_base64 = "$SI_JWT_KEY_BASE64" +algo = "RS256" [nats] creds = """ diff --git a/lib/dal-test/BUCK b/lib/dal-test/BUCK index cddc24e614..dc76191df2 100644 --- a/lib/dal-test/BUCK +++ b/lib/dal-test/BUCK @@ -14,6 +14,7 @@ rust_library( "//lib/si-data-nats:si-data-nats", "//lib/si-data-pg:si-data-pg", "//lib/si-events-rs:si-events", + "//lib/si-jwt-public-key:si-jwt-public-key", "//lib/si-layer-cache:si-layer-cache", "//lib/si-pkg:si-pkg", "//lib/si-runtime-rs:si-runtime", diff --git a/lib/dal-test/Cargo.toml b/lib/dal-test/Cargo.toml index b4ea484ad5..0ec5ea5c40 100644 --- a/lib/dal-test/Cargo.toml +++ b/lib/dal-test/Cargo.toml @@ -20,6 +20,7 @@ si-crypto = { path = "../../lib/si-crypto" } si-data-nats = { path = "../../lib/si-data-nats" } si-data-pg = { path = "../../lib/si-data-pg" } si-events = { path = "../../lib/si-events-rs" } +si-jwt-public-key = { path = "../../lib/si-jwt-public-key" } si-layer-cache = { path = "../../lib/si-layer-cache" } si-pkg = { path = "../../lib/si-pkg" } si-runtime = { path = "../../lib/si-runtime-rs" } @@ -50,5 +51,5 @@ tokio = { workspace = true } tokio-util = { workspace = true } tracing-opentelemetry = { workspace = true } tracing-subscriber = { workspace = true } -ulid = { workspace = true } +ulid = { workspace = true } uuid = { workspace = true } diff --git a/lib/dal-test/src/lib.rs b/lib/dal-test/src/lib.rs index a1338d116b..862ed09e23 100644 --- a/lib/dal-test/src/lib.rs +++ b/lib/dal-test/src/lib.rs @@ -30,6 +30,7 @@ use std::{ env, fmt, future::IntoFuture, path::{Path, PathBuf}, + str::FromStr, sync::{Arc, Once}, }; @@ -39,8 +40,7 @@ use dal::{ builtins::func, feature_flags::FeatureFlagService, job::processor::{JobQueueProcessor, NatsProcessor}, - DalContext, DalLayerDb, JetstreamStreams, JwtPublicSigningKey, ModelResult, ServicesContext, - Workspace, + DalContext, DalLayerDb, JetstreamStreams, ModelResult, ServicesContext, Workspace, }; use derive_builder::Builder; use jwt_simple::prelude::RS256KeyPair; @@ -51,9 +51,10 @@ use si_crypto::{ }; use si_data_nats::{jetstream, NatsClient, NatsConfig}; use si_data_pg::{PgPool, PgPoolConfig}; +use si_jwt_public_key::{JwtAlgo, JwtConfig, JwtPublicSigningKeyChain}; use si_layer_cache::hybrid_cache::CacheConfig; use si_runtime::DedicatedExecutor; -use si_std::ResultExt; +use si_std::{CanonicalFile, ResultExt}; use telemetry::prelude::*; use tokio::{fs::File, io::AsyncReadExt, sync::Mutex}; use tokio_util::{sync::CancellationToken, task::TaskTracker}; @@ -170,6 +171,7 @@ pub struct Config { module_index_url: String, veritech_encryption_key_path: String, jwt_signing_public_key_path: String, + jwt_signing_public_key_algo: JwtAlgo, jwt_signing_private_key_path: String, postgres_key_path: String, #[builder(default)] @@ -619,13 +621,22 @@ pub fn random_identifier_string() -> String { } /// Returns a JWT public signing key, which is used to verify claims. -pub async fn jwt_public_signing_key() -> Result { - let jwt_signing_public_key_path = { +pub async fn jwt_public_signing_key() -> Result { + let jwt_config = { let context_builder = TEST_CONTEXT_BUILDER.lock().await; let config = context_builder.config()?; - config.jwt_signing_public_key_path.clone() + let key_file = Some(CanonicalFile::from_str( + &config.jwt_signing_public_key_path, + )?); + + JwtConfig { + key_file, + key_base64: None, + algo: config.jwt_signing_public_key_algo, + } }; - let key = JwtPublicSigningKey::load(&jwt_signing_public_key_path).await?; + + let key = JwtPublicSigningKeyChain::from_config(jwt_config, None).await?; Ok(key) } diff --git a/lib/dal/BUCK b/lib/dal/BUCK index 6ed5269456..23dc900301 100644 --- a/lib/dal/BUCK +++ b/lib/dal/BUCK @@ -24,6 +24,7 @@ rust_library( "//lib/si-frontend-types-rs:si-frontend-types", "//lib/si-hash:si-hash", "//lib/si-id:si-id", + "//lib/si-jwt-public-key:si-jwt-public-key", "//lib/si-layer-cache:si-layer-cache", "//lib/si-pkg:si-pkg", "//lib/si-runtime-rs:si-runtime", @@ -122,7 +123,7 @@ rust_test( ], crate_root = "tests/integration.rs", srcs = glob([ - "tests/**/*.rs", + "tests/**/*.rs", "tests/integration_test/external/ignition/*.ign", ]), env = { diff --git a/lib/dal/Cargo.toml b/lib/dal/Cargo.toml index 3b0141d22d..ab975feb1e 100644 --- a/lib/dal/Cargo.toml +++ b/lib/dal/Cargo.toml @@ -25,6 +25,7 @@ si-data-pg = { path = "../../lib/si-data-pg" } si-events = { path = "../../lib/si-events-rs" } si-frontend-types = { path = "../../lib/si-frontend-types-rs" } si-hash = { path = "../../lib/si-hash" } +si-jwt-public-key = { path = "../../lib/si-jwt-public-key" } si-id = { path = "../../lib/si-id" } si-layer-cache = { path = "../../lib/si-layer-cache" } si-pkg = { path = "../../lib/si-pkg" } diff --git a/lib/dal/src/jwt_key.rs b/lib/dal/src/jwt_key.rs deleted file mode 100644 index e5e4f3ea68..0000000000 --- a/lib/dal/src/jwt_key.rs +++ /dev/null @@ -1,153 +0,0 @@ -use si_std::CanonicalFile; -use std::{io::Cursor, path::Path, pin::Pin, sync::Arc}; - -use base64::{engine::general_purpose, Engine}; -use jwt_simple::{ - algorithms::RS256PublicKey, - prelude::{JWTClaims, RSAPublicKeyLike}, -}; -use serde::{Deserialize, Serialize}; -use si_data_pg::{PgError, PgPoolError}; -use telemetry::prelude::*; -use thiserror::Error; -use tokio::{ - fs, - io::{AsyncRead, AsyncReadExt}, - task::JoinError, -}; - -use crate::{TransactionsError, UserClaim, UserPk, WorkspacePk}; - -#[remain::sorted] -#[derive(Error, Debug)] -pub enum JwtKeyError { - #[error("bad nonce bytes")] - BadNonce, - #[error("failed to decode base64 string: {0}")] - Base64Decode(#[from] base64::DecodeError), - #[error("invalid bearer token")] - BearerToken, - #[error("failed to decrypt secret data")] - Decrypt, - #[error("error creating jwt from config")] - FromConfig, - #[error("error generating new keypair")] - GenerateKeyPair, - #[error("io error: {0}")] - Io(#[from] std::io::Error), - #[error("failed to load jwt encryption key from bytes")] - JwtEncryptionKeyParse, - #[error("failure to build signing key from pem: {0}")] - KeyFromPem(String), - #[error("failure to extract metadata from bearer token: {0}")] - Metadata(String), - #[error("no signing keys - bad news for you!")] - NoKeys, - #[error("pg error: {0}")] - Pg(#[from] PgError), - #[error("pg pool error: {0}")] - PgPool(#[from] PgPoolError), - #[error("{0}")] - TaskJoin(#[from] JoinError), - #[error("failed to convert into PEM format")] - ToPem, - #[error("transactions error: {0}")] - Transactions(#[from] TransactionsError), - #[error(transparent)] - UlidDecode(#[from] ulid::DecodeError), - #[error("failed to build string from utf8: {0}")] - Utf8(#[from] std::string::FromUtf8Error), - #[error("failure to verify token: {0}")] - Verify(String), -} - -pub type JwtKeyResult = Result; - -#[derive(Clone, Debug, Default, Deserialize, Serialize)] -pub struct JwtConfig { - pub key_file: Option, - pub key_base64: Option, -} - -#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] -pub struct SiClaims { - pub user_pk: UserPk, - pub workspace_pk: WorkspacePk, -} - -#[derive(Clone, Debug)] -pub struct JwtPublicSigningKey { - inner: Arc, -} - -impl JwtPublicSigningKey { - pub async fn from_config(config: JwtConfig) -> JwtKeyResult { - match (config.key_file, config.key_base64) { - (Some(path), None) => Self::load(path).await, - (None, Some(b64_string)) => Self::decode(b64_string).await, - _ => Err(JwtKeyError::FromConfig), - } - } - - #[instrument(level = "debug", skip_all)] - pub async fn load(path: impl AsRef) -> JwtKeyResult { - trace!( - path = path.as_ref().to_string_lossy().as_ref(), - "loading jwt public signing key" - ); - let mut file = fs::File::open(path).await?; - Self::from_reader(Pin::new(&mut file)).await - } - - #[instrument(level = "debug", skip_all)] - pub async fn decode(key_string: String) -> JwtKeyResult { - let buf = general_purpose::STANDARD.decode(key_string)?; - - Self::from_reader(Pin::new(&mut Cursor::new(&buf))).await - } - - async fn from_reader(mut reader: Pin<&mut impl AsyncRead>) -> JwtKeyResult { - let mut public_key_string = String::new(); - reader.read_to_string(&mut public_key_string).await?; - - let inner = tokio::task::spawn_blocking(move || { - RS256PublicKey::from_pem(&public_key_string) - .map_err(|err| JwtKeyError::KeyFromPem(format!("{err}"))) - }) - .instrument(trace_span!( - "from_pem", - code.namespace = "jwt_simple::algorithms::RS256PublicKey" - )) - .await??; - - Ok(Self { - inner: Arc::new(inner), - }) - } -} - -#[instrument(level = "debug", skip_all)] -pub async fn validate_bearer_token( - public_key: JwtPublicSigningKey, - bearer_token: impl AsRef, -) -> JwtKeyResult> { - let bearer_token = bearer_token.as_ref(); - let token = if let Some(token) = bearer_token.strip_prefix("Bearer ") { - token.to_string() - } else { - return Err(JwtKeyError::BearerToken); - }; - - let claims = tokio::task::spawn_blocking(move || { - public_key - .inner - .verify_token::(&token, None) - .map_err(|err| JwtKeyError::Verify(format!("{err}"))) - }) - .instrument(trace_span!( - "verfy_token", - code.namespace = "jwt_simple::algorithms::RSAPublicKeyLike" - )) - .await??; - Ok(claims) -} diff --git a/lib/dal/src/lib.rs b/lib/dal/src/lib.rs index dd3e37cec2..d206e6f703 100644 --- a/lib/dal/src/lib.rs +++ b/lib/dal/src/lib.rs @@ -40,7 +40,6 @@ pub mod history_event; pub mod input_sources; pub mod jetstream_streams; pub mod job; -pub mod jwt_key; pub mod key_pair; pub mod label_list; pub mod layer_db_types; @@ -94,7 +93,6 @@ pub use func::{ pub use history_event::{HistoryActor, HistoryEvent, HistoryEventError}; pub use jetstream_streams::{JetstreamStreams, JetstreamStreamsError}; pub use job::processor::{JobQueueProcessor, NatsProcessor}; -pub use jwt_key::JwtPublicSigningKey; pub use key_pair::{KeyPair, KeyPairError, KeyPairResult, PublicKey}; pub use label_list::{LabelEntry, LabelList, LabelListError}; pub use prop::{Prop, PropId, PropKind}; diff --git a/lib/dal/src/user.rs b/lib/dal/src/user.rs index 793b70d656..da2c99c689 100644 --- a/lib/dal/src/user.rs +++ b/lib/dal/src/user.rs @@ -2,14 +2,17 @@ use serde::{Deserialize, Serialize}; use si_data_nats::NatsError; use si_data_pg::PgError; use si_events::ViewId; +use si_jwt_public_key::{ + validate_bearer_token, JwtPublicSigningKeyChain, JwtPublicSigningKeyError, SiJwtClaims, +}; use telemetry::prelude::*; use thiserror::Error; use tokio::task::JoinError; use crate::ws_event::{WsEvent, WsEventResult, WsPayload}; use crate::{ - jwt_key::JwtKeyError, standard_model_accessor_ro, ChangeSetId, DalContext, HistoryEvent, - HistoryEventError, JwtPublicSigningKey, Tenancy, Timestamp, TransactionsError, WorkspacePk, + standard_model_accessor_ro, ChangeSetId, DalContext, HistoryEvent, HistoryEventError, Tenancy, + Timestamp, TransactionsError, WorkspacePk, }; const USER_GET_BY_PK: &str = include_str!("queries/user/get_by_pk.sql"); @@ -23,7 +26,7 @@ pub enum UserError { #[error("failed to join long lived async task; bug!")] Join(#[from] JoinError), #[error(transparent)] - JwtKey(#[from] JwtKeyError), + JwtKey(#[from] JwtPublicSigningKeyError), #[error("nats txn error: {0}")] Nats(#[from] NatsError), #[error("user not found in tenancy: {0} {1:?}")] @@ -212,6 +215,15 @@ pub struct UserClaim { pub workspace_pk: WorkspacePk, } +impl From for UserClaim { + fn from(value: SiJwtClaims) -> Self { + Self { + user_pk: value.user_pk.into_raw_id().into(), + workspace_pk: value.workspace_pk.into_raw_id().into(), + } + } +} + impl UserClaim { pub fn new(user_pk: UserPk, workspace_pk: WorkspacePk) -> Self { UserClaim { @@ -221,11 +233,11 @@ impl UserClaim { } pub async fn from_bearer_token( - public_key: JwtPublicSigningKey, + public_key: JwtPublicSigningKeyChain, token: impl AsRef, ) -> UserResult { - let claims = crate::jwt_key::validate_bearer_token(public_key, &token).await?; - Ok(claims.custom) + let claims = validate_bearer_token(public_key, token).await?; + Ok(claims.custom.into()) } } diff --git a/lib/module-index-server/BUCK b/lib/module-index-server/BUCK index 6cbbb2b794..ea2c99e0f0 100644 --- a/lib/module-index-server/BUCK +++ b/lib/module-index-server/BUCK @@ -9,6 +9,7 @@ rust_library( "//lib/si-data-pg:si-data-pg", "//lib/si-pkg:si-pkg", "//lib/si-hash:si-hash", + "//lib/si-jwt-public-key:si-jwt-public-key", "//lib/si-posthog-rs:si-posthog", "//lib/si-settings:si-settings", "//lib/si-std:si-std", diff --git a/lib/module-index-server/Cargo.toml b/lib/module-index-server/Cargo.toml index b91ab3d00f..4d58c3931b 100644 --- a/lib/module-index-server/Cargo.toml +++ b/lib/module-index-server/Cargo.toml @@ -14,6 +14,7 @@ buck2-resources = { path = "../../lib/buck2-resources" } module-index-types = { path = "../../lib/module-index-types" } si-data-pg = { path = "../../lib/si-data-pg" } si-hash = { path = "../../lib/si-hash" } +si-jwt-public-key = { path = "../../lib/si-jwt-public-key" } si-pkg = { path = "../../lib/si-pkg" } si-posthog = { path = "../../lib/si-posthog-rs" } si-settings = { path = "../../lib/si-settings" } diff --git a/lib/module-index-server/src/app_state.rs b/lib/module-index-server/src/app_state.rs index 53aeb60262..fc7a0b4ac6 100644 --- a/lib/module-index-server/src/app_state.rs +++ b/lib/module-index-server/src/app_state.rs @@ -4,11 +4,12 @@ use std::sync::Arc; use axum::extract::FromRef; use s3::creds::Credentials as AwsCredentials; use sea_orm::DatabaseConnection; +use si_jwt_public_key::JwtPublicSigningKeyChain; pub use si_posthog::PosthogClient; use tokio::sync::{mpsc, Mutex}; -use crate::{jwt_key::JwtPublicSigningKey, s3::S3Config}; +use crate::s3::S3Config; #[remain::sorted] #[derive(Debug, Eq, PartialEq)] @@ -18,7 +19,7 @@ pub enum ShutdownSource {} pub struct AppState { /// A PostgreSQL connection pool. pg_pool: DatabaseConnection, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, aws_creds: AwsCredentials, s3_config: S3Config, @@ -34,7 +35,7 @@ impl AppState { #[allow(clippy::too_many_arguments)] pub fn new( pg_pool: DatabaseConnection, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, aws_creds: AwsCredentials, s3_config: S3Config, @@ -42,7 +43,7 @@ impl AppState { ) -> Self { Self { pg_pool, - jwt_public_signing_key, + jwt_public_signing_key_chain, posthog_client, aws_creds, s3_config, @@ -57,8 +58,8 @@ impl AppState { } /// Gets a reference to the public key used to sign the JWT - pub fn jwt_public_signing_key(&self) -> &JwtPublicSigningKey { - &self.jwt_public_signing_key + pub fn jwt_public_signing_key(&self) -> &JwtPublicSigningKeyChain { + &self.jwt_public_signing_key_chain } /// Gets a reference to the Posthog client. diff --git a/lib/module-index-server/src/config.rs b/lib/module-index-server/src/config.rs index f6cc1cd53a..f228aab15b 100644 --- a/lib/module-index-server/src/config.rs +++ b/lib/module-index-server/src/config.rs @@ -4,6 +4,7 @@ use buck2_resources::Buck2Resources; use derive_builder::Builder; use serde::{Deserialize, Serialize}; use si_data_pg::PgPoolConfig; +use si_jwt_public_key::JwtAlgo; use si_posthog::PosthogConfig; use si_std::{CanonicalFile, CanonicalFileError}; use telemetry::prelude::*; @@ -51,6 +52,12 @@ pub struct Config { instance_id: String, jwt_signing_public_key_path: CanonicalFile, + jwt_signing_public_key_algo: JwtAlgo, + + #[builder(default)] + jwt_secondary_signing_public_key_path: Option, + #[builder(default)] + jwt_secondary_signing_public_key_algo: Option, #[builder(default = "PosthogConfig::default()")] posthog: PosthogConfig, @@ -81,8 +88,25 @@ impl Config { /// Gets a reference to the config's jwt signing public key path. #[must_use] - pub fn jwt_signing_public_key_path(&self) -> &Path { - self.jwt_signing_public_key_path.as_path() + pub fn jwt_signing_public_key_path(&self) -> &CanonicalFile { + &self.jwt_signing_public_key_path + } + + /// Gets a reference to the config's jwt signing public key path. + #[must_use] + pub fn jwt_signing_public_key_algo(&self) -> JwtAlgo { + self.jwt_signing_public_key_algo + } + + /// Gets a reference to the config's jwt secondary signing public key path. + #[must_use] + pub fn jwt_secondary_signing_public_key_path(&self) -> Option<&CanonicalFile> { + self.jwt_secondary_signing_public_key_path.as_ref() + } + + #[must_use] + pub fn jwt_secondary_signing_public_key_algo(&self) -> Option { + self.jwt_secondary_signing_public_key_algo } /// Gets a reference to the config's posthog config. @@ -108,6 +132,12 @@ pub struct ConfigFile { instance_id: String, #[serde(default = "default_jwt_signing_public_key_path")] pub jwt_signing_public_key_path: String, + #[serde(default = "default_jwt_signing_public_key_algo")] + pub jwt_signing_public_key_algo: JwtAlgo, + #[serde(default)] + pub jwt_secondary_signing_public_key_path: Option, + #[serde(default)] + pub jwt_secondary_signing_public_key_algo: Option, #[serde(default)] pub posthog: PosthogConfig, #[serde(default)] @@ -127,6 +157,9 @@ impl Default for ConfigFile { socket_addr: get_default_socket_addr(), instance_id: random_instance_id(), jwt_signing_public_key_path: default_jwt_signing_public_key_path(), + jwt_signing_public_key_algo: default_jwt_signing_public_key_algo(), + jwt_secondary_signing_public_key_path: None, + jwt_secondary_signing_public_key_algo: None, posthog: Default::default(), s3: Default::default(), } @@ -161,6 +194,10 @@ fn default_jwt_signing_public_key_path() -> String { "/run/sdf/jwt_signing_public_key.pem".to_string() } +fn default_jwt_signing_public_key_algo() -> JwtAlgo { + JwtAlgo::RS256 +} + #[allow(clippy::disallowed_methods)] // Used to determine if running in development pub fn detect_and_configure_development(config: &mut ConfigFile) -> Result<()> { if env::var("BUCK_RUN_BUILD_ID").is_ok() || env::var("BUCK_BUILD_ID").is_ok() { diff --git a/lib/module-index-server/src/extract.rs b/lib/module-index-server/src/extract.rs index 5e753e6e93..0cc766f2e1 100644 --- a/lib/module-index-server/src/extract.rs +++ b/lib/module-index-server/src/extract.rs @@ -4,12 +4,9 @@ use axum::{async_trait, extract::FromRequestParts, http::request::Parts, Json}; use hyper::StatusCode; use s3::{Bucket as S3Bucket, Region as AwsRegion}; use sea_orm::{DatabaseTransaction, TransactionTrait}; -use serde::{Deserialize, Serialize}; -use thiserror::Error; -use ulid::Ulid; +use si_jwt_public_key::SiJwtClaims; use super::app_state::AppState; -use crate::jwt_key::{JwtKeyError, JwtPublicSigningKey}; pub struct ExtractedS3Bucket(pub S3Bucket); @@ -80,36 +77,8 @@ impl FromRequestParts for DbConnection { } } -pub type UserPk = Ulid; -pub type WorkspacePk = Ulid; - -#[derive(Deserialize, Serialize, Debug, Clone, Copy)] -pub struct UserClaim { - pub user_pk: UserPk, - pub workspace_pk: WorkspacePk, -} - -#[remain::sorted] -#[derive(Error, Debug)] -pub enum AuthError { - #[error(transparent)] - JwtKey(#[from] JwtKeyError), -} - -pub type AuthResult = Result; - -impl UserClaim { - pub async fn from_bearer_token( - public_key: JwtPublicSigningKey, - token: impl AsRef, - ) -> AuthResult { - let claims = crate::jwt_key::validate_bearer_token(public_key, &token).await?; - Ok(claims.custom) - } -} - pub struct Authorization { - pub user_claim: UserClaim, + pub user_claim: SiJwtClaims, pub auth_token: String, } @@ -121,7 +90,7 @@ impl FromRequestParts for Authorization { parts: &mut Parts, state: &AppState, ) -> Result { - let jwt_public_signing_key = state.jwt_public_signing_key().clone(); + let jwt_public_signing_key = state.jwt_public_signing_key(); let headers = &parts.headers; let authorization_header_value = headers @@ -130,9 +99,11 @@ impl FromRequestParts for Authorization { let auth_token = authorization_header_value .to_str() .map_err(internal_error)?; - let user_claim = UserClaim::from_bearer_token(jwt_public_signing_key, auth_token) - .await - .map_err(|_| unauthorized_error())?; + let user_claim = + si_jwt_public_key::validate_bearer_token(jwt_public_signing_key.clone(), &auth_token) + .await + .map_err(|_| unauthorized_error())? + .custom; Ok(Self { user_claim, diff --git a/lib/module-index-server/src/jwt_key.rs b/lib/module-index-server/src/jwt_key.rs deleted file mode 100644 index d2816136de..0000000000 --- a/lib/module-index-server/src/jwt_key.rs +++ /dev/null @@ -1,110 +0,0 @@ -use jwt_simple::prelude::{JWTClaims, RS256PublicKey, RSAPublicKeyLike}; -use std::{path::Path, pin::Pin, sync::Arc}; -use telemetry::prelude::*; -use thiserror::Error; - -use tokio::{ - fs, - io::{AsyncRead, AsyncReadExt}, - task::JoinError, -}; - -use crate::extract::UserClaim; - -#[derive(Clone, Debug)] -pub struct JwtPublicSigningKey { - inner: Arc, -} -impl JwtPublicSigningKey { - #[instrument(level = "debug", skip_all)] - pub async fn load(path: impl AsRef) -> JwtKeyResult { - trace!( - path = path.as_ref().to_string_lossy().as_ref(), - "loading jwt public signing key" - ); - let mut file = fs::File::open(path).await?; - Self::from_reader(Pin::new(&mut file)).await - } - - async fn from_reader(mut reader: Pin<&mut impl AsyncRead>) -> JwtKeyResult { - let mut public_key_string = String::new(); - reader.read_to_string(&mut public_key_string).await?; - - let inner = tokio::task::spawn_blocking(move || { - RS256PublicKey::from_pem(&public_key_string) - .map_err(|err| JwtKeyError::KeyFromPem(format!("{err}"))) - }) - .instrument(trace_span!( - "from_pem", - code.namespace = "jwt_simple::algorithms::RS256PublicKey" - )) - .await??; - - Ok(Self { - inner: Arc::new(inner), - }) - } -} - -pub type JwtKeyResult = Result; - -#[remain::sorted] -#[derive(Error, Debug)] -pub enum JwtKeyError { - #[error("bad nonce bytes")] - BadNonce, - #[error("failed to decode base64 string: {0}")] - Base64Decode(#[from] base64::DecodeError), - #[error("invalid bearer token")] - BearerToken, - #[error("failed to decrypt secret data")] - Decrypt, - #[error("error generating new keypair")] - GenerateKeyPair, - #[error("io error: {0}")] - Io(#[from] std::io::Error), - #[error("failed to load jwt encryption key from bytes")] - JwtEncryptionKeyParse, - #[error("failure to build signing key from pem: {0}")] - KeyFromPem(String), - #[error("failure to extract metadata from bearer token: {0}")] - Metadata(String), - #[error("no signing keys - bad news for you!")] - NoKeys, - #[error("{0}")] - TaskJoin(#[from] JoinError), - #[error("failed to convert into PEM format")] - ToPem, - #[error(transparent)] - UlidDecode(#[from] ulid::DecodeError), - #[error("failed to build string from utf8: {0}")] - Utf8(#[from] std::string::FromUtf8Error), - #[error("failure to verify token: {0}")] - Verify(String), -} - -#[instrument(level = "debug", skip_all)] -pub async fn validate_bearer_token( - public_key: JwtPublicSigningKey, - bearer_token: impl AsRef, -) -> JwtKeyResult> { - let bearer_token = bearer_token.as_ref(); - let token = if let Some(token) = bearer_token.strip_prefix("Bearer ") { - token.to_string() - } else { - return Err(JwtKeyError::BearerToken); - }; - - let claims = tokio::task::spawn_blocking(move || { - public_key - .inner - .verify_token::(&token, None) - .map_err(|err| JwtKeyError::Verify(format!("{err}"))) - }) - .instrument(trace_span!( - "verfy_token", - code.namespace = "jwt_simple::algorithms::RSAPublicKeyLike" - )) - .await??; - Ok(claims) -} diff --git a/lib/module-index-server/src/lib.rs b/lib/module-index-server/src/lib.rs index 451d5429b5..a6077ab43a 100644 --- a/lib/module-index-server/src/lib.rs +++ b/lib/module-index-server/src/lib.rs @@ -1,7 +1,6 @@ mod app_state; mod config; mod extract; -mod jwt_key; mod models; mod routes; mod s3; diff --git a/lib/module-index-server/src/server.rs b/lib/module-index-server/src/server.rs index 8e208b029c..e2d357d913 100644 --- a/lib/module-index-server/src/server.rs +++ b/lib/module-index-server/src/server.rs @@ -1,4 +1,4 @@ -use std::{io, net::SocketAddr, path::Path, time::Duration}; +use std::{io, net::SocketAddr, time::Duration}; use super::routes; @@ -8,6 +8,7 @@ use hyper::server::{accept::Accept, conn::AddrIncoming}; use s3::creds::{error::CredentialsError, Credentials as AwsCredentials}; use sea_orm::{ConnectOptions, Database, DatabaseConnection, DbErr}; use si_data_pg::{PgPool, PgPoolConfig, PgPoolError}; +use si_jwt_public_key::{JwtConfig, JwtPublicSigningKeyChain, JwtPublicSigningKeyError}; use si_posthog::{PosthogClient, PosthogConfig}; use telemetry::prelude::*; use thiserror::Error; @@ -21,7 +22,6 @@ use tower_http::trace::{DefaultMakeSpan, TraceLayer}; use crate::{ app_state::{AppState, ShutdownSource}, - jwt_key::{JwtKeyError, JwtPublicSigningKey}, s3::S3Config, Config, }; @@ -43,8 +43,8 @@ pub enum ServerError { DbErr(#[from] DbErr), #[error("hyper server error")] Hyper(#[from] hyper::Error), - #[error("jwt secret key error")] - JwtSecretKey(#[from] JwtKeyError), + #[error("jwt public key error: {0}")] + JwtPublicKey(#[from] JwtPublicSigningKeyError), #[error(transparent)] PgPool(#[from] Box), #[error(transparent)] @@ -61,7 +61,7 @@ impl From for ServerError { } } -type Result = std::result::Result; +type ServerResult = std::result::Result; pub struct Server { config: Config, @@ -75,9 +75,9 @@ impl Server<(), ()> { pub fn http( config: Config, pg_pool: DatabaseConnection, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key: JwtPublicSigningKeyChain, posthog_client: PosthogClient, - ) -> Result<(Server, broadcast::Receiver<()>)> { + ) -> ServerResult<(Server, broadcast::Receiver<()>)> { // socket_addr // try to load aws creds from a few different places @@ -141,7 +141,7 @@ impl Server<(), ()> { // this creates our si_data_pg::PgPool, which wont work with SeaORM #[instrument(name = "module-index.init.create_pg_pool", level = "info", skip_all)] - pub async fn create_pg_pool(pg_pool_config: &PgPoolConfig) -> Result { + pub async fn create_pg_pool(pg_pool_config: &PgPoolConfig) -> ServerResult { let pool = PgPool::new(pg_pool_config).await?; debug!("successfully started pg pool (note that not all connections may be healthy)"); Ok(pool) @@ -153,7 +153,9 @@ impl Server<(), ()> { level = "info", skip_all )] - pub async fn create_db_connection(pg_pool_config: &PgPoolConfig) -> Result { + pub async fn create_db_connection( + pg_pool_config: &PgPoolConfig, + ) -> ServerResult { let mut opt = ConnectOptions::new(format!( "{protocol}://{username}:{password}@{host}:{port}/{database}", protocol = "postgres", @@ -181,7 +183,7 @@ impl Server<(), ()> { Ok(db) } - pub async fn run_migrations(pg_pool: &PgPool) -> Result<()> { + pub async fn run_migrations(pg_pool: &PgPool) -> ServerResult<()> { Ok(pg_pool .migrate(embedded_migrations::migrations::runner()) .await?) @@ -193,12 +195,27 @@ impl Server<(), ()> { skip_all )] pub async fn load_jwt_public_signing_key( - path: impl AsRef, - ) -> Result { - Ok(JwtPublicSigningKey::load(path).await?) + config: &Config, + ) -> ServerResult { + let primary = JwtConfig { + key_file: Some(config.jwt_signing_public_key_path().to_owned()), + key_base64: None, + algo: config.jwt_signing_public_key_algo(), + }; + + let secondary = config + .jwt_secondary_signing_public_key_path() + .zip(config.jwt_secondary_signing_public_key_algo()) + .map(|(path, algo)| JwtConfig { + key_file: Some(path.to_owned()), + key_base64: None, + algo, + }); + + Ok(JwtPublicSigningKeyChain::from_config(primary, secondary).await?) } - pub async fn start_posthog(config: &PosthogConfig) -> Result { + pub async fn start_posthog(config: &PosthogConfig) -> ServerResult { // TODO(fnichol): this should be threaded through let token = CancellationToken::new(); @@ -216,7 +233,7 @@ where IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, IE: Into>, { - pub async fn run(self) -> Result<()> { + pub async fn run(self) -> ServerResult<()> { let shutdown_rx = self.shutdown_rx; self.inner @@ -240,17 +257,17 @@ where pub fn build_service( pg_pool: DatabaseConnection, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, aws_creds: AwsCredentials, s3_config: S3Config, -) -> Result<(Router, oneshot::Receiver<()>, broadcast::Receiver<()>)> { +) -> ServerResult<(Router, oneshot::Receiver<()>, broadcast::Receiver<()>)> { let (shutdown_tx, shutdown_rx) = mpsc::channel(1); let (shutdown_broadcast_tx, shutdown_broadcast_rx) = broadcast::channel(1); let state = AppState::new( pg_pool, - jwt_public_signing_key, + jwt_public_signing_key_chain, posthog_client, aws_creds, s3_config, @@ -273,7 +290,7 @@ pub fn build_service( fn prepare_graceful_shutdown( mut shutdown_rx: mpsc::Receiver, shutdown_broadcast_tx: broadcast::Sender<()>, -) -> Result> { +) -> ServerResult> { let (graceful_shutdown_tx, graceful_shutdown_rx) = oneshot::channel::<()>(); let mut sigterm_watcher = signal::unix::signal(signal::unix::SignalKind::terminate()).map_err(ServerError::Signal)?; diff --git a/lib/sdf-server/BUCK b/lib/sdf-server/BUCK index 5f58ec72ca..9f807d5554 100644 --- a/lib/sdf-server/BUCK +++ b/lib/sdf-server/BUCK @@ -19,6 +19,7 @@ rust_library( "//lib/si-data-spicedb:si-data-spicedb", "//lib/si-events-rs:si-events", "//lib/si-frontend-types-rs:si-frontend-types", + "//lib/si-jwt-public-key:si-jwt-public-key", "//lib/si-layer-cache:si-layer-cache", "//lib/si-pkg:si-pkg", "//lib/si-posthog-rs:si-posthog", diff --git a/lib/sdf-server/Cargo.toml b/lib/sdf-server/Cargo.toml index ac66a9b653..afd7e8c1e1 100644 --- a/lib/sdf-server/Cargo.toml +++ b/lib/sdf-server/Cargo.toml @@ -25,6 +25,7 @@ si-data-pg = { path = "../../lib/si-data-pg" } si-data-spicedb = { path = "../../lib/si-data-spicedb" } si-events = { path = "../../lib/si-events-rs" } si-frontend-types = { path = "../../lib/si-frontend-types-rs" } +si-jwt-public-key = { path = "../../lib/si-jwt-public-key" } si-layer-cache = { path = "../../lib/si-layer-cache" } si-pkg = { path = "../../lib/si-pkg" } si-posthog = { path = "../../lib/si-posthog-rs" } diff --git a/lib/sdf-server/src/app.rs b/lib/sdf-server/src/app.rs index 293cf54470..5ffeed835a 100644 --- a/lib/sdf-server/src/app.rs +++ b/lib/sdf-server/src/app.rs @@ -3,9 +3,10 @@ use std::sync::Arc; use asset_sprayer::AssetSprayer; use audit_database::AuditDatabaseContext; use axum::Router; -use dal::{JwtPublicSigningKey, ServicesContext}; +use dal::ServicesContext; use nats_multiplexer_client::MultiplexerClient; use si_data_spicedb::SpiceDbClient; +use si_jwt_public_key::JwtPublicSigningKeyChain; use si_posthog::PosthogClient; use telemetry::prelude::*; use tokio::sync::RwLock; @@ -25,7 +26,7 @@ impl AxumApp { #[inline] pub fn from_services( services_context: ServicesContext, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, auth_api_url: impl AsRef, asset_sprayer: Option, @@ -40,7 +41,7 @@ impl AxumApp { ) -> Self { Self::inner_from_services( services_context, - jwt_public_signing_key, + jwt_public_signing_key_chain, posthog_client, auth_api_url, asset_sprayer, @@ -65,7 +66,7 @@ impl AxumApp { #[inline] pub fn from_services_for_tests( services_context: ServicesContext, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, auth_api_url: impl AsRef, asset_sprayer: Option, @@ -80,7 +81,7 @@ impl AxumApp { ) -> Self { Self::inner_from_services( services_context, - jwt_public_signing_key, + jwt_public_signing_key_chain, posthog_client, auth_api_url, asset_sprayer, @@ -103,7 +104,7 @@ impl AxumApp { #[allow(clippy::too_many_arguments)] fn inner_from_services( services_context: ServicesContext, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, auth_api_url: impl AsRef, asset_sprayer: Option, @@ -119,7 +120,7 @@ impl AxumApp { ) -> Self { let state = AppState::new( services_context, - jwt_public_signing_key, + jwt_public_signing_key_chain, posthog_client, auth_api_url, asset_sprayer, diff --git a/lib/sdf-server/src/app_state.rs b/lib/sdf-server/src/app_state.rs index 4cc29053e5..c731cad42c 100644 --- a/lib/sdf-server/src/app_state.rs +++ b/lib/sdf-server/src/app_state.rs @@ -3,9 +3,9 @@ use std::{ops::Deref, sync::Arc}; use asset_sprayer::AssetSprayer; use audit_database::AuditDatabaseContext; use axum::extract::FromRef; -use dal::JwtPublicSigningKey; use nats_multiplexer_client::MultiplexerClient; use si_data_spicedb::SpiceDbClient; +use si_jwt_public_key::JwtPublicSigningKeyChain; use std::fmt; use tokio::sync::{Mutex, RwLock}; use tokio_util::sync::CancellationToken; @@ -26,7 +26,7 @@ pub enum ApplicationRuntimeMode { pub struct AppState { services_context: ServicesContext, broadcast_groups: BroadcastGroups, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, auth_api_url: String, // TODO(victor) store the auth client on state instead of just the URL asset_sprayer: Option, @@ -44,7 +44,7 @@ impl AppState { #[allow(clippy::too_many_arguments)] pub fn new( services_context: impl Into, - jwt_public_signing_key: impl Into, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: impl Into, auth_api_url: impl AsRef, asset_sprayer: Option, @@ -65,7 +65,7 @@ impl AppState { Self { services_context: services_context.into(), - jwt_public_signing_key: jwt_public_signing_key.into(), + jwt_public_signing_key_chain, broadcast_groups: Default::default(), posthog_client: posthog_client.into(), auth_api_url: auth_api_url.as_ref().to_string(), @@ -97,8 +97,8 @@ impl AppState { self.asset_sprayer.as_ref() } - pub fn jwt_public_signing_key(&self) -> &JwtPublicSigningKey { - &self.jwt_public_signing_key + pub fn jwt_public_signing_key_chain(&self) -> &JwtPublicSigningKeyChain { + &self.jwt_public_signing_key_chain } pub fn for_tests(&self) -> bool { diff --git a/lib/sdf-server/src/config.rs b/lib/sdf-server/src/config.rs index 2c00ad6243..1f68fcd110 100644 --- a/lib/sdf-server/src/config.rs +++ b/lib/sdf-server/src/config.rs @@ -1,9 +1,9 @@ use asset_sprayer::config::{AssetSprayerConfig, SIOpenAIConfig}; use audit_database::AuditDatabaseConfig; -use dal::jwt_key::JwtConfig; use serde_with::{DeserializeFromStr, SerializeDisplay}; use si_crypto::VeritechCryptoConfig; use si_data_spicedb::SpiceDbConfig; +use si_jwt_public_key::{JwtAlgo, JwtConfig}; use si_layer_cache::{db::LayerDbConfig, error::LayerDbError}; use std::collections::HashSet; use std::{ @@ -130,6 +130,9 @@ pub struct Config { #[builder(default)] jwt_signing_public_key: JwtConfig, + #[builder(default)] + jwt_secondary_signing_public_key: Option, + #[builder(default = "default_layer_db_config()")] layer_db_config: LayerDbConfig, @@ -191,6 +194,10 @@ impl Config { &self.jwt_signing_public_key } + pub fn jwt_secondary_signing_public_key(&self) -> Option<&JwtConfig> { + self.jwt_secondary_signing_public_key.as_ref() + } + /// Gets a reference to the config's cyclone public key path. #[must_use] pub fn crypto(&self) -> &VeritechCryptoConfig { @@ -298,6 +305,8 @@ pub struct ConfigFile { #[serde(default)] pub jwt_signing_public_key: JwtConfig, #[serde(default)] + pub jwt_secondary_signing_public_key: Option, + #[serde(default)] pub crypto: VeritechCryptoConfig, #[serde(default = "default_pkgs_path")] pub pkgs_path: String, @@ -335,6 +344,7 @@ impl Default for ConfigFile { nats: Default::default(), migration_mode: Default::default(), jwt_signing_public_key: Default::default(), + jwt_secondary_signing_public_key: Default::default(), crypto: Default::default(), pkgs_path: default_pkgs_path(), posthog: Default::default(), @@ -371,6 +381,7 @@ impl TryFrom for Config { incoming_stream: IncomingStream::default(), // TODO this OK? migration_mode: value.migration_mode, jwt_signing_public_key: value.jwt_signing_public_key, + jwt_secondary_signing_public_key: value.jwt_secondary_signing_public_key, crypto: value.crypto, pkgs_path: value.pkgs_path.try_into()?, posthog: value.posthog, @@ -463,7 +474,7 @@ fn buck2_development(config: &mut ConfigFile) -> Result<()> { #[allow(clippy::disallowed_methods)] // Used in development with a local auth services // Note(victor): If the user has set a custom auth ip url via env variable we assume dev mode - let jwt_signing_public_key_path = if env::var("SI_AUTH_API_URL").is_ok() { + let jwt_primary_signing_public_key_path = if env::var("SI_AUTH_API_URL").is_ok() { resources .get_ends_with("dev.jwt_signing_public_key.pem") .map_err(ConfigError::development)? @@ -498,7 +509,7 @@ fn buck2_development(config: &mut ConfigFile) -> Result<()> { .to_string(); warn!( - jwt_signing_public_key_path = jwt_signing_public_key_path.as_str(), + jwt_signing_public_key_path = jwt_primary_signing_public_key_path.as_str(), veritech_encryption_key_path = veritech_encryption_key_path.as_str(), symmetric_crypto_service_key = symmetric_crypto_service_key.as_str(), postgres_cert = postgres_cert.as_str(), @@ -507,8 +518,9 @@ fn buck2_development(config: &mut ConfigFile) -> Result<()> { ); config.jwt_signing_public_key = JwtConfig { - key_file: Some(jwt_signing_public_key_path.try_into()?), + key_file: Some(jwt_primary_signing_public_key_path.try_into()?), key_base64: None, + algo: JwtAlgo::RS256, }; config.crypto.encryption_key_file = veritech_encryption_key_path.parse().ok(); config.symmetric_crypto_service = SymmetricCryptoServiceConfigFile { @@ -572,6 +584,7 @@ fn cargo_development(dir: String, config: &mut ConfigFile) -> Result<()> { config.jwt_signing_public_key = JwtConfig { key_file: Some(jwt_signing_public_key_path.try_into()?), key_base64: None, + algo: JwtAlgo::RS256, }; config.crypto.encryption_key_file = veritech_encryption_key_path.parse().ok(); config.symmetric_crypto_service = SymmetricCryptoServiceConfigFile { diff --git a/lib/sdf-server/src/extract.rs b/lib/sdf-server/src/extract.rs index f387b0ab91..66e00c105f 100644 --- a/lib/sdf-server/src/extract.rs +++ b/lib/sdf-server/src/extract.rs @@ -211,7 +211,7 @@ impl FromRequestParts for Authorization { let HandlerContext(builder) = HandlerContext::from_request_parts(parts, state).await?; let mut ctx = builder.build_default().await.map_err(internal_error)?; - let jwt_public_signing_key = state.jwt_public_signing_key().clone(); + let jwt_public_signing_key = state.jwt_public_signing_key_chain().clone(); let headers = &parts.headers; let authorization_header_value = headers @@ -251,7 +251,7 @@ impl FromRequestParts for WsAuthorization { ) -> Result { let HandlerContext(builder) = HandlerContext::from_request_parts(parts, state).await?; let mut ctx = builder.build_default().await.map_err(internal_error)?; - let jwt_public_signing_key = state.jwt_public_signing_key().clone(); + let jwt_public_signing_key = state.jwt_public_signing_key_chain().clone(); let query: Query> = Query::from_request_parts(parts, state) .await diff --git a/lib/sdf-server/src/init.rs b/lib/sdf-server/src/init.rs index 442298fd77..81555258ee 100644 --- a/lib/sdf-server/src/init.rs +++ b/lib/sdf-server/src/init.rs @@ -1,8 +1,8 @@ use std::{path::PathBuf, sync::Arc}; use dal::{ - feature_flags::FeatureFlagService, jwt_key::JwtConfig, DalLayerDb, DedicatedExecutor, - JetstreamStreams, JobQueueProcessor, JwtPublicSigningKey, NatsProcessor, ServicesContext, + feature_flags::FeatureFlagService, DalLayerDb, DedicatedExecutor, JetstreamStreams, + JobQueueProcessor, NatsProcessor, ServicesContext, }; use rebaser_client::RebaserClient; use si_crypto::{ @@ -11,6 +11,7 @@ use si_crypto::{ }; use si_data_nats::{NatsClient, NatsConfig}; use si_data_pg::{PgPool, PgPoolConfig}; +use si_jwt_public_key::{JwtConfig, JwtPublicSigningKeyChain, JwtPublicSigningKeyError}; use si_layer_cache::{ db::{LayerDbConfig, LayerDbGracefulShutdown}, LayerDb, @@ -32,7 +33,7 @@ pub enum InitError { #[error("failed to initialize a dal jetstream streams: {0}")] DalJetstreamStreams(#[source] dal::JetstreamStreamsError), #[error("jwt key error")] - JwtKey(#[from] dal::jwt_key::JwtKeyError), + JwtKey(#[from] JwtPublicSigningKeyError), #[error("layer cache error: {0}")] LayerCache(#[from] si_layer_cache::LayerDbError), #[error("failed to initialize a nats client: {0}")] @@ -195,9 +196,10 @@ pub(crate) async fn initialize_layer_db( skip_all )] pub(crate) async fn load_jwt_public_signing_key( - config: JwtConfig, -) -> InitResult { - Ok(JwtPublicSigningKey::from_config(config).await?) + primary: JwtConfig, + secondary: Option, +) -> InitResult { + Ok(JwtPublicSigningKeyChain::from_config(primary, secondary).await?) } pub(crate) fn initialize_posthog( diff --git a/lib/sdf-server/src/server.rs b/lib/sdf-server/src/server.rs index 7cbd92667e..8c1e8d0f9a 100644 --- a/lib/sdf-server/src/server.rs +++ b/lib/sdf-server/src/server.rs @@ -3,11 +3,12 @@ use std::{fmt, future::IntoFuture as _, net::SocketAddr, path::PathBuf, sync::Ar use asset_sprayer::AssetSprayer; use audit_database::AuditDatabaseContext; use axum::{async_trait, routing::IntoMakeService, Router}; -use dal::{JwtPublicSigningKey, ServicesContext}; +use dal::ServicesContext; use hyper::server::accept::Accept; use nats_multiplexer::Multiplexer; use nats_multiplexer_client::MultiplexerClient; use si_data_spicedb::SpiceDbClient; +use si_jwt_public_key::JwtPublicSigningKeyChain; use si_posthog::PosthogClient; use telemetry::prelude::*; use tokio::{ @@ -83,8 +84,11 @@ impl Server { let (services_context, layer_db_graceful_shutdown) = init::services_context_from_config(&config, helping_tasks_token.clone()).await?; - let jwt_public_signing_key = - init::load_jwt_public_signing_key(config.jwt_signing_public_key().clone()).await?; + let jwt_public_signing_key = init::load_jwt_public_signing_key( + config.jwt_signing_public_key().clone(), + config.jwt_secondary_signing_public_key().cloned(), + ) + .await?; let (posthog_sender, posthog_client) = init::initialize_posthog(config.posthog(), helping_tasks_token.clone())?; @@ -155,7 +159,7 @@ impl Server { instance_id: impl Into, incoming_stream: IncomingStream, services_context: ServicesContext, - jwt_public_signing_key: JwtPublicSigningKey, + jwt_public_signing_key_chain: JwtPublicSigningKeyChain, posthog_client: PosthogClient, auth_api_url: impl AsRef, asset_sprayer: Option, @@ -170,7 +174,7 @@ impl Server { ) -> ServerResult { let app = AxumApp::from_services( services_context.clone(), - jwt_public_signing_key, + jwt_public_signing_key_chain, posthog_client, auth_api_url, asset_sprayer, diff --git a/lib/si-jwt-public-key/BUCK b/lib/si-jwt-public-key/BUCK new file mode 100644 index 0000000000..2c8f75f0a1 --- /dev/null +++ b/lib/si-jwt-public-key/BUCK @@ -0,0 +1,23 @@ +load( + "@prelude-si//:macros.bzl", + "rust_library", + "rust_test", +) + +rust_library( + name = "si-jwt-public-key", + deps = [ + "//lib/si-events-rs:si-events", + "//lib/si-std:si-std", + "//lib/telemetry-rs:telemetry", + "//third-party/rust:base64", + "//third-party/rust:jwt-simple", + "//third-party/rust:remain", + "//third-party/rust:serde", + "//third-party/rust:thiserror", + "//third-party/rust:tokio", + ], + srcs = glob([ + "src/**/*.rs", + ]), +) diff --git a/lib/si-jwt-public-key/Cargo.toml b/lib/si-jwt-public-key/Cargo.toml new file mode 100644 index 0000000000..b7978ee520 --- /dev/null +++ b/lib/si-jwt-public-key/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "si-jwt-public-key" +version.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +edition.workspace = true +rust-version.workspace = true +publish.workspace = true + +[dependencies] +si-events = { path = "../../lib/si-events-rs" } +si-std = { path = "../../lib/si-std" } +telemetry = { path = "../../lib/telemetry-rs" } + +base64 = { workspace = true } +jwt-simple = { workspace = true } +remain = { workspace = true } +serde = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true } diff --git a/lib/si-jwt-public-key/src/lib.rs b/lib/si-jwt-public-key/src/lib.rs new file mode 100644 index 0000000000..f8aa70316b --- /dev/null +++ b/lib/si-jwt-public-key/src/lib.rs @@ -0,0 +1,355 @@ +use core::str; +use si_events::{UserPk, WorkspacePk}; +use si_std::CanonicalFile; +use std::sync::Arc; + +use base64::{engine::general_purpose, Engine}; +use jwt_simple::{common::VerificationOptions, prelude::*}; +use serde::{Deserialize, Serialize}; +use telemetry::prelude::*; +use thiserror::Error; +use tokio::{fs, io::AsyncReadExt, task::JoinError}; + +#[remain::sorted] +#[derive(Error, Debug)] +pub enum JwtPublicSigningKeyError { + #[error("failed to decode base64 string: {0}")] + Base64Decode(#[from] base64::DecodeError), + #[error("invalid bearer token")] + BearerToken, + #[error("error creating jwt from config")] + FromConfig, + #[error("io error: {0}")] + Io(#[from] std::io::Error), + #[error("JWT error: {0}")] + Jwt(#[from] jwt_simple::Error), + #[error("{0}")] + TaskJoin(#[from] JoinError), + #[error("Unsupported JWT signing algorithm: {0}")] + UnsupportedAlgo(String), + #[error("failed to build string from utf8: {0}")] + Utf8(#[from] std::str::Utf8Error), + #[error("failure to verify token: {0}")] + Verify(String), + #[error("failure to verify against secondary token: first error: {0}, second error: {1}")] + VerifySecondaryFail(String, String), +} + +pub type JwtKeyResult = Result; + +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct JwtConfig { + pub key_file: Option, + pub key_base64: Option, + pub algo: JwtAlgo, +} + +impl JwtConfig { + pub async fn to_pem(self) -> JwtKeyResult { + Ok(match (self.key_file.as_ref(), self.key_base64.as_deref()) { + (None, Some(key_base64)) => { + let buf = general_purpose::STANDARD.decode(key_base64)?; + str::from_utf8(&buf)?.to_string() + } + (Some(key_file), None) => { + let mut file = fs::File::open(key_file).await?; + let mut buf = String::new(); + file.read_to_string(&mut buf).await?; + + buf + } + _ => Err(JwtPublicSigningKeyError::FromConfig)?, + }) + } + + pub async fn into_verify(self) -> JwtKeyResult> { + let algo = self.algo; + let pem = self.to_pem().await?; + + Ok(match algo { + JwtAlgo::ES256 => { + Arc::new(ES256PublicKey::from_pem(&pem)?) as Arc + } + JwtAlgo::RS256 => { + Arc::new(RS256PublicKey::from_pem(&pem)?) as Arc + } + }) + } +} + +#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)] +pub struct SiJwtClaims { + pub user_pk: UserPk, + pub workspace_pk: WorkspacePk, +} + +#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] +pub enum JwtAlgo { + #[default] + ES256, + RS256, +} + +pub trait JwtPublicKeyVerify: std::fmt::Debug + Send + Sync { + fn algo(&self) -> JwtAlgo; + fn verify( + &self, + token: &str, + options: Option, + ) -> JwtKeyResult>; +} + +impl JwtPublicKeyVerify for RS256PublicKey { + fn algo(&self) -> JwtAlgo { + JwtAlgo::RS256 + } + + fn verify( + &self, + token: &str, + options: Option, + ) -> JwtKeyResult> { + self.verify_token(token, options) + .map_err(|err| JwtPublicSigningKeyError::Verify(format!("{err}"))) + } +} + +impl JwtPublicKeyVerify for ES256PublicKey { + fn algo(&self) -> JwtAlgo { + JwtAlgo::ES256 + } + + fn verify( + &self, + token: &str, + options: Option, + ) -> JwtKeyResult> { + self.verify_token(token, options) + .map_err(|err| JwtPublicSigningKeyError::Verify(format!("{err}"))) + } +} + +#[derive(Clone, Debug)] +pub struct JwtPublicSigningKeyChain { + primary: Arc, + secondary: Option>, +} + +impl JwtPublicSigningKeyChain { + pub async fn from_config( + primary: JwtConfig, + secondary: Option, + ) -> JwtKeyResult { + Ok(Self { + primary: primary.into_verify().await?, + secondary: match secondary { + Some(jwt_cfg) => Some(jwt_cfg.into_verify().await?), + None => None, + }, + }) + } + + /// Attempt to verify that this token was signed by either the primary or + /// secondary key(s) + pub fn verify_token( + &self, + token: &str, + options: Option, + ) -> JwtKeyResult> { + match self.primary.verify(token, options.clone()) { + Ok(claims) => Ok(claims), + Err(err) => match self.secondary.as_ref() { + Some(secondary) => match secondary.verify(token, options) { + Ok(claims) => Ok(claims), + Err(second_err) => Err(JwtPublicSigningKeyError::VerifySecondaryFail( + err.to_string(), + second_err.to_string(), + )), + }, + None => Err(err), + }, + } + } +} + +#[instrument(level = "debug", skip_all)] +pub async fn validate_bearer_token( + public_key: JwtPublicSigningKeyChain, + bearer_token: impl AsRef, +) -> JwtKeyResult> { + let bearer_token = bearer_token.as_ref(); + let token = if let Some(token) = bearer_token.strip_prefix("Bearer ") { + token.to_string() + } else { + return Err(JwtPublicSigningKeyError::BearerToken); + }; + + let claims = + tokio::task::spawn_blocking(move || public_key.verify_token(&token, None)).await??; + + Ok(claims) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn validate_with_primary_rs256() { + println!("generating key..."); + let key_pair = RS256KeyPair::generate(2048).expect("generate key pair"); + println!("done"); + + let pub_key = key_pair.public_key(); + let pub_key_pem = pub_key.to_pem().expect("get pub key pem"); + let pub_key_base64 = general_purpose::STANDARD.encode(pub_key_pem); + + let si_claim = SiJwtClaims { + user_pk: UserPk::generate(), + workspace_pk: WorkspacePk::generate(), + }; + + let claims = JWTClaims { + issued_at: None, + expires_at: None, + invalid_before: None, + issuer: None, + subject: None, + audiences: None, + jwt_id: None, + nonce: None, + custom: si_claim.clone(), + }; + + let signed = key_pair.sign(claims).expect("sign the key"); + let bearer_token = format!("Bearer {signed}"); + + let primary_cfg = JwtConfig { + key_file: None, + key_base64: Some(pub_key_base64), + algo: JwtAlgo::RS256, + }; + + let key_chain = JwtPublicSigningKeyChain::from_config(primary_cfg, None) + .await + .expect("make key chain"); + + let claims = validate_bearer_token(key_chain, &bearer_token) + .await + .expect("shoudl validate"); + + assert_eq!(si_claim, claims.custom); + } + + #[tokio::test] + async fn validate_with_primary_es256() { + println!("generating key..."); + let key_pair = ES256KeyPair::generate(); + let key_pair_2 = ES256KeyPair::generate(); + println!("done"); + + let pub_key = key_pair.public_key(); + let pub_key_pem = pub_key.to_pem().expect("get pub key pem"); + let pub_key_base64 = general_purpose::STANDARD.encode(pub_key_pem); + + let si_claim = SiJwtClaims { + user_pk: UserPk::generate(), + workspace_pk: WorkspacePk::generate(), + }; + + let claims = JWTClaims { + issued_at: None, + expires_at: None, + invalid_before: None, + issuer: None, + subject: None, + audiences: None, + jwt_id: None, + nonce: None, + custom: si_claim.clone(), + }; + + let signed = key_pair.sign(claims.clone()).expect("sign the key"); + let bearer_token = format!("Bearer {signed}"); + + let primary_cfg = JwtConfig { + key_file: None, + key_base64: Some(pub_key_base64), + algo: JwtAlgo::ES256, + }; + + let key_chain = JwtPublicSigningKeyChain::from_config(primary_cfg, None) + .await + .expect("make key chain"); + + let claims = validate_bearer_token(key_chain.clone(), &bearer_token) + .await + .expect("shoudl validate"); + + assert_eq!(si_claim, claims.custom); + + // Just confirm it fails with the wrong key + let signed_bad = key_pair_2.sign(claims).expect("sign the key"); + let bearer_bad = format!("Bearer {signed_bad}"); + let result = validate_bearer_token(key_chain, &bearer_bad).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn validate_with_secondary_rs256() { + println!("generating keys..."); + let key_pair_es256 = ES256KeyPair::generate(); + let key_pair_rs256 = RS256KeyPair::generate(2048).expect("generate rs256 key"); + println!("done"); + + let pub_key_es256 = key_pair_es256.public_key(); + let pub_key_pem = pub_key_es256.to_pem().expect("get pub key pem"); + let pub_key_base64_es256 = general_purpose::STANDARD.encode(pub_key_pem); + + let pub_key_rs256 = key_pair_rs256.public_key(); + let pub_key_pem = pub_key_rs256.to_pem().expect("get pub key pem"); + let pub_key_base64_rs256 = general_purpose::STANDARD.encode(pub_key_pem); + + let si_claim = SiJwtClaims { + user_pk: UserPk::generate(), + workspace_pk: WorkspacePk::generate(), + }; + + let claims = JWTClaims { + issued_at: None, + expires_at: None, + invalid_before: None, + issuer: None, + subject: None, + audiences: None, + jwt_id: None, + nonce: None, + custom: si_claim.clone(), + }; + + let signed = key_pair_rs256.sign(claims.clone()).expect("sign the key"); + let bearer_token = format!("Bearer {signed}"); + + let primary_cfg = JwtConfig { + key_file: None, + key_base64: Some(pub_key_base64_es256), + algo: JwtAlgo::ES256, + }; + + let secondary_cfg = JwtConfig { + key_file: None, + key_base64: Some(pub_key_base64_rs256), + algo: JwtAlgo::RS256, + }; + + let key_chain = JwtPublicSigningKeyChain::from_config(primary_cfg, Some(secondary_cfg)) + .await + .expect("make key chain"); + + let claims = validate_bearer_token(key_chain.clone(), &bearer_token) + .await + .expect("shoudl validate"); + + assert_eq!(si_claim, claims.custom); + } +}