From 9b618d24476967d364835d04010d9076a80aeb9c Mon Sep 17 00:00:00 2001 From: Prajjwal Kumar Date: Fri, 3 Nov 2023 18:37:31 +0530 Subject: [PATCH] feat(router): Add Smart Routing to route payments efficiently (#2665) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: shashank_attarde Co-authored-by: Aprabhat19 Co-authored-by: Amisha Prabhat <55580080+Aprabhat19@users.noreply.github.com> --- Cargo.lock | 236 +++ crates/api_models/Cargo.toml | 7 +- crates/api_models/src/admin.rs | 66 - crates/api_models/src/lib.rs | 1 + crates/api_models/src/routing.rs | 594 +++++++ crates/common_utils/Cargo.toml | 1 + crates/common_utils/src/lib.rs | 2 + crates/common_utils/src/static_cache.rs | 91 + crates/diesel_models/src/enums.rs | 22 + crates/diesel_models/src/lib.rs | 1 + crates/diesel_models/src/query.rs | 1 + .../src/query/routing_algorithm.rs | 200 +++ crates/diesel_models/src/routing_algorithm.rs | 37 + crates/diesel_models/src/schema.rs | 23 + crates/euclid/Cargo.toml | 38 + crates/euclid/benches/backends.rs | 93 ++ crates/euclid/src/backend.rs | 25 + crates/euclid/src/backend/inputs.rs | 39 + crates/euclid/src/backend/interpreter.rs | 180 ++ .../euclid/src/backend/interpreter/types.rs | 81 + crates/euclid/src/backend/vir_interpreter.rs | 583 +++++++ .../src/backend/vir_interpreter/types.rs | 126 ++ crates/euclid/src/dssa.rs | 7 + crates/euclid/src/dssa/analyzer.rs | 447 +++++ crates/euclid/src/dssa/graph.rs | 1478 +++++++++++++++++ crates/euclid/src/dssa/state_machine.rs | 714 ++++++++ crates/euclid/src/dssa/truth.rs | 29 + crates/euclid/src/dssa/types.rs | 158 ++ crates/euclid/src/dssa/utils.rs | 1 + crates/euclid/src/enums.rs | 191 +++ crates/euclid/src/frontend.rs | 3 + crates/euclid/src/frontend/ast.rs | 156 ++ crates/euclid/src/frontend/ast/lowering.rs | 377 +++++ crates/euclid/src/frontend/ast/parser.rs | 441 +++++ crates/euclid/src/frontend/dir.rs | 803 +++++++++ crates/euclid/src/frontend/dir/enums.rs | 321 ++++ crates/euclid/src/frontend/dir/lowering.rs | 295 ++++ .../euclid/src/frontend/dir/transformers.rs | 166 ++ crates/euclid/src/frontend/vir.rs | 37 + crates/euclid/src/lib.rs | 7 + crates/euclid/src/types.rs | 318 ++++ crates/euclid/src/types/transformers.rs | 1 + crates/euclid/src/utils.rs | 3 + crates/euclid/src/utils/dense_map.rs | 224 +++ crates/euclid_macros/Cargo.toml | 16 + crates/euclid_macros/src/inner.rs | 5 + crates/euclid_macros/src/inner/enum_nums.rs | 47 + crates/euclid_macros/src/inner/knowledge.rs | 680 ++++++++ crates/euclid_macros/src/lib.rs | 16 + crates/euclid_wasm/Cargo.toml | 37 + crates/euclid_wasm/src/lib.rs | 227 +++ crates/euclid_wasm/src/types.rs | 7 + crates/euclid_wasm/src/utils.rs | 17 + crates/kgraph_utils/Cargo.toml | 27 + crates/kgraph_utils/benches/evaluation.rs | 113 ++ crates/kgraph_utils/src/error.rs | 14 + crates/kgraph_utils/src/lib.rs | 3 + crates/kgraph_utils/src/mca.rs | 739 +++++++++ crates/kgraph_utils/src/transformers.rs | 724 ++++++++ crates/router/Cargo.toml | 13 +- .../compatibility/stripe/payment_intents.rs | 12 +- .../stripe/payment_intents/types.rs | 14 +- .../src/compatibility/stripe/setup_intents.rs | 4 + .../stripe/setup_intents/types.rs | 14 +- crates/router/src/consts.rs | 2 + crates/router/src/core.rs | 1 + crates/router/src/core/admin.rs | 79 +- crates/router/src/core/errors.rs | 46 + .../router/src/core/payment_methods/cards.rs | 136 +- crates/router/src/core/payments.rs | 717 ++++++-- crates/router/src/core/payments/routing.rs | 950 +++++++++++ .../src/core/payments/routing/transformers.rs | 121 ++ crates/router/src/core/routing.rs | 713 ++++++++ crates/router/src/core/routing/helpers.rs | 479 ++++++ .../router/src/core/routing/transformers.rs | 86 + crates/router/src/core/webhooks.rs | 2 + crates/router/src/db.rs | 2 + crates/router/src/db/routing_algorithm.rs | 199 +++ crates/router/src/lib.rs | 1 + crates/router/src/routes.rs | 4 + crates/router/src/routes/app.rs | 39 + crates/router/src/routes/lock_utils.rs | 12 + crates/router/src/routes/payments.rs | 9 + crates/router/src/routes/routing.rs | 298 ++++ crates/router/src/types/api.rs | 20 +- crates/router/src/types/api/admin.rs | 4 +- crates/router/src/types/api/routing.rs | 41 + crates/router/src/types/storage.rs | 59 +- .../src/types/storage/routing_algorithm.rs | 3 + crates/router/src/types/transformers.rs | 165 +- crates/router/src/workflows/payment_sync.rs | 1 + crates/router/tests/payments.rs | 2 + crates/router/tests/payments2.rs | 2 + crates/router_env/src/logger/types.rs | 20 + .../down.sql | 4 + .../up.sql | 19 + 96 files changed, 15366 insertions(+), 223 deletions(-) create mode 100644 crates/api_models/src/routing.rs create mode 100644 crates/common_utils/src/static_cache.rs create mode 100644 crates/diesel_models/src/query/routing_algorithm.rs create mode 100644 crates/diesel_models/src/routing_algorithm.rs create mode 100644 crates/euclid/Cargo.toml create mode 100644 crates/euclid/benches/backends.rs create mode 100644 crates/euclid/src/backend.rs create mode 100644 crates/euclid/src/backend/inputs.rs create mode 100644 crates/euclid/src/backend/interpreter.rs create mode 100644 crates/euclid/src/backend/interpreter/types.rs create mode 100644 crates/euclid/src/backend/vir_interpreter.rs create mode 100644 crates/euclid/src/backend/vir_interpreter/types.rs create mode 100644 crates/euclid/src/dssa.rs create mode 100644 crates/euclid/src/dssa/analyzer.rs create mode 100644 crates/euclid/src/dssa/graph.rs create mode 100644 crates/euclid/src/dssa/state_machine.rs create mode 100644 crates/euclid/src/dssa/truth.rs create mode 100644 crates/euclid/src/dssa/types.rs create mode 100644 crates/euclid/src/dssa/utils.rs create mode 100644 crates/euclid/src/enums.rs create mode 100644 crates/euclid/src/frontend.rs create mode 100644 crates/euclid/src/frontend/ast.rs create mode 100644 crates/euclid/src/frontend/ast/lowering.rs create mode 100644 crates/euclid/src/frontend/ast/parser.rs create mode 100644 crates/euclid/src/frontend/dir.rs create mode 100644 crates/euclid/src/frontend/dir/enums.rs create mode 100644 crates/euclid/src/frontend/dir/lowering.rs create mode 100644 crates/euclid/src/frontend/dir/transformers.rs create mode 100644 crates/euclid/src/frontend/vir.rs create mode 100644 crates/euclid/src/lib.rs create mode 100644 crates/euclid/src/types.rs create mode 100644 crates/euclid/src/types/transformers.rs create mode 100644 crates/euclid/src/utils.rs create mode 100644 crates/euclid/src/utils/dense_map.rs create mode 100644 crates/euclid_macros/Cargo.toml create mode 100644 crates/euclid_macros/src/inner.rs create mode 100644 crates/euclid_macros/src/inner/enum_nums.rs create mode 100644 crates/euclid_macros/src/inner/knowledge.rs create mode 100644 crates/euclid_macros/src/lib.rs create mode 100644 crates/euclid_wasm/Cargo.toml create mode 100644 crates/euclid_wasm/src/lib.rs create mode 100644 crates/euclid_wasm/src/types.rs create mode 100644 crates/euclid_wasm/src/utils.rs create mode 100644 crates/kgraph_utils/Cargo.toml create mode 100644 crates/kgraph_utils/benches/evaluation.rs create mode 100644 crates/kgraph_utils/src/error.rs create mode 100644 crates/kgraph_utils/src/lib.rs create mode 100644 crates/kgraph_utils/src/mca.rs create mode 100644 crates/kgraph_utils/src/transformers.rs create mode 100644 crates/router/src/core/payments/routing.rs create mode 100644 crates/router/src/core/payments/routing/transformers.rs create mode 100644 crates/router/src/core/routing.rs create mode 100644 crates/router/src/core/routing/helpers.rs create mode 100644 crates/router/src/core/routing/transformers.rs create mode 100644 crates/router/src/db/routing_algorithm.rs create mode 100644 crates/router/src/routes/routing.rs create mode 100644 crates/router/src/types/api/routing.rs create mode 100644 crates/router/src/types/storage/routing_algorithm.rs create mode 100644 migrations/2023-10-19-101558_create_routing_algorithm_table/down.sql create mode 100644 migrations/2023-10-19-101558_create_routing_algorithm_table/up.sql diff --git a/Cargo.lock b/Cargo.lock index 665703f3d505..886a8b50acc8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -376,6 +376,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstyle" version = "1.0.0" @@ -397,6 +403,7 @@ dependencies = [ "common_enums", "common_utils", "error-stack", + "euclid", "masking", "mime", "reqwest", @@ -1343,6 +1350,12 @@ dependencies = [ "thiserror", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.0.83" @@ -1413,6 +1426,33 @@ dependencies = [ "phf_codegen", ] +[[package]] +name = "ciborium" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656" + +[[package]] +name = "ciborium-ll" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clap" version = "4.3.4" @@ -1497,6 +1537,7 @@ dependencies = [ "reqwest", "ring", "router_env", + "rustc-hash", "serde", "serde_json", "serde_urlencoded", @@ -1615,6 +1656,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-channel" version = "0.5.8" @@ -2022,6 +2099,15 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "erased-serde" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c138974f9d5e7fe373eb04df7cae98833802ae4b11c24ac7039a21d5af4b26c" +dependencies = [ + "serde", +] + [[package]] name = "errno" version = "0.3.4" @@ -2063,6 +2149,52 @@ dependencies = [ "serde", ] +[[package]] +name = "euclid" +version = "0.1.0" +dependencies = [ + "common_enums", + "criterion", + "erased-serde", + "euclid_macros", + "frunk", + "frunk_core", + "nom", + "once_cell", + "rustc-hash", + "serde", + "serde_json", + "strum 0.25.0", + "thiserror", +] + +[[package]] +name = "euclid_macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "rustc-hash", + "strum 0.24.1", + "syn 1.0.109", +] + +[[package]] +name = "euclid_wasm" +version = "0.1.0" +dependencies = [ + "api_models", + "euclid", + "getrandom 0.2.10", + "kgraph_utils", + "once_cell", + "ron-parser", + "serde", + "serde-wasm-bindgen", + "strum 0.25.0", + "wasm-bindgen", +] + [[package]] name = "event-listener" version = "2.5.3" @@ -2415,8 +2547,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.0+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -2497,6 +2631,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" + [[package]] name = "hashbrown" version = "0.12.3" @@ -2811,6 +2951,17 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" +[[package]] +name = "is-terminal" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" +dependencies = [ + "hermit-abi", + "rustix 0.38.17", + "windows-sys", +] + [[package]] name = "itertools" version = "0.10.5" @@ -2905,6 +3056,19 @@ dependencies = [ "simple_asn1", ] +[[package]] +name = "kgraph_utils" +version = "0.1.0" +dependencies = [ + "api_models", + "criterion", + "euclid", + "masking", + "serde", + "serde_json", + "thiserror", +] + [[package]] name = "language-tags" version = "0.3.2" @@ -3365,6 +3529,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44d11de466f4a3006fe8a5e7ec84e93b79c70cb992ae0aa0eb631ad2df8abfe2" +[[package]] +name = "oorandom" +version = "11.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" + [[package]] name = "opaque-debug" version = "0.3.0" @@ -3729,6 +3899,34 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "plotters" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" + +[[package]] +name = "plotters-svg" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +dependencies = [ + "plotters-backend", +] + [[package]] name = "png" version = "0.16.8" @@ -4216,6 +4414,19 @@ dependencies = [ "serde", ] +[[package]] +name = "ron-parser" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c7280c46017fafbe4275179689e446a9b0db3bd91ea61aaee22841ef618405a" +dependencies = [ + "nom", + "serde", + "serde-wasm-bindgen", + "serde_json", + "wasm-bindgen", +] + [[package]] name = "router" version = "0.2.0" @@ -4248,6 +4459,7 @@ dependencies = [ "dyn-clone", "encoding_rs", "error-stack", + "euclid", "external_services", "futures", "hex", @@ -4257,6 +4469,7 @@ dependencies = [ "infer 0.13.0", "josekit", "jsonwebtoken", + "kgraph_utils", "literally", "masking", "maud", @@ -4268,6 +4481,7 @@ dependencies = [ "openssl", "qrcode", "rand 0.8.5", + "rand_chacha 0.3.1", "redis_interface", "regex", "reqwest", @@ -4275,6 +4489,7 @@ dependencies = [ "router_derive", "router_env", "roxmltree", + "rustc-hash", "scheduler", "serde", "serde_json", @@ -4651,6 +4866,17 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-wasm-bindgen" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3b143e2833c57ab9ad3ea280d21fd34e285a42837aeb0ee301f4f41890fa00e" +dependencies = [ + "js-sys", + "serde", + "wasm-bindgen", +] + [[package]] name = "serde_derive" version = "1.0.188" @@ -5349,6 +5575,16 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.6.0" diff --git a/crates/api_models/Cargo.toml b/crates/api_models/Cargo.toml index ce61d30d36f5..d15fdeabf387 100644 --- a/crates/api_models/Cargo.toml +++ b/crates/api_models/Cargo.toml @@ -9,8 +9,12 @@ license.workspace = true [features] default = ["payouts"] +business_profile_routing = [] +connector_choice_bcompat = [] errors = ["dep:actix-web", "dep:reqwest"] -dummy_connector = ["common_enums/dummy_connector"] +backwards_compatibility = ["connector_choice_bcompat"] +connector_choice_mca_id = ["euclid/connector_choice_mca_id"] +dummy_connector = ["common_enums/dummy_connector", "euclid/dummy_connector"] detailed_errors = [] payouts = [] @@ -32,5 +36,6 @@ thiserror = "1.0.40" cards = { version = "0.1.0", path = "../cards" } common_enums = { path = "../common_enums" } common_utils = { version = "0.1.0", path = "../common_utils" } +euclid = { version = "0.1.0", path = "../euclid" } masking = { version = "0.1.0", path = "../masking" } router_derive = { version = "0.1.0", path = "../router_derive" } diff --git a/crates/api_models/src/admin.rs b/crates/api_models/src/admin.rs index b1a258e6b26c..037d223754a0 100644 --- a/crates/api_models/src/admin.rs +++ b/crates/api_models/src/admin.rs @@ -443,72 +443,6 @@ pub mod payout_routing_algorithm { } } -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(tag = "type", content = "data", rename_all = "snake_case")] -pub enum RoutingAlgorithm { - Single(RoutableConnectorChoice), -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(untagged)] -pub enum RoutableConnectorChoice { - ConnectorName(api_enums::RoutableConnectors), - ConnectorId { - merchant_connector_id: String, - connector: api_enums::RoutableConnectors, - }, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde( - tag = "type", - content = "data", - rename_all = "snake_case", - from = "StraightThroughAlgorithmSerde", - into = "StraightThroughAlgorithmSerde" -)] -pub enum StraightThroughAlgorithm { - Single(RoutableConnectorChoice), -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(tag = "type", content = "data", rename_all = "snake_case")] -pub enum StraightThroughAlgorithmInner { - Single(RoutableConnectorChoice), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum StraightThroughAlgorithmSerde { - Direct(StraightThroughAlgorithmInner), - Nested { - algorithm: StraightThroughAlgorithmInner, - }, -} - -impl From for StraightThroughAlgorithm { - fn from(value: StraightThroughAlgorithmSerde) -> Self { - let inner = match value { - StraightThroughAlgorithmSerde::Direct(algorithm) => algorithm, - StraightThroughAlgorithmSerde::Nested { algorithm } => algorithm, - }; - - match inner { - StraightThroughAlgorithmInner::Single(conn) => Self::Single(conn), - } - } -} - -impl From for StraightThroughAlgorithmSerde { - fn from(value: StraightThroughAlgorithm) -> Self { - let inner = match value { - StraightThroughAlgorithm::Single(conn) => StraightThroughAlgorithmInner::Single(conn), - }; - - Self::Nested { algorithm: inner } - } -} - #[derive(Clone, Debug, Deserialize, ToSchema, Serialize, PartialEq)] #[serde(deny_unknown_fields)] pub struct PrimaryBusinessDetails { diff --git a/crates/api_models/src/lib.rs b/crates/api_models/src/lib.rs index dab1b46adbad..ec272514e38a 100644 --- a/crates/api_models/src/lib.rs +++ b/crates/api_models/src/lib.rs @@ -17,5 +17,6 @@ pub mod payments; #[cfg(feature = "payouts")] pub mod payouts; pub mod refunds; +pub mod routing; pub mod verifications; pub mod webhooks; diff --git a/crates/api_models/src/routing.rs b/crates/api_models/src/routing.rs new file mode 100644 index 000000000000..95d4c5e10ece --- /dev/null +++ b/crates/api_models/src/routing.rs @@ -0,0 +1,594 @@ +use std::fmt::Debug; + +use common_utils::errors::ParsingError; +use error_stack::IntoReport; +use euclid::{ + dssa::types::EuclidAnalysable, + enums as euclid_enums, + frontend::{ + ast, + dir::{DirKeyKind, EuclidDirFilter}, + }, +}; +use serde::{Deserialize, Serialize}; + +use crate::enums::{self, RoutableConnectors}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", content = "data", rename_all = "snake_case")] +pub enum ConnectorSelection { + Priority(Vec), + VolumeSplit(Vec), +} + +impl ConnectorSelection { + pub fn get_connector_list(&self) -> Vec { + match self { + Self::Priority(list) => list.clone(), + Self::VolumeSplit(splits) => { + splits.iter().map(|split| split.connector.clone()).collect() + } + } + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct RoutingConfigRequest { + pub name: Option, + pub description: Option, + pub algorithm: Option, + pub profile_id: Option, +} + +#[cfg(feature = "business_profile_routing")] +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub struct RoutingRetrieveQuery { + pub limit: Option, + pub offset: Option, + + pub profile_id: Option, +} + +#[cfg(feature = "business_profile_routing")] +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub struct RoutingRetrieveLinkQuery { + pub profile_id: Option, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct RoutingRetrieveResponse { + pub algorithm: Option, +} + +#[derive(Debug, serde::Serialize)] +#[serde(untagged)] +pub enum LinkedRoutingConfigRetrieveResponse { + MerchantAccountBased(RoutingRetrieveResponse), + ProfileBased(Vec), +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct MerchantRoutingAlgorithm { + pub id: String, + #[cfg(feature = "business_profile_routing")] + pub profile_id: String, + pub name: String, + pub description: String, + pub algorithm: RoutingAlgorithm, + pub created_at: i64, + pub modified_at: i64, +} + +impl EuclidDirFilter for ConnectorSelection { + const ALLOWED: &'static [DirKeyKind] = &[ + DirKeyKind::PaymentMethod, + DirKeyKind::CardBin, + DirKeyKind::CardType, + DirKeyKind::CardNetwork, + DirKeyKind::PayLaterType, + DirKeyKind::WalletType, + DirKeyKind::UpiType, + DirKeyKind::BankRedirectType, + DirKeyKind::BankDebitType, + DirKeyKind::CryptoType, + DirKeyKind::MetaData, + DirKeyKind::PaymentAmount, + DirKeyKind::PaymentCurrency, + DirKeyKind::AuthenticationType, + DirKeyKind::MandateAcceptanceType, + DirKeyKind::MandateType, + DirKeyKind::PaymentType, + DirKeyKind::SetupFutureUsage, + DirKeyKind::CaptureMethod, + DirKeyKind::BillingCountry, + DirKeyKind::BusinessCountry, + DirKeyKind::BusinessLabel, + DirKeyKind::MetaData, + DirKeyKind::RewardType, + DirKeyKind::VoucherType, + DirKeyKind::CardRedirectType, + DirKeyKind::BankTransferType, + ]; +} + +impl EuclidAnalysable for ConnectorSelection { + fn get_dir_value_for_analysis( + &self, + rule_name: String, + ) -> Vec<(euclid::frontend::dir::DirValue, euclid::types::Metadata)> { + self.get_connector_list() + .into_iter() + .map(|connector_choice| { + let connector_name = connector_choice.connector.to_string(); + #[cfg(not(feature = "connector_choice_mca_id"))] + let sub_label = connector_choice.sub_label.clone(); + #[cfg(feature = "connector_choice_mca_id")] + let mca_id = connector_choice.merchant_connector_id.clone(); + + ( + euclid::frontend::dir::DirValue::Connector(Box::new(connector_choice.into())), + std::collections::HashMap::from_iter([( + "CONNECTOR_SELECTION".to_string(), + #[cfg(feature = "connector_choice_mca_id")] + serde_json::json!({ + "rule_name": rule_name, + "connector_name": connector_name, + "mca_id": mca_id, + }), + #[cfg(not(feature = "connector_choice_mca_id"))] + serde_json ::json!({ + "rule_name": rule_name, + "connector_name": connector_name, + "sub_label": sub_label, + }), + )]), + ) + }) + .collect() + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct ConnectorVolumeSplit { + pub connector: RoutableConnectorChoice, + pub split: u8, +} + +#[cfg(feature = "connector_choice_bcompat")] +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +pub enum RoutableChoiceKind { + OnlyConnector, + FullStruct, +} + +#[cfg(feature = "connector_choice_bcompat")] +#[derive(Debug, serde::Deserialize, serde::Serialize)] +#[serde(untagged)] +pub enum RoutableChoiceSerde { + OnlyConnector(Box), + FullStruct { + connector: RoutableConnectors, + #[cfg(feature = "connector_choice_mca_id")] + merchant_connector_id: Option, + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label: Option, + }, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[cfg_attr( + feature = "connector_choice_bcompat", + serde(from = "RoutableChoiceSerde"), + serde(into = "RoutableChoiceSerde") +)] +#[cfg_attr(not(feature = "connector_choice_bcompat"), derive(PartialEq, Eq))] +pub struct RoutableConnectorChoice { + #[cfg(feature = "connector_choice_bcompat")] + pub choice_kind: RoutableChoiceKind, + pub connector: RoutableConnectors, + #[cfg(feature = "connector_choice_mca_id")] + pub merchant_connector_id: Option, + #[cfg(not(feature = "connector_choice_mca_id"))] + pub sub_label: Option, +} + +impl ToString for RoutableConnectorChoice { + fn to_string(&self) -> String { + #[cfg(feature = "connector_choice_mca_id")] + let base = self.connector.to_string(); + + #[cfg(not(feature = "connector_choice_mca_id"))] + let base = { + let mut sub_base = self.connector.to_string(); + if let Some(ref label) = self.sub_label { + sub_base.push('_'); + sub_base.push_str(label); + } + + sub_base + }; + + base + } +} + +#[cfg(feature = "connector_choice_bcompat")] +impl PartialEq for RoutableConnectorChoice { + fn eq(&self, other: &Self) -> bool { + #[cfg(not(feature = "connector_choice_mca_id"))] + { + self.connector.eq(&other.connector) && self.sub_label.eq(&other.sub_label) + } + + #[cfg(feature = "connector_choice_mca_id")] + { + self.connector.eq(&other.connector) + && self.merchant_connector_id.eq(&other.merchant_connector_id) + } + } +} + +#[cfg(feature = "connector_choice_bcompat")] +impl Eq for RoutableConnectorChoice {} + +#[cfg(feature = "connector_choice_bcompat")] +impl From for RoutableConnectorChoice { + fn from(value: RoutableChoiceSerde) -> Self { + match value { + RoutableChoiceSerde::OnlyConnector(connector) => Self { + choice_kind: RoutableChoiceKind::OnlyConnector, + connector: *connector, + #[cfg(feature = "connector_choice_mca_id")] + merchant_connector_id: None, + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label: None, + }, + + RoutableChoiceSerde::FullStruct { + connector, + #[cfg(feature = "connector_choice_mca_id")] + merchant_connector_id, + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label, + } => Self { + choice_kind: RoutableChoiceKind::FullStruct, + connector, + #[cfg(feature = "connector_choice_mca_id")] + merchant_connector_id, + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label, + }, + } + } +} + +#[cfg(feature = "connector_choice_bcompat")] +impl From for RoutableChoiceSerde { + fn from(value: RoutableConnectorChoice) -> Self { + match value.choice_kind { + RoutableChoiceKind::OnlyConnector => Self::OnlyConnector(Box::new(value.connector)), + RoutableChoiceKind::FullStruct => Self::FullStruct { + connector: value.connector, + #[cfg(feature = "connector_choice_mca_id")] + merchant_connector_id: value.merchant_connector_id, + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label: value.sub_label, + }, + } + } +} + +impl From for ast::ConnectorChoice { + fn from(value: RoutableConnectorChoice) -> Self { + Self { + connector: match value.connector { + #[cfg(feature = "dummy_connector")] + RoutableConnectors::DummyConnector1 => euclid_enums::Connector::DummyConnector1, + #[cfg(feature = "dummy_connector")] + RoutableConnectors::DummyConnector2 => euclid_enums::Connector::DummyConnector2, + #[cfg(feature = "dummy_connector")] + RoutableConnectors::DummyConnector3 => euclid_enums::Connector::DummyConnector3, + #[cfg(feature = "dummy_connector")] + RoutableConnectors::DummyConnector4 => euclid_enums::Connector::DummyConnector4, + #[cfg(feature = "dummy_connector")] + RoutableConnectors::DummyConnector5 => euclid_enums::Connector::DummyConnector5, + #[cfg(feature = "dummy_connector")] + RoutableConnectors::DummyConnector6 => euclid_enums::Connector::DummyConnector6, + #[cfg(feature = "dummy_connector")] + RoutableConnectors::DummyConnector7 => euclid_enums::Connector::DummyConnector7, + RoutableConnectors::Aci => euclid_enums::Connector::Aci, + RoutableConnectors::Adyen => euclid_enums::Connector::Adyen, + RoutableConnectors::Airwallex => euclid_enums::Connector::Airwallex, + RoutableConnectors::Authorizedotnet => euclid_enums::Connector::Authorizedotnet, + RoutableConnectors::Bitpay => euclid_enums::Connector::Bitpay, + RoutableConnectors::Bambora => euclid_enums::Connector::Bambora, + RoutableConnectors::Bluesnap => euclid_enums::Connector::Bluesnap, + RoutableConnectors::Boku => euclid_enums::Connector::Boku, + RoutableConnectors::Braintree => euclid_enums::Connector::Braintree, + RoutableConnectors::Cashtocode => euclid_enums::Connector::Cashtocode, + RoutableConnectors::Checkout => euclid_enums::Connector::Checkout, + RoutableConnectors::Coinbase => euclid_enums::Connector::Coinbase, + RoutableConnectors::Cryptopay => euclid_enums::Connector::Cryptopay, + RoutableConnectors::Cybersource => euclid_enums::Connector::Cybersource, + RoutableConnectors::Dlocal => euclid_enums::Connector::Dlocal, + RoutableConnectors::Fiserv => euclid_enums::Connector::Fiserv, + RoutableConnectors::Forte => euclid_enums::Connector::Forte, + RoutableConnectors::Globalpay => euclid_enums::Connector::Globalpay, + RoutableConnectors::Globepay => euclid_enums::Connector::Globepay, + RoutableConnectors::Gocardless => euclid_enums::Connector::Gocardless, + RoutableConnectors::Helcim => euclid_enums::Connector::Helcim, + RoutableConnectors::Iatapay => euclid_enums::Connector::Iatapay, + RoutableConnectors::Klarna => euclid_enums::Connector::Klarna, + RoutableConnectors::Mollie => euclid_enums::Connector::Mollie, + RoutableConnectors::Multisafepay => euclid_enums::Connector::Multisafepay, + RoutableConnectors::Nexinets => euclid_enums::Connector::Nexinets, + RoutableConnectors::Nmi => euclid_enums::Connector::Nmi, + RoutableConnectors::Noon => euclid_enums::Connector::Noon, + RoutableConnectors::Nuvei => euclid_enums::Connector::Nuvei, + RoutableConnectors::Opennode => euclid_enums::Connector::Opennode, + RoutableConnectors::Payme => euclid_enums::Connector::Payme, + RoutableConnectors::Paypal => euclid_enums::Connector::Paypal, + RoutableConnectors::Payu => euclid_enums::Connector::Payu, + RoutableConnectors::Powertranz => euclid_enums::Connector::Powertranz, + RoutableConnectors::Rapyd => euclid_enums::Connector::Rapyd, + RoutableConnectors::Shift4 => euclid_enums::Connector::Shift4, + RoutableConnectors::Square => euclid_enums::Connector::Square, + RoutableConnectors::Stax => euclid_enums::Connector::Stax, + RoutableConnectors::Stripe => euclid_enums::Connector::Stripe, + RoutableConnectors::Trustpay => euclid_enums::Connector::Trustpay, + RoutableConnectors::Tsys => euclid_enums::Connector::Tsys, + RoutableConnectors::Volt => euclid_enums::Connector::Volt, + RoutableConnectors::Wise => euclid_enums::Connector::Wise, + RoutableConnectors::Worldline => euclid_enums::Connector::Worldline, + RoutableConnectors::Worldpay => euclid_enums::Connector::Worldpay, + RoutableConnectors::Zen => euclid_enums::Connector::Zen, + }, + + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label: value.sub_label, + } + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct DetailedConnectorChoice { + pub connector: RoutableConnectors, + pub business_label: Option, + pub business_country: Option, + pub business_sub_label: Option, +} + +impl DetailedConnectorChoice { + pub fn get_connector_label(&self) -> Option { + self.business_country + .as_ref() + .zip(self.business_label.as_ref()) + .map(|(business_country, business_label)| { + let mut base_label = format!( + "{}_{:?}_{}", + self.connector, business_country, business_label + ); + + if let Some(ref sub_label) = self.business_sub_label { + base_label.push('_'); + base_label.push_str(sub_label); + } + + base_label + }) + } +} + +#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize, strum::Display)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum RoutingAlgorithmKind { + Single, + Priority, + VolumeSplit, + Advanced, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde( + tag = "type", + content = "data", + rename_all = "snake_case", + try_from = "RoutingAlgorithmSerde" +)] +pub enum RoutingAlgorithm { + Single(Box), + Priority(Vec), + VolumeSplit(Vec), + Advanced(euclid::frontend::ast::Program), +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(tag = "type", content = "data", rename_all = "snake_case")] +pub enum RoutingAlgorithmSerde { + Single(Box), + Priority(Vec), + VolumeSplit(Vec), + Advanced(euclid::frontend::ast::Program), +} + +impl TryFrom for RoutingAlgorithm { + type Error = error_stack::Report; + + fn try_from(value: RoutingAlgorithmSerde) -> Result { + match &value { + RoutingAlgorithmSerde::Priority(i) if i.is_empty() => { + Err(ParsingError::StructParseFailure( + "Connectors list can't be empty for Priority Algorithm", + )) + .into_report()? + } + RoutingAlgorithmSerde::VolumeSplit(i) if i.is_empty() => { + Err(ParsingError::StructParseFailure( + "Connectors list can't be empty for Volume split Algorithm", + )) + .into_report()? + } + _ => {} + }; + Ok(match value { + RoutingAlgorithmSerde::Single(i) => Self::Single(i), + RoutingAlgorithmSerde::Priority(i) => Self::Priority(i), + RoutingAlgorithmSerde::VolumeSplit(i) => Self::VolumeSplit(i), + RoutingAlgorithmSerde::Advanced(i) => Self::Advanced(i), + }) + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde( + tag = "type", + content = "data", + rename_all = "snake_case", + try_from = "StraightThroughAlgorithmSerde", + into = "StraightThroughAlgorithmSerde" +)] +pub enum StraightThroughAlgorithm { + Single(Box), + Priority(Vec), + VolumeSplit(Vec), +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(tag = "type", content = "data", rename_all = "snake_case")] +pub enum StraightThroughAlgorithmInner { + Single(Box), + Priority(Vec), + VolumeSplit(Vec), +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(untagged)] +pub enum StraightThroughAlgorithmSerde { + Direct(StraightThroughAlgorithmInner), + Nested { + algorithm: StraightThroughAlgorithmInner, + }, +} + +impl TryFrom for StraightThroughAlgorithm { + type Error = error_stack::Report; + + fn try_from(value: StraightThroughAlgorithmSerde) -> Result { + let inner = match value { + StraightThroughAlgorithmSerde::Direct(algorithm) => algorithm, + StraightThroughAlgorithmSerde::Nested { algorithm } => algorithm, + }; + + match &inner { + StraightThroughAlgorithmInner::Priority(i) if i.is_empty() => { + Err(ParsingError::StructParseFailure( + "Connectors list can't be empty for Priority Algorithm", + )) + .into_report()? + } + StraightThroughAlgorithmInner::VolumeSplit(i) if i.is_empty() => { + Err(ParsingError::StructParseFailure( + "Connectors list can't be empty for Volume split Algorithm", + )) + .into_report()? + } + _ => {} + }; + + Ok(match inner { + StraightThroughAlgorithmInner::Single(single) => Self::Single(single), + StraightThroughAlgorithmInner::Priority(plist) => Self::Priority(plist), + StraightThroughAlgorithmInner::VolumeSplit(vsplit) => Self::VolumeSplit(vsplit), + }) + } +} + +impl From for StraightThroughAlgorithmSerde { + fn from(value: StraightThroughAlgorithm) -> Self { + let inner = match value { + StraightThroughAlgorithm::Single(conn) => StraightThroughAlgorithmInner::Single(conn), + StraightThroughAlgorithm::Priority(plist) => { + StraightThroughAlgorithmInner::Priority(plist) + } + StraightThroughAlgorithm::VolumeSplit(vsplit) => { + StraightThroughAlgorithmInner::VolumeSplit(vsplit) + } + }; + + Self::Nested { algorithm: inner } + } +} + +impl From for RoutingAlgorithm { + fn from(value: StraightThroughAlgorithm) -> Self { + match value { + StraightThroughAlgorithm::Single(conn) => Self::Single(conn), + StraightThroughAlgorithm::Priority(conns) => Self::Priority(conns), + StraightThroughAlgorithm::VolumeSplit(splits) => Self::VolumeSplit(splits), + } + } +} + +impl RoutingAlgorithm { + pub fn get_kind(&self) -> RoutingAlgorithmKind { + match self { + Self::Single(_) => RoutingAlgorithmKind::Single, + Self::Priority(_) => RoutingAlgorithmKind::Priority, + Self::VolumeSplit(_) => RoutingAlgorithmKind::VolumeSplit, + Self::Advanced(_) => RoutingAlgorithmKind::Advanced, + } + } +} + +#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)] +pub struct RoutingAlgorithmRef { + pub algorithm_id: Option, + pub timestamp: i64, + pub config_algo_id: Option, + pub surcharge_config_algo_id: Option, +} + +impl RoutingAlgorithmRef { + pub fn update_algorithm_id(&mut self, new_id: String) { + self.algorithm_id = Some(new_id); + self.timestamp = common_utils::date_time::now_unix_timestamp(); + } + + pub fn update_conditional_config_id(&mut self, ids: String) { + self.config_algo_id = Some(ids); + self.timestamp = common_utils::date_time::now_unix_timestamp(); + } + + pub fn update_surcharge_config_id(&mut self, ids: String) { + self.surcharge_config_algo_id = Some(ids); + self.timestamp = common_utils::date_time::now_unix_timestamp(); + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] + +pub struct RoutingDictionaryRecord { + pub id: String, + #[cfg(feature = "business_profile_routing")] + pub profile_id: String, + pub name: String, + pub kind: RoutingAlgorithmKind, + pub description: String, + pub created_at: i64, + pub modified_at: i64, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct RoutingDictionary { + pub merchant_id: String, + pub active_id: Option, + pub records: Vec, +} + +#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[serde(untagged)] +pub enum RoutingKind { + Config(RoutingDictionary), + RoutingAlgorithm(Vec), +} diff --git a/crates/common_utils/Cargo.toml b/crates/common_utils/Cargo.toml index e319cf86ccd0..c1fd91a351c7 100644 --- a/crates/common_utils/Cargo.toml +++ b/crates/common_utils/Cargo.toml @@ -28,6 +28,7 @@ rand = "0.8.5" regex = "1.8.4" reqwest = { version = "0.11.18", features = ["json", "native-tls", "gzip", "multipart"] } ring = { version = "0.16.20", features = ["std"] } +rustc-hash = "1.1.0" serde = { version = "1.0.163", features = ["derive"] } serde_json = "1.0.96" serde_urlencoded = "0.7.1" diff --git a/crates/common_utils/src/lib.rs b/crates/common_utils/src/lib.rs index ca6bba480063..724c3bca0a27 100644 --- a/crates/common_utils/src/lib.rs +++ b/crates/common_utils/src/lib.rs @@ -13,6 +13,8 @@ pub mod pii; pub mod request; #[cfg(feature = "signals")] pub mod signals; +#[allow(missing_docs)] // Todo: add docs +pub mod static_cache; pub mod types; pub mod validation; diff --git a/crates/common_utils/src/static_cache.rs b/crates/common_utils/src/static_cache.rs new file mode 100644 index 000000000000..ca608fa9a3b5 --- /dev/null +++ b/crates/common_utils/src/static_cache.rs @@ -0,0 +1,91 @@ +use std::sync::{Arc, RwLock}; + +use once_cell::sync::Lazy; +use rustc_hash::FxHashMap; + +#[derive(Debug)] +pub struct CacheEntry { + data: Arc, + timestamp: i64, +} + +#[derive(Debug, Clone, thiserror::Error)] +pub enum CacheError { + #[error("Could not acquire the lock for cache entry")] + CouldNotAcquireLock, + #[error("Entry not found in cache")] + EntryNotFound, +} + +#[derive(Debug)] +pub struct StaticCache { + data: Lazy>>>, +} + +impl StaticCache +where + T: Send, +{ + pub const fn new() -> Self { + Self { + data: Lazy::new(|| RwLock::new(FxHashMap::default())), + } + } + + pub fn present(&self, key: &String) -> Result { + let the_map = self + .data + .read() + .map_err(|_| CacheError::CouldNotAcquireLock)?; + + Ok(the_map.get(key).is_some()) + } + + pub fn expired(&self, key: &String, timestamp: i64) -> Result { + let the_map = self + .data + .read() + .map_err(|_| CacheError::CouldNotAcquireLock)?; + + Ok(match the_map.get(key) { + None => false, + Some(entry) => timestamp > entry.timestamp, + }) + } + + pub fn retrieve(&self, key: &String) -> Result, CacheError> { + let the_map = self + .data + .read() + .map_err(|_| CacheError::CouldNotAcquireLock)?; + + let cache_entry = the_map.get(key).ok_or(CacheError::EntryNotFound)?; + + Ok(Arc::clone(&cache_entry.data)) + } + + pub fn save(&self, key: String, data: T, timestamp: i64) -> Result<(), CacheError> { + let mut the_map = self + .data + .write() + .map_err(|_| CacheError::CouldNotAcquireLock)?; + + let entry = CacheEntry { + data: Arc::new(data), + timestamp, + }; + + the_map.insert(key, entry); + Ok(()) + } + + pub fn clear(&self) -> Result<(), CacheError> { + let mut the_map = self + .data + .write() + .map_err(|_| CacheError::CouldNotAcquireLock)?; + + the_map.clear(); + Ok(()) + } +} diff --git a/crates/diesel_models/src/enums.rs b/crates/diesel_models/src/enums.rs index b73eeefbb10b..0e06a324f038 100644 --- a/crates/diesel_models/src/enums.rs +++ b/crates/diesel_models/src/enums.rs @@ -14,6 +14,7 @@ pub mod diesel_exports { DbPaymentType as PaymentType, DbPayoutStatus as PayoutStatus, DbPayoutType as PayoutType, DbProcessTrackerStatus as ProcessTrackerStatus, DbReconStatus as ReconStatus, DbRefundStatus as RefundStatus, DbRefundType as RefundType, + DbRoutingAlgorithmKind as RoutingAlgorithmKind, }; } pub use common_enums::*; @@ -21,6 +22,27 @@ use common_utils::pii; use diesel::serialize::{Output, ToSql}; use time::PrimitiveDateTime; +#[derive( + Clone, + Copy, + Debug, + Eq, + PartialEq, + serde::Deserialize, + serde::Serialize, + strum::Display, + strum::EnumString, +)] +#[router_derive::diesel_enum(storage_type = "pg_enum")] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum RoutingAlgorithmKind { + Single, + Priority, + VolumeSplit, + Advanced, +} + #[derive( Clone, Copy, diff --git a/crates/diesel_models/src/lib.rs b/crates/diesel_models/src/lib.rs index 528446678015..2d459499a1bd 100644 --- a/crates/diesel_models/src/lib.rs +++ b/crates/diesel_models/src/lib.rs @@ -34,6 +34,7 @@ pub mod process_tracker; pub mod query; pub mod refund; pub mod reverse_lookup; +pub mod routing_algorithm; #[allow(unused_qualifications)] pub mod schema; diff --git a/crates/diesel_models/src/query.rs b/crates/diesel_models/src/query.rs index 6b705e29873e..aeb09b969f13 100644 --- a/crates/diesel_models/src/query.rs +++ b/crates/diesel_models/src/query.rs @@ -26,3 +26,4 @@ pub mod payouts; pub mod process_tracker; pub mod refund; pub mod reverse_lookup; +pub mod routing_algorithm; diff --git a/crates/diesel_models/src/query/routing_algorithm.rs b/crates/diesel_models/src/query/routing_algorithm.rs new file mode 100644 index 000000000000..533ac7194c41 --- /dev/null +++ b/crates/diesel_models/src/query/routing_algorithm.rs @@ -0,0 +1,200 @@ +use async_bb8_diesel::AsyncRunQueryDsl; +use diesel::{associations::HasTable, BoolExpressionMethods, ExpressionMethods, QueryDsl}; +use error_stack::{IntoReport, ResultExt}; +use router_env::tracing::{self, instrument}; +use time::PrimitiveDateTime; + +use crate::{ + enums, + errors::DatabaseError, + query::generics, + routing_algorithm::{RoutingAlgorithm, RoutingAlgorithmMetadata, RoutingProfileMetadata}, + schema::routing_algorithm::dsl, + PgPooledConn, StorageResult, +}; + +impl RoutingAlgorithm { + #[instrument(skip(conn))] + pub async fn insert(self, conn: &PgPooledConn) -> StorageResult { + generics::generic_insert(conn, self).await + } + + #[instrument(skip(conn))] + pub async fn find_by_algorithm_id_merchant_id( + conn: &PgPooledConn, + algorithm_id: &str, + merchant_id: &str, + ) -> StorageResult { + generics::generic_find_one::<::Table, _, _>( + conn, + dsl::algorithm_id + .eq(algorithm_id.to_owned()) + .and(dsl::merchant_id.eq(merchant_id.to_owned())), + ) + .await + } + + #[instrument(skip(conn))] + pub async fn find_by_algorithm_id_profile_id( + conn: &PgPooledConn, + algorithm_id: &str, + profile_id: &str, + ) -> StorageResult { + generics::generic_find_one::<::Table, _, _>( + conn, + dsl::algorithm_id + .eq(algorithm_id.to_owned()) + .and(dsl::profile_id.eq(profile_id.to_owned())), + ) + .await + } + + #[instrument(skip(conn))] + pub async fn find_metadata_by_algorithm_id_profile_id( + conn: &PgPooledConn, + algorithm_id: &str, + profile_id: &str, + ) -> StorageResult { + Self::table() + .select(( + dsl::profile_id, + dsl::algorithm_id, + dsl::name, + dsl::description, + dsl::kind, + dsl::created_at, + dsl::modified_at, + )) + .filter( + dsl::algorithm_id + .eq(algorithm_id.to_owned()) + .and(dsl::profile_id.eq(profile_id.to_owned())), + ) + .limit(1) + .load_async::<( + String, + String, + String, + Option, + enums::RoutingAlgorithmKind, + PrimitiveDateTime, + PrimitiveDateTime, + )>(conn) + .await + .into_report() + .change_context(DatabaseError::Others)? + .into_iter() + .next() + .ok_or(DatabaseError::NotFound) + .into_report() + .map( + |(profile_id, algorithm_id, name, description, kind, created_at, modified_at)| { + RoutingProfileMetadata { + profile_id, + algorithm_id, + name, + description, + kind, + created_at, + modified_at, + } + }, + ) + } + + #[instrument(skip(conn))] + pub async fn list_metadata_by_profile_id( + conn: &PgPooledConn, + profile_id: &str, + limit: i64, + offset: i64, + ) -> StorageResult> { + Ok(Self::table() + .select(( + dsl::algorithm_id, + dsl::name, + dsl::description, + dsl::kind, + dsl::created_at, + dsl::modified_at, + )) + .filter(dsl::profile_id.eq(profile_id.to_owned())) + .limit(limit) + .offset(offset) + .load_async::<( + String, + String, + Option, + enums::RoutingAlgorithmKind, + PrimitiveDateTime, + PrimitiveDateTime, + )>(conn) + .await + .into_report() + .change_context(DatabaseError::Others)? + .into_iter() + .map( + |(algorithm_id, name, description, kind, created_at, modified_at)| { + RoutingAlgorithmMetadata { + algorithm_id, + name, + description, + kind, + created_at, + modified_at, + } + }, + ) + .collect()) + } + + #[instrument(skip(conn))] + pub async fn list_metadata_by_merchant_id( + conn: &PgPooledConn, + merchant_id: &str, + limit: i64, + offset: i64, + ) -> StorageResult> { + Ok(Self::table() + .select(( + dsl::profile_id, + dsl::algorithm_id, + dsl::name, + dsl::description, + dsl::kind, + dsl::created_at, + dsl::modified_at, + )) + .filter(dsl::merchant_id.eq(merchant_id.to_owned())) + .limit(limit) + .offset(offset) + .order(dsl::modified_at.desc()) + .load_async::<( + String, + String, + String, + Option, + enums::RoutingAlgorithmKind, + PrimitiveDateTime, + PrimitiveDateTime, + )>(conn) + .await + .into_report() + .change_context(DatabaseError::Others)? + .into_iter() + .map( + |(profile_id, algorithm_id, name, description, kind, created_at, modified_at)| { + RoutingProfileMetadata { + profile_id, + algorithm_id, + name, + description, + kind, + created_at, + modified_at, + } + }, + ) + .collect()) + } +} diff --git a/crates/diesel_models/src/routing_algorithm.rs b/crates/diesel_models/src/routing_algorithm.rs new file mode 100644 index 000000000000..09f9baf7edb9 --- /dev/null +++ b/crates/diesel_models/src/routing_algorithm.rs @@ -0,0 +1,37 @@ +use diesel::{Identifiable, Insertable, Queryable}; +use serde::{Deserialize, Serialize}; + +use crate::{enums, schema::routing_algorithm}; + +#[derive(Clone, Debug, Identifiable, Insertable, Queryable, Serialize, Deserialize)] +#[diesel(table_name = routing_algorithm, primary_key(algorithm_id))] +pub struct RoutingAlgorithm { + pub algorithm_id: String, + pub profile_id: String, + pub merchant_id: String, + pub name: String, + pub description: Option, + pub kind: enums::RoutingAlgorithmKind, + pub algorithm_data: serde_json::Value, + pub created_at: time::PrimitiveDateTime, + pub modified_at: time::PrimitiveDateTime, +} + +pub struct RoutingAlgorithmMetadata { + pub algorithm_id: String, + pub name: String, + pub description: Option, + pub kind: enums::RoutingAlgorithmKind, + pub created_at: time::PrimitiveDateTime, + pub modified_at: time::PrimitiveDateTime, +} + +pub struct RoutingProfileMetadata { + pub profile_id: String, + pub algorithm_id: String, + pub name: String, + pub description: Option, + pub kind: enums::RoutingAlgorithmKind, + pub created_at: time::PrimitiveDateTime, + pub modified_at: time::PrimitiveDateTime, +} diff --git a/crates/diesel_models/src/schema.rs b/crates/diesel_models/src/schema.rs index e214fa364ddd..2923c719c8f7 100644 --- a/crates/diesel_models/src/schema.rs +++ b/crates/diesel_models/src/schema.rs @@ -874,6 +874,28 @@ diesel::table! { } } +diesel::table! { + use diesel::sql_types::*; + use crate::enums::diesel_exports::*; + + routing_algorithm (algorithm_id) { + #[max_length = 64] + algorithm_id -> Varchar, + #[max_length = 64] + profile_id -> Varchar, + #[max_length = 64] + merchant_id -> Varchar, + #[max_length = 64] + name -> Varchar, + #[max_length = 256] + description -> Nullable, + kind -> RoutingAlgorithmKind, + algorithm_data -> Jsonb, + created_at -> Timestamp, + modified_at -> Timestamp, + } +} + diesel::allow_tables_to_appear_in_same_query!( address, api_keys, @@ -902,4 +924,5 @@ diesel::allow_tables_to_appear_in_same_query!( process_tracker, refund, reverse_lookup, + routing_algorithm, ); diff --git a/crates/euclid/Cargo.toml b/crates/euclid/Cargo.toml new file mode 100644 index 000000000000..f0e24b1ff63c --- /dev/null +++ b/crates/euclid/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "euclid" +description = "DSL for static routing" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true + +[dependencies] +frunk = "0.4.1" +frunk_core = "0.4.1" +nom = { version = "7.1.3", features = ["alloc"], optional = true } +once_cell = "1.18.0" +rustc-hash = "1.1.0" +serde = { version = "1.0.163", features = ["derive", "rc"] } +serde_json = "1.0.96" +erased-serde = "0.3.28" +strum = { version = "0.25", features = ["derive"] } +thiserror = "1.0.43" + +# First party dependencies +common_enums = { version = "0.1.0", path = "../common_enums" } +euclid_macros = { version = "0.1.0", path = "../euclid_macros" } + +[features] +ast_parser = ["dep:nom"] +valued_jit = [] +connector_choice_bcompat = [] +connector_choice_mca_id = [] +dummy_connector = [] +backwards_compatibility = ["connector_choice_bcompat"] + +[dev-dependencies] +criterion = "0.5" + +[[bench]] +name = "backends" +harness = false +required-features = ["ast_parser", "valued_jit"] diff --git a/crates/euclid/benches/backends.rs b/crates/euclid/benches/backends.rs new file mode 100644 index 000000000000..9d29c41d34c6 --- /dev/null +++ b/crates/euclid/benches/backends.rs @@ -0,0 +1,93 @@ +#![allow(unused, clippy::expect_used)] + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use euclid::{ + backend::{inputs, EuclidBackend, InterpreterBackend, VirInterpreterBackend}, + enums, + frontend::ast::{self, parser}, + types::DummyOutput, +}; + +fn get_program_data() -> (ast::Program, inputs::BackendInput) { + let code1 = r#" + default: ["stripe", "adyen", "checkout"] + + stripe_first: ["stripe", "aci"] + { + payment_method = card & amount = 40 { + payment_method = (card, bank_redirect) + amount = (40, 50) + } + } + + adyen_first: ["adyen", "checkout"] + { + payment_method = bank_redirect & amount > 60 { + payment_method = (card, bank_redirect) + amount = (40, 50) + } + } + + auth_first: ["authorizedotnet", "adyen"] + { + payment_method = wallet + } + "#; + + let inp = inputs::BackendInput { + metadata: None, + payment: inputs::PaymentInput { + amount: 32, + card_bin: None, + currency: enums::Currency::USD, + authentication_type: Some(enums::AuthenticationType::NoThreeDs), + capture_method: Some(enums::CaptureMethod::Automatic), + business_country: Some(enums::Country::UnitedStatesOfAmerica), + billing_country: Some(enums::Country::France), + business_label: None, + setup_future_usage: None, + }, + payment_method: inputs::PaymentMethodInput { + payment_method: Some(enums::PaymentMethod::PayLater), + payment_method_type: Some(enums::PaymentMethodType::Sofort), + card_network: None, + }, + mandate: inputs::MandateData { + mandate_acceptance_type: None, + mandate_type: None, + payment_type: None, + }, + }; + + let (_, program) = parser::program(code1).expect("Parser"); + + (program, inp) +} + +fn interpreter_vs_jit_vs_vir_interpreter(c: &mut Criterion) { + let (program, binputs) = get_program_data(); + + let interp_b = InterpreterBackend::with_program(program.clone()).expect("Interpreter backend"); + + let vir_interp_b = + VirInterpreterBackend::with_program(program).expect("Vir Interpreter Backend"); + + c.bench_function("Raw Interpreter Backend", |b| { + b.iter(|| { + interp_b + .execute(binputs.clone()) + .expect("Interpreter EXECUTION"); + }); + }); + + c.bench_function("Valued Interpreter Backend", |b| { + b.iter(|| { + vir_interp_b + .execute(binputs.clone()) + .expect("Vir Interpreter execution"); + }) + }); +} + +criterion_group!(benches, interpreter_vs_jit_vs_vir_interpreter); +criterion_main!(benches); diff --git a/crates/euclid/src/backend.rs b/crates/euclid/src/backend.rs new file mode 100644 index 000000000000..caf0a87b69cb --- /dev/null +++ b/crates/euclid/src/backend.rs @@ -0,0 +1,25 @@ +pub mod inputs; +pub mod interpreter; +#[cfg(feature = "valued_jit")] +pub mod vir_interpreter; + +pub use inputs::BackendInput; +pub use interpreter::InterpreterBackend; +#[cfg(feature = "valued_jit")] +pub use vir_interpreter::VirInterpreterBackend; + +use crate::frontend::ast; + +#[derive(Debug, Clone, serde::Serialize)] +pub struct BackendOutput { + pub rule_name: Option, + pub connector_selection: O, +} + +pub trait EuclidBackend: Sized { + type Error: serde::Serialize; + + fn with_program(program: ast::Program) -> Result; + + fn execute(&self, input: BackendInput) -> Result, Self::Error>; +} diff --git a/crates/euclid/src/backend/inputs.rs b/crates/euclid/src/backend/inputs.rs new file mode 100644 index 000000000000..18298d4c358d --- /dev/null +++ b/crates/euclid/src/backend/inputs.rs @@ -0,0 +1,39 @@ +use rustc_hash::FxHashMap; +use serde::{Deserialize, Serialize}; + +use crate::enums; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MandateData { + pub mandate_acceptance_type: Option, + pub mandate_type: Option, + pub payment_type: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PaymentMethodInput { + pub payment_method: Option, + pub payment_method_type: Option, + pub card_network: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PaymentInput { + pub amount: i64, + pub currency: enums::Currency, + pub authentication_type: Option, + pub card_bin: Option, + pub capture_method: Option, + pub business_country: Option, + pub billing_country: Option, + pub business_label: Option, + pub setup_future_usage: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BackendInput { + pub metadata: Option>, + pub payment: PaymentInput, + pub payment_method: PaymentMethodInput, + pub mandate: MandateData, +} diff --git a/crates/euclid/src/backend/interpreter.rs b/crates/euclid/src/backend/interpreter.rs new file mode 100644 index 000000000000..bf0a561bf3f3 --- /dev/null +++ b/crates/euclid/src/backend/interpreter.rs @@ -0,0 +1,180 @@ +pub mod types; + +use crate::{ + backend::{self, inputs, EuclidBackend}, + frontend::ast, +}; + +pub struct InterpreterBackend { + program: ast::Program, +} + +impl InterpreterBackend +where + O: Clone, +{ + fn eval_number_comparison_array( + num: i64, + array: &[ast::NumberComparison], + ) -> Result { + for comparison in array { + let other = comparison.number; + let res = match comparison.comparison_type { + ast::ComparisonType::GreaterThan => num > other, + ast::ComparisonType::LessThan => num < other, + ast::ComparisonType::LessThanEqual => num <= other, + ast::ComparisonType::GreaterThanEqual => num >= other, + ast::ComparisonType::Equal => num == other, + ast::ComparisonType::NotEqual => num != other, + }; + + if res { + return Ok(true); + } + } + + Ok(false) + } + + fn eval_comparison( + comparison: &ast::Comparison, + ctx: &types::Context, + ) -> Result { + use ast::{ComparisonType::*, ValueType::*}; + + let value = ctx + .get(&comparison.lhs) + .ok_or_else(|| types::InterpreterError { + error_type: types::InterpreterErrorType::InvalidKey(comparison.lhs.clone()), + metadata: comparison.metadata.clone(), + })?; + + if let Some(val) = value { + match (val, &comparison.comparison, &comparison.value) { + (EnumVariant(e1), Equal, EnumVariant(e2)) => Ok(e1 == e2), + (EnumVariant(e1), NotEqual, EnumVariant(e2)) => Ok(e1 != e2), + (EnumVariant(e), Equal, EnumVariantArray(evec)) => Ok(evec.iter().any(|v| e == v)), + (EnumVariant(e), NotEqual, EnumVariantArray(evec)) => { + Ok(evec.iter().all(|v| e != v)) + } + (Number(n1), Equal, Number(n2)) => Ok(n1 == n2), + (Number(n1), NotEqual, Number(n2)) => Ok(n1 != n2), + (Number(n1), LessThanEqual, Number(n2)) => Ok(n1 <= n2), + (Number(n1), GreaterThanEqual, Number(n2)) => Ok(n1 >= n2), + (Number(n1), LessThan, Number(n2)) => Ok(n1 < n2), + (Number(n1), GreaterThan, Number(n2)) => Ok(n1 > n2), + (Number(n), Equal, NumberArray(nvec)) => Ok(nvec.iter().any(|v| v == n)), + (Number(n), NotEqual, NumberArray(nvec)) => Ok(nvec.iter().all(|v| v != n)), + (Number(n), Equal, NumberComparisonArray(ncvec)) => { + Self::eval_number_comparison_array(*n, ncvec) + } + _ => Err(types::InterpreterError { + error_type: types::InterpreterErrorType::InvalidComparison, + metadata: comparison.metadata.clone(), + }), + } + } else { + Ok(false) + } + } + + fn eval_if_condition( + condition: &ast::IfCondition, + ctx: &types::Context, + ) -> Result { + for comparison in condition { + let res = Self::eval_comparison(comparison, ctx)?; + + if !res { + return Ok(false); + } + } + + Ok(true) + } + + fn eval_if_statement( + stmt: &ast::IfStatement, + ctx: &types::Context, + ) -> Result { + let cond_res = Self::eval_if_condition(&stmt.condition, ctx)?; + + if !cond_res { + return Ok(false); + } + + if let Some(ref nested) = stmt.nested { + for nested_if in nested { + let res = Self::eval_if_statement(nested_if, ctx)?; + + if res { + return Ok(true); + } + } + + return Ok(false); + } + + Ok(true) + } + + fn eval_rule_statements( + statements: &[ast::IfStatement], + ctx: &types::Context, + ) -> Result { + for stmt in statements { + let res = Self::eval_if_statement(stmt, ctx)?; + + if res { + return Ok(true); + } + } + + Ok(false) + } + + #[inline] + fn eval_rule( + rule: &ast::Rule, + ctx: &types::Context, + ) -> Result { + Self::eval_rule_statements(&rule.statements, ctx) + } + + fn eval_program( + program: &ast::Program, + ctx: &types::Context, + ) -> Result, types::InterpreterError> { + for rule in &program.rules { + let res = Self::eval_rule(rule, ctx)?; + + if res { + return Ok(backend::BackendOutput { + connector_selection: rule.connector_selection.clone(), + rule_name: Some(rule.name.clone()), + }); + } + } + + Ok(backend::BackendOutput { + connector_selection: program.default_selection.clone(), + rule_name: None, + }) + } +} + +impl EuclidBackend for InterpreterBackend +where + O: Clone, +{ + type Error = types::InterpreterError; + + fn with_program(program: ast::Program) -> Result { + Ok(Self { program }) + } + + fn execute(&self, input: inputs::BackendInput) -> Result, Self::Error> { + let ctx: types::Context = input.into(); + Self::eval_program(&self.program, &ctx) + } +} diff --git a/crates/euclid/src/backend/interpreter/types.rs b/crates/euclid/src/backend/interpreter/types.rs new file mode 100644 index 000000000000..a6384dbdf3ce --- /dev/null +++ b/crates/euclid/src/backend/interpreter/types.rs @@ -0,0 +1,81 @@ +use std::{collections::HashMap, fmt, ops::Deref, string::ToString}; + +use serde::Serialize; + +use crate::{backend::inputs, frontend::ast::ValueType, types::EuclidKey}; + +#[derive(Debug, Clone, Serialize, thiserror::Error)] +#[serde(tag = "type", content = "data", rename_all = "snake_case")] +pub enum InterpreterErrorType { + #[error("Invalid key received '{0}'")] + InvalidKey(String), + #[error("Invalid Comparison")] + InvalidComparison, +} + +#[derive(Debug, Clone, Serialize, thiserror::Error)] +pub struct InterpreterError { + pub error_type: InterpreterErrorType, + pub metadata: HashMap, +} + +impl fmt::Display for InterpreterError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + InterpreterErrorType::fmt(&self.error_type, f) + } +} + +pub struct Context(HashMap>); + +impl Deref for Context { + type Target = HashMap>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for Context { + fn from(input: inputs::BackendInput) -> Self { + let ctx = HashMap::>::from_iter([ + ( + EuclidKey::PaymentMethod.to_string(), + input + .payment_method + .payment_method + .map(|pm| ValueType::EnumVariant(pm.to_string())), + ), + ( + EuclidKey::PaymentMethodType.to_string(), + input + .payment_method + .payment_method_type + .map(|pt| ValueType::EnumVariant(pt.to_string())), + ), + ( + EuclidKey::AuthenticationType.to_string(), + input + .payment + .authentication_type + .map(|at| ValueType::EnumVariant(at.to_string())), + ), + ( + EuclidKey::CaptureMethod.to_string(), + input + .payment + .capture_method + .map(|cm| ValueType::EnumVariant(cm.to_string())), + ), + ( + EuclidKey::PaymentAmount.to_string(), + Some(ValueType::Number(input.payment.amount)), + ), + ( + EuclidKey::PaymentCurrency.to_string(), + Some(ValueType::EnumVariant(input.payment.currency.to_string())), + ), + ]); + + Self(ctx) + } +} diff --git a/crates/euclid/src/backend/vir_interpreter.rs b/crates/euclid/src/backend/vir_interpreter.rs new file mode 100644 index 000000000000..b7be62cf6740 --- /dev/null +++ b/crates/euclid/src/backend/vir_interpreter.rs @@ -0,0 +1,583 @@ +pub mod types; + +use crate::{ + backend::{self, inputs, EuclidBackend}, + frontend::{ + ast, + dir::{self, EuclidDirFilter}, + vir, + }, +}; + +pub struct VirInterpreterBackend { + program: vir::ValuedProgram, +} + +impl VirInterpreterBackend +where + O: Clone, +{ + #[inline] + fn eval_comparison(comp: &vir::ValuedComparison, ctx: &types::Context) -> bool { + match &comp.logic { + vir::ValuedComparisonLogic::PositiveDisjunction => { + comp.values.iter().any(|v| ctx.check_presence(v)) + } + vir::ValuedComparisonLogic::NegativeConjunction => { + comp.values.iter().all(|v| !ctx.check_presence(v)) + } + } + } + + #[inline] + fn eval_condition(cond: &vir::ValuedIfCondition, ctx: &types::Context) -> bool { + cond.iter().all(|comp| Self::eval_comparison(comp, ctx)) + } + + fn eval_statement(stmt: &vir::ValuedIfStatement, ctx: &types::Context) -> bool { + Self::eval_condition(&stmt.condition, ctx) + .then(|| { + stmt.nested.as_ref().map_or(true, |nested_stmts| { + nested_stmts.iter().any(|s| Self::eval_statement(s, ctx)) + }) + }) + .unwrap_or(false) + } + + fn eval_rule(rule: &vir::ValuedRule, ctx: &types::Context) -> bool { + rule.statements + .iter() + .any(|stmt| Self::eval_statement(stmt, ctx)) + } + + fn eval_program( + program: &vir::ValuedProgram, + ctx: &types::Context, + ) -> backend::BackendOutput { + program + .rules + .iter() + .find(|rule| Self::eval_rule(rule, ctx)) + .map_or_else( + || backend::BackendOutput { + connector_selection: program.default_selection.clone(), + rule_name: None, + }, + |rule| backend::BackendOutput { + connector_selection: rule.connector_selection.clone(), + rule_name: Some(rule.name.clone()), + }, + ) + } +} + +impl EuclidBackend for VirInterpreterBackend +where + O: Clone + EuclidDirFilter, +{ + type Error = types::VirInterpreterError; + + fn with_program(program: ast::Program) -> Result { + let dir_program = ast::lowering::lower_program(program) + .map_err(types::VirInterpreterError::LoweringError)?; + + let vir_program = dir::lowering::lower_program(dir_program) + .map_err(types::VirInterpreterError::LoweringError)?; + + Ok(Self { + program: vir_program, + }) + } + + fn execute( + &self, + input: inputs::BackendInput, + ) -> Result, Self::Error> { + let ctx = types::Context::from_input(input); + Ok(Self::eval_program(&self.program, &ctx)) + } +} +#[cfg(all(test, feature = "ast_parser"))] +mod test { + #![allow(clippy::expect_used)] + use rustc_hash::FxHashMap; + + use super::*; + use crate::{enums, types::DummyOutput}; + + #[test] + fn test_execution() { + let program_str = r#" + default: [ "stripe", "adyen"] + + rule_1: ["stripe"] + { + pay_later = klarna + } + + rule_2: ["adyen"] + { + pay_later = affirm + } + "#; + + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let inp = inputs::BackendInput { + metadata: None, + payment: inputs::PaymentInput { + amount: 32, + card_bin: None, + currency: enums::Currency::USD, + authentication_type: Some(enums::AuthenticationType::NoThreeDs), + capture_method: Some(enums::CaptureMethod::Automatic), + business_country: Some(enums::Country::UnitedStatesOfAmerica), + billing_country: Some(enums::Country::France), + business_label: None, + setup_future_usage: None, + }, + payment_method: inputs::PaymentMethodInput { + payment_method: Some(enums::PaymentMethod::PayLater), + payment_method_type: Some(enums::PaymentMethodType::Affirm), + card_network: None, + }, + mandate: inputs::MandateData { + mandate_acceptance_type: None, + mandate_type: None, + payment_type: None, + }, + }; + + let backend = VirInterpreterBackend::::with_program(program).expect("Program"); + let result = backend.execute(inp).expect("Execution"); + assert_eq!(result.rule_name.expect("Rule Name").as_str(), "rule_2"); + } + #[test] + fn test_payment_type() { + let program_str = r#" + default: ["stripe", "adyen"] + rule_1: ["stripe"] + { + payment_type = setup_mandate + } + "#; + + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let inp = inputs::BackendInput { + metadata: None, + payment: inputs::PaymentInput { + amount: 32, + currency: enums::Currency::USD, + card_bin: Some("123456".to_string()), + authentication_type: Some(enums::AuthenticationType::NoThreeDs), + capture_method: Some(enums::CaptureMethod::Automatic), + business_country: Some(enums::Country::UnitedStatesOfAmerica), + billing_country: Some(enums::Country::France), + business_label: None, + setup_future_usage: None, + }, + payment_method: inputs::PaymentMethodInput { + payment_method: Some(enums::PaymentMethod::PayLater), + payment_method_type: Some(enums::PaymentMethodType::Affirm), + card_network: None, + }, + mandate: inputs::MandateData { + mandate_acceptance_type: None, + mandate_type: None, + payment_type: Some(enums::PaymentType::SetupMandate), + }, + }; + + let backend = VirInterpreterBackend::::with_program(program).expect("Program"); + let result = backend.execute(inp).expect("Execution"); + assert_eq!(result.rule_name.expect("Rule Name").as_str(), "rule_1"); + } + + #[test] + fn test_mandate_type() { + let program_str = r#" + default: ["stripe", "adyen"] + rule_1: ["stripe"] + { + mandate_type = single_use + } + "#; + + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let inp = inputs::BackendInput { + metadata: None, + payment: inputs::PaymentInput { + amount: 32, + currency: enums::Currency::USD, + card_bin: Some("123456".to_string()), + authentication_type: Some(enums::AuthenticationType::NoThreeDs), + capture_method: Some(enums::CaptureMethod::Automatic), + business_country: Some(enums::Country::UnitedStatesOfAmerica), + billing_country: Some(enums::Country::France), + business_label: None, + setup_future_usage: None, + }, + payment_method: inputs::PaymentMethodInput { + payment_method: Some(enums::PaymentMethod::PayLater), + payment_method_type: Some(enums::PaymentMethodType::Affirm), + card_network: None, + }, + mandate: inputs::MandateData { + mandate_acceptance_type: None, + mandate_type: Some(enums::MandateType::SingleUse), + payment_type: None, + }, + }; + + let backend = VirInterpreterBackend::::with_program(program).expect("Program"); + let result = backend.execute(inp).expect("Execution"); + assert_eq!(result.rule_name.expect("Rule Name").as_str(), "rule_1"); + } + + #[test] + fn test_mandate_acceptance_type() { + let program_str = r#" + default: ["stripe","adyen"] + rule_1: ["stripe"] + { + mandate_acceptance_type = online + } + "#; + + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let inp = inputs::BackendInput { + metadata: None, + payment: inputs::PaymentInput { + amount: 32, + currency: enums::Currency::USD, + card_bin: Some("123456".to_string()), + authentication_type: Some(enums::AuthenticationType::NoThreeDs), + capture_method: Some(enums::CaptureMethod::Automatic), + business_country: Some(enums::Country::UnitedStatesOfAmerica), + billing_country: Some(enums::Country::France), + business_label: None, + setup_future_usage: None, + }, + payment_method: inputs::PaymentMethodInput { + payment_method: Some(enums::PaymentMethod::PayLater), + payment_method_type: Some(enums::PaymentMethodType::Affirm), + card_network: None, + }, + mandate: inputs::MandateData { + mandate_acceptance_type: Some(enums::MandateAcceptanceType::Online), + mandate_type: None, + payment_type: None, + }, + }; + + let backend = VirInterpreterBackend::::with_program(program).expect("Program"); + let result = backend.execute(inp).expect("Execution"); + assert_eq!(result.rule_name.expect("Rule Name").as_str(), "rule_1"); + } + #[test] + fn test_card_bin() { + let program_str = r#" + default: ["stripe", "adyen"] + + rule_1: ["stripe"] + { + card_bin="123456" + } + "#; + + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let inp = inputs::BackendInput { + metadata: None, + payment: inputs::PaymentInput { + amount: 32, + currency: enums::Currency::USD, + card_bin: Some("123456".to_string()), + authentication_type: Some(enums::AuthenticationType::NoThreeDs), + capture_method: Some(enums::CaptureMethod::Automatic), + business_country: Some(enums::Country::UnitedStatesOfAmerica), + billing_country: Some(enums::Country::France), + business_label: None, + setup_future_usage: None, + }, + payment_method: inputs::PaymentMethodInput { + payment_method: Some(enums::PaymentMethod::PayLater), + payment_method_type: Some(enums::PaymentMethodType::Affirm), + card_network: None, + }, + mandate: inputs::MandateData { + mandate_acceptance_type: None, + mandate_type: None, + payment_type: None, + }, + }; + + let backend = VirInterpreterBackend::::with_program(program).expect("Program"); + let result = backend.execute(inp).expect("Execution"); + assert_eq!(result.rule_name.expect("Rule Name").as_str(), "rule_1"); + } + #[test] + fn test_payment_amount() { + let program_str = r#" + default: ["stripe", "adyen"] + + rule_1: ["stripe"] + { + amount = 32 + } + "#; + + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let inp = inputs::BackendInput { + metadata: None, + payment: inputs::PaymentInput { + amount: 32, + currency: enums::Currency::USD, + card_bin: None, + authentication_type: Some(enums::AuthenticationType::NoThreeDs), + capture_method: Some(enums::CaptureMethod::Automatic), + business_country: Some(enums::Country::UnitedStatesOfAmerica), + billing_country: Some(enums::Country::France), + business_label: None, + setup_future_usage: None, + }, + payment_method: inputs::PaymentMethodInput { + payment_method: Some(enums::PaymentMethod::PayLater), + payment_method_type: Some(enums::PaymentMethodType::Affirm), + card_network: None, + }, + mandate: inputs::MandateData { + mandate_acceptance_type: None, + mandate_type: None, + payment_type: None, + }, + }; + + let backend = VirInterpreterBackend::::with_program(program).expect("Program"); + let result = backend.execute(inp).expect("Execution"); + assert_eq!(result.rule_name.expect("Rule Name").as_str(), "rule_1"); + } + #[test] + fn test_payment_method() { + let program_str = r#" + default: ["stripe", "adyen"] + + rule_1: ["stripe"] + { + payment_method = pay_later + } + "#; + + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let inp = inputs::BackendInput { + metadata: None, + payment: inputs::PaymentInput { + amount: 32, + currency: enums::Currency::USD, + card_bin: None, + authentication_type: Some(enums::AuthenticationType::NoThreeDs), + capture_method: Some(enums::CaptureMethod::Automatic), + business_country: Some(enums::Country::UnitedStatesOfAmerica), + billing_country: Some(enums::Country::France), + business_label: None, + setup_future_usage: None, + }, + payment_method: inputs::PaymentMethodInput { + payment_method: Some(enums::PaymentMethod::PayLater), + payment_method_type: Some(enums::PaymentMethodType::Affirm), + card_network: None, + }, + mandate: inputs::MandateData { + mandate_acceptance_type: None, + mandate_type: None, + payment_type: None, + }, + }; + + let backend = VirInterpreterBackend::::with_program(program).expect("Program"); + let result = backend.execute(inp).expect("Execution"); + assert_eq!(result.rule_name.expect("Rule Name").as_str(), "rule_1"); + } + #[test] + fn test_future_usage() { + let program_str = r#" + default: ["stripe", "adyen"] + + rule_1: ["stripe"] + { + setup_future_usage = off_session + } + "#; + + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let inp = inputs::BackendInput { + metadata: None, + payment: inputs::PaymentInput { + amount: 32, + currency: enums::Currency::USD, + card_bin: None, + authentication_type: Some(enums::AuthenticationType::NoThreeDs), + capture_method: Some(enums::CaptureMethod::Automatic), + business_country: Some(enums::Country::UnitedStatesOfAmerica), + billing_country: Some(enums::Country::France), + business_label: None, + setup_future_usage: Some(enums::SetupFutureUsage::OffSession), + }, + payment_method: inputs::PaymentMethodInput { + payment_method: Some(enums::PaymentMethod::PayLater), + payment_method_type: Some(enums::PaymentMethodType::Affirm), + card_network: None, + }, + mandate: inputs::MandateData { + mandate_acceptance_type: None, + mandate_type: None, + payment_type: None, + }, + }; + + let backend = VirInterpreterBackend::::with_program(program).expect("Program"); + let result = backend.execute(inp).expect("Execution"); + assert_eq!(result.rule_name.expect("Rule Name").as_str(), "rule_1"); + } + + #[test] + fn test_metadata_execution() { + let program_str = r#" + default: ["stripe"," adyen"] + + rule_1: ["stripe"] + { + "metadata_key" = "arbitrary meta" + } + "#; + let mut meta_map = FxHashMap::default(); + meta_map.insert("metadata_key".to_string(), "arbitrary meta".to_string()); + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let inp = inputs::BackendInput { + metadata: Some(meta_map), + payment: inputs::PaymentInput { + amount: 32, + card_bin: None, + currency: enums::Currency::USD, + authentication_type: Some(enums::AuthenticationType::NoThreeDs), + capture_method: Some(enums::CaptureMethod::Automatic), + business_country: Some(enums::Country::UnitedStatesOfAmerica), + billing_country: Some(enums::Country::France), + business_label: None, + setup_future_usage: None, + }, + payment_method: inputs::PaymentMethodInput { + payment_method: Some(enums::PaymentMethod::PayLater), + payment_method_type: Some(enums::PaymentMethodType::Affirm), + card_network: None, + }, + mandate: inputs::MandateData { + mandate_acceptance_type: None, + mandate_type: None, + payment_type: None, + }, + }; + + let backend = VirInterpreterBackend::::with_program(program).expect("Program"); + let result = backend.execute(inp).expect("Execution"); + assert_eq!(result.rule_name.expect("Rule Name").as_str(), "rule_1"); + } + + #[test] + fn test_less_than_operator() { + let program_str = r#" + default: ["stripe", "adyen"] + + rule_1: ["stripe"] + { + amount>=123 + } + "#; + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let inp_greater = inputs::BackendInput { + metadata: None, + payment: inputs::PaymentInput { + amount: 150, + card_bin: None, + currency: enums::Currency::USD, + authentication_type: Some(enums::AuthenticationType::NoThreeDs), + capture_method: Some(enums::CaptureMethod::Automatic), + business_country: Some(enums::Country::UnitedStatesOfAmerica), + billing_country: Some(enums::Country::France), + business_label: None, + setup_future_usage: None, + }, + payment_method: inputs::PaymentMethodInput { + payment_method: Some(enums::PaymentMethod::PayLater), + payment_method_type: Some(enums::PaymentMethodType::Affirm), + card_network: None, + }, + mandate: inputs::MandateData { + mandate_acceptance_type: None, + mandate_type: None, + payment_type: None, + }, + }; + let mut inp_equal = inp_greater.clone(); + inp_equal.payment.amount = 123; + let backend = VirInterpreterBackend::::with_program(program).expect("Program"); + let result_greater = backend.execute(inp_greater).expect("Execution"); + let result_equal = backend.execute(inp_equal).expect("Execution"); + assert_eq!( + result_equal.rule_name.expect("Rule Name").as_str(), + "rule_1" + ); + assert_eq!( + result_greater.rule_name.expect("Rule Name").as_str(), + "rule_1" + ); + } + + #[test] + fn test_greater_than_operator() { + let program_str = r#" + default: ["stripe", "adyen"] + + rule_1: ["stripe"] + { + amount<=123 + } + "#; + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let inp_lower = inputs::BackendInput { + metadata: None, + payment: inputs::PaymentInput { + amount: 120, + card_bin: None, + currency: enums::Currency::USD, + authentication_type: Some(enums::AuthenticationType::NoThreeDs), + capture_method: Some(enums::CaptureMethod::Automatic), + business_country: Some(enums::Country::UnitedStatesOfAmerica), + billing_country: Some(enums::Country::France), + business_label: None, + setup_future_usage: None, + }, + payment_method: inputs::PaymentMethodInput { + payment_method: Some(enums::PaymentMethod::PayLater), + payment_method_type: Some(enums::PaymentMethodType::Affirm), + card_network: None, + }, + mandate: inputs::MandateData { + mandate_acceptance_type: None, + mandate_type: None, + payment_type: None, + }, + }; + let mut inp_equal = inp_lower.clone(); + inp_equal.payment.amount = 123; + let backend = VirInterpreterBackend::::with_program(program).expect("Program"); + let result_equal = backend.execute(inp_equal).expect("Execution"); + let result_lower = backend.execute(inp_lower).expect("Execution"); + assert_eq!( + result_equal.rule_name.expect("Rule Name").as_str(), + "rule_1" + ); + assert_eq!( + result_lower.rule_name.expect("Rule Name").as_str(), + "rule_1" + ); + } +} diff --git a/crates/euclid/src/backend/vir_interpreter/types.rs b/crates/euclid/src/backend/vir_interpreter/types.rs new file mode 100644 index 000000000000..a144cdaafd08 --- /dev/null +++ b/crates/euclid/src/backend/vir_interpreter/types.rs @@ -0,0 +1,126 @@ +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::{ + backend::inputs::BackendInput, + dssa, + types::{self, EuclidKey, EuclidValue, MetadataValue, NumValueRefinement, StrValue}, +}; + +#[derive(Debug, Clone, serde::Serialize, thiserror::Error)] +pub enum VirInterpreterError { + #[error("Error when lowering the program: {0:?}")] + LoweringError(dssa::types::AnalysisError), +} + +pub struct Context { + atomic_values: FxHashSet, + numeric_values: FxHashMap, +} + +impl Context { + pub fn check_presence(&self, value: &EuclidValue) -> bool { + let key = value.get_key(); + + match key.key_type() { + types::DataType::MetadataValue => self.atomic_values.contains(value), + types::DataType::StrValue => self.atomic_values.contains(value), + types::DataType::EnumVariant => self.atomic_values.contains(value), + types::DataType::Number => { + let ctx_num_value = self + .numeric_values + .get(&key) + .and_then(|value| value.get_num_value()); + + value.get_num_value().zip(ctx_num_value).map_or( + false, + |(program_value, ctx_value)| { + let program_num = program_value.number; + let ctx_num = ctx_value.number; + + match &program_value.refinement { + None => program_num == ctx_num, + Some(NumValueRefinement::NotEqual) => ctx_num != program_num, + Some(NumValueRefinement::GreaterThan) => ctx_num > program_num, + Some(NumValueRefinement::GreaterThanEqual) => ctx_num >= program_num, + Some(NumValueRefinement::LessThanEqual) => ctx_num <= program_num, + Some(NumValueRefinement::LessThan) => ctx_num < program_num, + } + }, + ) + } + } + } + + pub fn from_input(input: BackendInput) -> Self { + let payment = input.payment; + let payment_method = input.payment_method; + let meta_data = input.metadata; + let payment_mandate = input.mandate; + + let mut enum_values: FxHashSet = + FxHashSet::from_iter([EuclidValue::PaymentCurrency(payment.currency)]); + + if let Some(pm) = payment_method.payment_method { + enum_values.insert(EuclidValue::PaymentMethod(pm)); + } + + if let Some(pmt) = payment_method.payment_method_type { + enum_values.insert(EuclidValue::PaymentMethodType(pmt)); + } + + if let Some(met) = meta_data { + for (key, value) in met.into_iter() { + enum_values.insert(EuclidValue::Metadata(MetadataValue { key, value })); + } + } + + if let Some(at) = payment.authentication_type { + enum_values.insert(EuclidValue::AuthenticationType(at)); + } + + if let Some(capture_method) = payment.capture_method { + enum_values.insert(EuclidValue::CaptureMethod(capture_method)); + } + + if let Some(country) = payment.business_country { + enum_values.insert(EuclidValue::BusinessCountry(country)); + } + + if let Some(country) = payment.billing_country { + enum_values.insert(EuclidValue::BillingCountry(country)); + } + if let Some(card_bin) = payment.card_bin { + enum_values.insert(EuclidValue::CardBin(StrValue { value: card_bin })); + } + if let Some(business_label) = payment.business_label { + enum_values.insert(EuclidValue::BusinessLabel(StrValue { + value: business_label, + })); + } + if let Some(setup_future_usage) = payment.setup_future_usage { + enum_values.insert(EuclidValue::SetupFutureUsage(setup_future_usage)); + } + if let Some(payment_type) = payment_mandate.payment_type { + enum_values.insert(EuclidValue::PaymentType(payment_type)); + } + if let Some(mandate_type) = payment_mandate.mandate_type { + enum_values.insert(EuclidValue::MandateType(mandate_type)); + } + if let Some(mandate_acceptance_type) = payment_mandate.mandate_acceptance_type { + enum_values.insert(EuclidValue::MandateAcceptanceType(mandate_acceptance_type)); + } + + let numeric_values: FxHashMap = FxHashMap::from_iter([( + EuclidKey::PaymentAmount, + EuclidValue::PaymentAmount(types::NumValue { + number: payment.amount, + refinement: None, + }), + )]); + + Self { + atomic_values: enum_values, + numeric_values, + } + } +} diff --git a/crates/euclid/src/dssa.rs b/crates/euclid/src/dssa.rs new file mode 100644 index 000000000000..2f6f35dfb27c --- /dev/null +++ b/crates/euclid/src/dssa.rs @@ -0,0 +1,7 @@ +//! Domain Specific Static Analyzer +pub mod analyzer; +pub mod graph; +pub mod state_machine; +pub mod truth; +pub mod types; +pub mod utils; diff --git a/crates/euclid/src/dssa/analyzer.rs b/crates/euclid/src/dssa/analyzer.rs new file mode 100644 index 000000000000..149ed1fd79cd --- /dev/null +++ b/crates/euclid/src/dssa/analyzer.rs @@ -0,0 +1,447 @@ +//! Static Analysis for the Euclid Rule DSL +//! +//! Exposes certain functions that can be used to perform static analysis over programs +//! in the Euclid Rule DSL. These include standard control flow analyses like testing +//! conflicting assertions, to Domain Specific Analyses making use of the +//! [`Knowledge Graph Framework`](crate::dssa::graph). +use rustc_hash::{FxHashMap, FxHashSet}; + +use super::{graph::Memoization, types::EuclidAnalysable}; +use crate::{ + dssa::{graph, state_machine, truth, types}, + frontend::{ + ast, + dir::{self, EuclidDirFilter}, + vir, + }, + types::{DataType, Metadata}, +}; + +/// Analyses conflicting assertions on the same key in a conjunctive context. +/// +/// For example, +/// ```notrust +/// payment_method = card && ... && payment_method = bank_debit +/// ```notrust +/// This is a condition that will never evaluate to `true` given a single +/// payment method and needs to be caught in analysis. +pub fn analyze_conflicting_assertions( + keywise_assertions: &FxHashMap>, + assertion_metadata: &FxHashMap<&dir::DirValue, &Metadata>, +) -> Result<(), types::AnalysisError> { + for (key, value_set) in keywise_assertions { + if value_set.len() > 1 { + let err_type = types::AnalysisErrorType::ConflictingAssertions { + key: key.clone(), + values: value_set + .iter() + .map(|val| types::ValueData { + value: (*val).clone(), + metadata: assertion_metadata + .get(val) + .map(|meta| (*meta).clone()) + .unwrap_or_default(), + }) + .collect(), + }; + + Err(types::AnalysisError { + error_type: err_type, + metadata: Default::default(), + })?; + } + } + Ok(()) +} + +/// Analyses exhaustive negations on the same key in a conjunctive context. +/// +/// For example, +/// ```notrust +/// authentication_type /= three_ds && ... && authentication_type /= no_three_ds +/// ```notrust +/// This is a condition that will never evaluate to `true` given any authentication_type +/// since all the possible values authentication_type can take have been negated. +pub fn analyze_exhaustive_negations( + keywise_negations: &FxHashMap>, + keywise_negation_metadata: &FxHashMap>, +) -> Result<(), types::AnalysisError> { + for (key, negation_set) in keywise_negations { + let mut value_set = if let Some(set) = key.kind.get_value_set() { + set + } else { + continue; + }; + + value_set.retain(|val| !negation_set.contains(val)); + + if value_set.is_empty() { + let error_type = types::AnalysisErrorType::ExhaustiveNegation { + key: key.clone(), + metadata: keywise_negation_metadata + .get(key) + .cloned() + .unwrap_or_default() + .iter() + .cloned() + .cloned() + .collect(), + }; + + Err(types::AnalysisError { + error_type, + metadata: Default::default(), + })?; + } + } + Ok(()) +} + +fn analyze_negated_assertions( + keywise_assertions: &FxHashMap>, + assertion_metadata: &FxHashMap<&dir::DirValue, &Metadata>, + keywise_negations: &FxHashMap>, + negation_metadata: &FxHashMap<&dir::DirValue, &Metadata>, +) -> Result<(), types::AnalysisError> { + for (key, negation_set) in keywise_negations { + let assertion_set = if let Some(set) = keywise_assertions.get(key) { + set + } else { + continue; + }; + + let intersection = negation_set & assertion_set; + + intersection.iter().next().map_or(Ok(()), |val| { + let error_type = types::AnalysisErrorType::NegatedAssertion { + value: (*val).clone(), + assertion_metadata: assertion_metadata + .get(*val) + .cloned() + .cloned() + .unwrap_or_default(), + negation_metadata: negation_metadata + .get(*val) + .cloned() + .cloned() + .unwrap_or_default(), + }; + + Err(types::AnalysisError { + error_type, + metadata: Default::default(), + }) + })?; + } + Ok(()) +} + +fn perform_condition_analyses( + context: &types::ConjunctiveContext<'_>, +) -> Result<(), types::AnalysisError> { + let mut assertion_metadata: FxHashMap<&dir::DirValue, &Metadata> = FxHashMap::default(); + let mut keywise_assertions: FxHashMap> = + FxHashMap::default(); + let mut negation_metadata: FxHashMap<&dir::DirValue, &Metadata> = FxHashMap::default(); + let mut keywise_negation_metadata: FxHashMap> = + FxHashMap::default(); + let mut keywise_negations: FxHashMap> = + FxHashMap::default(); + + for ctx_val in context { + let key = if let Some(k) = ctx_val.value.get_key() { + k + } else { + continue; + }; + + if let dir::DirKeyKind::Connector = key.kind { + continue; + } + + if !matches!(key.kind.get_type(), DataType::EnumVariant) { + continue; + } + + match ctx_val.value { + types::CtxValueKind::Assertion(val) => { + keywise_assertions + .entry(key.clone()) + .or_default() + .insert(val); + + assertion_metadata.insert(val, ctx_val.metadata); + } + + types::CtxValueKind::Negation(vals) => { + let negation_set = keywise_negations.entry(key.clone()).or_default(); + + for val in vals { + negation_set.insert(val); + negation_metadata.insert(val, ctx_val.metadata); + } + + keywise_negation_metadata + .entry(key.clone()) + .or_default() + .push(ctx_val.metadata); + } + } + } + + analyze_conflicting_assertions(&keywise_assertions, &assertion_metadata)?; + analyze_exhaustive_negations(&keywise_negations, &keywise_negation_metadata)?; + analyze_negated_assertions( + &keywise_assertions, + &assertion_metadata, + &keywise_negations, + &negation_metadata, + )?; + + Ok(()) +} + +fn perform_context_analyses( + context: &types::ConjunctiveContext<'_>, + knowledge_graph: &graph::KnowledgeGraph<'_>, +) -> Result<(), types::AnalysisError> { + perform_condition_analyses(context)?; + let mut memo = Memoization::new(); + knowledge_graph + .perform_context_analysis(context, &mut memo) + .map_err(|err| types::AnalysisError { + error_type: types::AnalysisErrorType::GraphAnalysis(err, memo), + metadata: Default::default(), + })?; + Ok(()) +} + +pub fn analyze( + program: ast::Program, + knowledge_graph: Option<&graph::KnowledgeGraph<'_>>, +) -> Result, types::AnalysisError> { + let dir_program = ast::lowering::lower_program(program)?; + + let selection_data = state_machine::make_connector_selection_data(&dir_program); + let mut ctx_manager = state_machine::AnalysisContextManager::new(&dir_program, &selection_data); + while let Some(ctx) = ctx_manager.advance().map_err(|err| types::AnalysisError { + metadata: Default::default(), + error_type: types::AnalysisErrorType::StateMachine(err), + })? { + perform_context_analyses(ctx, knowledge_graph.unwrap_or(&truth::ANALYSIS_GRAPH))?; + } + + dir::lowering::lower_program(dir_program) +} + +#[cfg(all(test, feature = "ast_parser"))] +mod tests { + #![allow(clippy::panic, clippy::expect_used)] + + use std::{ops::Deref, sync::Weak}; + + use euclid_macros::knowledge; + + use super::*; + use crate::{dirval, types::DummyOutput}; + + #[test] + fn test_conflicting_assertion_detection() { + let program_str = r#" + default: ["stripe", "adyen"] + + stripe_first: ["stripe", "adyen"] + { + payment_method = wallet { + amount > 500 & capture_method = automatic + amount < 500 & payment_method = card + } + } + "#; + + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let analysis_result = analyze(program, None); + + if let Err(types::AnalysisError { + error_type: types::AnalysisErrorType::ConflictingAssertions { key, values }, + .. + }) = analysis_result + { + assert!( + matches!(key.kind, dir::DirKeyKind::PaymentMethod), + "Key should be payment_method" + ); + let values: Vec = values.into_iter().map(|v| v.value).collect(); + assert_eq!(values.len(), 2, "There should be 2 conflicting conditions"); + assert!( + values.contains(&dirval!(PaymentMethod = Wallet)), + "Condition should include payment_method = wallet" + ); + assert!( + values.contains(&dirval!(PaymentMethod = Card)), + "Condition should include payment_method = card" + ); + } else { + panic!("Did not receive conflicting assertions error"); + } + } + + #[test] + fn test_exhaustive_negation_detection() { + let program_str = r#" + default: ["stripe"] + + rule_1: ["adyen"] + { + payment_method /= wallet { + capture_method = manual & payment_method /= card { + authentication_type = three_ds & payment_method /= pay_later { + amount > 1000 & payment_method /= bank_redirect { + payment_method /= crypto + & payment_method /= bank_debit + & payment_method /= bank_transfer + & payment_method /= upi + & payment_method /= reward + & payment_method /= voucher + & payment_method /= gift_card + + } + } + } + } + } + "#; + + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let analysis_result = analyze(program, None); + + if let Err(types::AnalysisError { + error_type: types::AnalysisErrorType::ExhaustiveNegation { key, .. }, + .. + }) = analysis_result + { + assert!( + matches!(key.kind, dir::DirKeyKind::PaymentMethod), + "Expected key to be payment_method" + ); + } else { + panic!("Expected exhaustive negation error"); + } + } + + #[test] + fn test_negated_assertions_detection() { + let program_str = r#" + default: ["stripe"] + + rule_1: ["adyen"] + { + payment_method = wallet { + amount > 500 { + capture_method = automatic + } + + amount < 501 { + payment_method /= wallet + } + } + } + "#; + + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let analysis_result = analyze(program, None); + + if let Err(types::AnalysisError { + error_type: types::AnalysisErrorType::NegatedAssertion { value, .. }, + .. + }) = analysis_result + { + assert_eq!( + value, + dirval!(PaymentMethod = Wallet), + "Expected to catch payment_method = wallet as conflict" + ); + } else { + panic!("Expected negated assertion error"); + } + } + + #[test] + fn test_negation_graph_analysis() { + let graph = knowledge! {crate + CaptureMethod(Automatic) ->> PaymentMethod(Card); + }; + + let program_str = r#" + default: ["stripe"] + + rule_1: ["adyen"] + { + amount > 500 { + payment_method = pay_later + } + + amount < 500 { + payment_method /= wallet & payment_method /= pay_later + } + } + "#; + + let (_, program) = ast::parser::program::(program_str).expect("Graph"); + let analysis_result = analyze(program, Some(&graph)); + + let error_type = match analysis_result { + Err(types::AnalysisError { error_type, .. }) => error_type, + _ => panic!("Error_type not found"), + }; + + let a_err = match error_type { + types::AnalysisErrorType::GraphAnalysis(trace, memo) => (trace, memo), + _ => panic!("Graph Analysis not found"), + }; + + let (trace, metadata) = match a_err.0 { + graph::AnalysisError::NegationTrace { trace, metadata } => (trace, metadata), + _ => panic!("Negation Trace not found"), + }; + + let predecessor = match Weak::upgrade(&trace) + .expect("Expected Arc not found") + .deref() + .clone() + { + graph::AnalysisTrace::Value { predecessors, .. } => { + let _value = graph::NodeValue::Value(dir::DirValue::PaymentMethod( + dir::enums::PaymentMethod::Card, + )); + let _relation = graph::Relation::Positive; + predecessors + } + _ => panic!("Expected Negation Trace for payment method = card"), + }; + + let pred = match predecessor { + Some(graph::ValueTracePredecessor::Mandatory(predecessor)) => predecessor, + _ => panic!("No predecessor found"), + }; + assert_eq!( + metadata.len(), + 2, + "Expected two metadats for wallet and pay_later" + ); + assert!(matches!( + *Weak::upgrade(&pred) + .expect("Expected Arc not found") + .deref(), + graph::AnalysisTrace::Value { + value: graph::NodeValue::Value(dir::DirValue::CaptureMethod( + dir::enums::CaptureMethod::Automatic + )), + relation: graph::Relation::Positive, + info: None, + metadata: None, + predecessors: None, + } + )); + } +} diff --git a/crates/euclid/src/dssa/graph.rs b/crates/euclid/src/dssa/graph.rs new file mode 100644 index 000000000000..bd23ae385226 --- /dev/null +++ b/crates/euclid/src/dssa/graph.rs @@ -0,0 +1,1478 @@ +use std::{ + fmt::Debug, + hash::Hash, + ops::{Deref, DerefMut}, + sync::{Arc, Weak}, +}; + +use erased_serde::{self, Serialize as ErasedSerialize}; +use rustc_hash::{FxHashMap, FxHashSet}; +use serde::Serialize; + +use crate::{ + dssa::types, + frontend::dir, + types::{DataType, Metadata}, + utils, +}; + +#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Hash, strum::Display)] +pub enum Strength { + Weak, + Normal, + Strong, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::Display, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum Relation { + Positive, + Negative, +} + +impl From for bool { + fn from(value: Relation) -> Self { + matches!(value, Relation::Positive) + } +} + +#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Hash)] +pub struct NodeId(usize); + +impl utils::EntityId for NodeId { + #[inline] + fn get_id(&self) -> usize { + self.0 + } + + #[inline] + fn with_id(id: usize) -> Self { + Self(id) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct DomainInfo<'a> { + pub domain_identifier: DomainIdentifier<'a>, + pub domain_description: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct DomainIdentifier<'a>(&'a str); + +impl<'a> DomainIdentifier<'a> { + pub fn new(domain_identifier: &'a str) -> Self { + Self(domain_identifier) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct DomainId(usize); + +impl utils::EntityId for DomainId { + #[inline] + fn get_id(&self) -> usize { + self.0 + } + + #[inline] + fn with_id(id: usize) -> Self { + Self(id) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct EdgeId(usize); + +impl utils::EntityId for EdgeId { + #[inline] + fn get_id(&self) -> usize { + self.0 + } + + #[inline] + fn with_id(id: usize) -> Self { + Self(id) + } +} + +#[derive(Debug, Clone, Serialize)] +pub struct Memoization(FxHashMap<(NodeId, Relation, Strength), Result<(), Arc>>); + +impl Memoization { + pub fn new() -> Self { + Self(FxHashMap::default()) + } +} + +impl Default for Memoization { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl Deref for Memoization { + type Target = FxHashMap<(NodeId, Relation, Strength), Result<(), Arc>>; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for Memoization { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} +#[derive(Debug, Clone)] +pub struct Edge { + pub strength: Strength, + pub relation: Relation, + pub pred: NodeId, + pub succ: NodeId, +} + +#[derive(Debug)] +pub struct Node { + pub node_type: NodeType, + pub preds: Vec, + pub succs: Vec, + pub domain_ids: Vec, +} + +impl Node { + fn new(node_type: NodeType, domain_ids: Vec) -> Self { + Self { + node_type, + preds: Vec::new(), + succs: Vec::new(), + domain_ids, + } + } +} + +pub trait KgraphMetadata: ErasedSerialize + std::any::Any + Sync + Send + Debug {} +erased_serde::serialize_trait_object!(KgraphMetadata); + +impl KgraphMetadata for M where M: ErasedSerialize + std::any::Any + Sync + Send + Debug {} + +#[derive(Debug)] +pub struct KnowledgeGraph<'a> { + domain: utils::DenseMap>, + nodes: utils::DenseMap, + edges: utils::DenseMap, + value_map: FxHashMap, + node_info: utils::DenseMap>, + node_metadata: utils::DenseMap>>, +} + +pub struct KnowledgeGraphBuilder<'a> { + domain: utils::DenseMap>, + nodes: utils::DenseMap, + edges: utils::DenseMap, + domain_identifier_map: FxHashMap, DomainId>, + value_map: FxHashMap, + edges_map: FxHashMap<(NodeId, NodeId), EdgeId>, + node_info: utils::DenseMap>, + node_metadata: utils::DenseMap>>, +} + +impl<'a> Default for KnowledgeGraphBuilder<'a> { + #[inline] + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, PartialEq, Eq)] +pub enum NodeType { + AllAggregator, + AnyAggregator, + InAggregator(FxHashSet), + Value(NodeValue), +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize)] +#[serde(tag = "type", content = "value", rename_all = "snake_case")] +pub enum NodeValue { + Key(dir::DirKey), + Value(dir::DirValue), +} + +impl From for NodeValue { + fn from(value: dir::DirValue) -> Self { + Self::Value(value) + } +} + +impl From for NodeValue { + fn from(key: dir::DirKey) -> Self { + Self::Key(key) + } +} + +#[derive(Debug, Clone, serde::Serialize)] +#[serde(tag = "type", content = "predecessor", rename_all = "snake_case")] +pub enum ValueTracePredecessor { + Mandatory(Box>), + OneOf(Vec>), +} + +#[derive(Debug, Clone, serde::Serialize)] +#[serde(tag = "type", content = "trace", rename_all = "snake_case")] +pub enum AnalysisTrace { + Value { + value: NodeValue, + relation: Relation, + predecessors: Option, + info: Option<&'static str>, + metadata: Option>, + }, + + AllAggregation { + unsatisfied: Vec>, + info: Option<&'static str>, + metadata: Option>, + }, + + AnyAggregation { + unsatisfied: Vec>, + info: Option<&'static str>, + metadata: Option>, + }, + + InAggregation { + expected: Vec, + found: Option, + relation: Relation, + info: Option<&'static str>, + metadata: Option>, + }, +} + +#[derive(Debug, Clone, serde::Serialize)] +#[serde(tag = "type", content = "details", rename_all = "snake_case")] +pub enum AnalysisError { + Graph(GraphError), + AssertionTrace { + trace: Weak, + metadata: Metadata, + }, + NegationTrace { + trace: Weak, + metadata: Vec, + }, +} + +impl AnalysisError { + fn assertion_from_graph_error(metadata: &Metadata, graph_error: GraphError) -> Self { + match graph_error { + GraphError::AnalysisError(trace) => Self::AssertionTrace { + trace, + metadata: metadata.clone(), + }, + + other => Self::Graph(other), + } + } + + fn negation_from_graph_error(metadata: Vec<&Metadata>, graph_error: GraphError) -> Self { + match graph_error { + GraphError::AnalysisError(trace) => Self::NegationTrace { + trace, + metadata: metadata.iter().map(|m| (*m).clone()).collect(), + }, + + other => Self::Graph(other), + } + } +} + +#[derive(Debug, Clone, serde::Serialize, thiserror::Error)] +#[serde(tag = "type", content = "info", rename_all = "snake_case")] +pub enum GraphError { + #[error("An edge was not found in the graph")] + EdgeNotFound, + #[error("Attempted to create a conflicting edge between two nodes")] + ConflictingEdgeCreated, + #[error("Cycle detected in graph")] + CycleDetected, + #[error("Domain wasn't found in the Graph")] + DomainNotFound, + #[error("Malformed Graph: {reason}")] + MalformedGraph { reason: String }, + #[error("A node was not found in the graph")] + NodeNotFound, + #[error("A value node was not found: {0:#?}")] + ValueNodeNotFound(dir::DirValue), + #[error("No values provided for an 'in' aggregator node")] + NoInAggregatorValues, + #[error("Error during analysis: {0:#?}")] + AnalysisError(Weak), +} + +impl GraphError { + fn get_analysis_trace(self) -> Result, Self> { + match self { + Self::AnalysisError(trace) => Ok(trace), + _ => Err(self), + } + } +} + +impl PartialEq for NodeValue { + fn eq(&self, other: &dir::DirValue) -> bool { + match self { + Self::Key(dir_key) => *dir_key == other.get_key(), + Self::Value(dir_value) if dir_value.get_key() == other.get_key() => { + if let (Some(left), Some(right)) = + (dir_value.get_num_value(), other.get_num_value()) + { + left.fits(&right) + } else { + dir::DirValue::check_equality(dir_value, other) + } + } + Self::Value(_) => false, + } + } +} + +pub struct AnalysisContext { + keywise_values: FxHashMap>, +} + +impl AnalysisContext { + pub fn from_dir_values(vals: impl IntoIterator) -> Self { + let mut keywise_values: FxHashMap> = + FxHashMap::default(); + + for dir_val in vals { + let key = dir_val.get_key(); + let set = keywise_values.entry(key).or_default(); + set.insert(dir_val); + } + + Self { keywise_values } + } + + fn check_presence(&self, value: &NodeValue, weak: bool) -> bool { + match value { + NodeValue::Key(k) => self.keywise_values.contains_key(k) || weak, + NodeValue::Value(val) => { + let key = val.get_key(); + let value_set = if let Some(set) = self.keywise_values.get(&key) { + set + } else { + return weak; + }; + + match key.kind.get_type() { + DataType::EnumVariant | DataType::StrValue | DataType::MetadataValue => { + value_set.contains(val) + } + DataType::Number => val.get_num_value().map_or(false, |num_val| { + value_set.iter().any(|ctx_val| { + ctx_val + .get_num_value() + .map_or(false, |ctx_num_val| num_val.fits(&ctx_num_val)) + }) + }), + } + } + } + } + + pub fn insert(&mut self, value: dir::DirValue) { + self.keywise_values + .entry(value.get_key()) + .or_default() + .insert(value); + } + + pub fn remove(&mut self, value: dir::DirValue) { + let set = self.keywise_values.entry(value.get_key()).or_default(); + + set.remove(&value); + + if set.is_empty() { + self.keywise_values.remove(&value.get_key()); + } + } +} + +impl<'a> KnowledgeGraphBuilder<'a> { + pub fn new() -> Self { + Self { + domain: utils::DenseMap::new(), + nodes: utils::DenseMap::new(), + edges: utils::DenseMap::new(), + domain_identifier_map: FxHashMap::default(), + value_map: FxHashMap::default(), + edges_map: FxHashMap::default(), + node_info: utils::DenseMap::new(), + node_metadata: utils::DenseMap::new(), + } + } + + pub fn build(self) -> KnowledgeGraph<'a> { + KnowledgeGraph { + domain: self.domain, + nodes: self.nodes, + edges: self.edges, + value_map: self.value_map, + node_info: self.node_info, + node_metadata: self.node_metadata, + } + } + + pub fn make_domain( + &mut self, + domain_identifier: DomainIdentifier<'a>, + domain_description: String, + ) -> Result { + Ok(self + .domain_identifier_map + .clone() + .get(&domain_identifier) + .map_or_else( + || { + let domain_id = self.domain.push(DomainInfo { + domain_identifier: domain_identifier.clone(), + domain_description, + }); + self.domain_identifier_map + .insert(domain_identifier.clone(), domain_id); + domain_id + }, + |domain_id| *domain_id, + )) + } + + pub fn make_value_node( + &mut self, + value: NodeValue, + info: Option<&'static str>, + domain_identifiers: Vec>, + metadata: Option, + ) -> Result { + match self.value_map.get(&value).copied() { + Some(node_id) => Ok(node_id), + None => { + let mut domain_ids: Vec = Vec::new(); + domain_identifiers + .iter() + .try_for_each(|ident| { + self.domain_identifier_map + .get(ident) + .map(|id| domain_ids.push(*id)) + }) + .ok_or(GraphError::DomainNotFound)?; + + let node_id = self + .nodes + .push(Node::new(NodeType::Value(value.clone()), domain_ids)); + let _node_info_id = self.node_info.push(info); + + let _node_metadata_id = self + .node_metadata + .push(metadata.map(|meta| -> Arc { Arc::new(meta) })); + + self.value_map.insert(value, node_id); + Ok(node_id) + } + } + } + + pub fn make_edge( + &mut self, + pred_id: NodeId, + succ_id: NodeId, + strength: Strength, + relation: Relation, + ) -> Result { + self.ensure_node_exists(pred_id)?; + self.ensure_node_exists(succ_id)?; + self.edges_map + .get(&(pred_id, succ_id)) + .copied() + .and_then(|edge_id| self.edges.get(edge_id).cloned().map(|edge| (edge_id, edge))) + .map_or_else( + || { + let edge_id = self.edges.push(Edge { + strength, + relation, + pred: pred_id, + succ: succ_id, + }); + self.edges_map.insert((pred_id, succ_id), edge_id); + + let pred = self + .nodes + .get_mut(pred_id) + .ok_or(GraphError::NodeNotFound)?; + pred.succs.push(edge_id); + + let succ = self + .nodes + .get_mut(succ_id) + .ok_or(GraphError::NodeNotFound)?; + succ.preds.push(edge_id); + + Ok(edge_id) + }, + |(edge_id, edge)| { + if edge.strength == strength && edge.relation == relation { + Ok(edge_id) + } else { + Err(GraphError::ConflictingEdgeCreated) + } + }, + ) + } + + pub fn make_all_aggregator( + &mut self, + nodes: &[(NodeId, Relation, Strength)], + info: Option<&'static str>, + metadata: Option, + domain: Vec>, + ) -> Result { + nodes + .iter() + .try_for_each(|(node_id, _, _)| self.ensure_node_exists(*node_id))?; + + let mut domain_ids: Vec = Vec::new(); + domain + .iter() + .try_for_each(|ident| { + self.domain_identifier_map + .get(ident) + .map(|id| domain_ids.push(*id)) + }) + .ok_or(GraphError::DomainNotFound)?; + + let aggregator_id = self + .nodes + .push(Node::new(NodeType::AllAggregator, domain_ids)); + let _aggregator_info_id = self.node_info.push(info); + + let _node_metadata_id = self + .node_metadata + .push(metadata.map(|meta| -> Arc { Arc::new(meta) })); + + for (node_id, relation, strength) in nodes { + self.make_edge(*node_id, aggregator_id, *strength, *relation)?; + } + + Ok(aggregator_id) + } + + pub fn make_any_aggregator( + &mut self, + nodes: &[(NodeId, Relation)], + info: Option<&'static str>, + metadata: Option, + domain: Vec>, + ) -> Result { + nodes + .iter() + .try_for_each(|(node_id, _)| self.ensure_node_exists(*node_id))?; + + let mut domain_ids: Vec = Vec::new(); + domain + .iter() + .try_for_each(|ident| { + self.domain_identifier_map + .get(ident) + .map(|id| domain_ids.push(*id)) + }) + .ok_or(GraphError::DomainNotFound)?; + + let aggregator_id = self + .nodes + .push(Node::new(NodeType::AnyAggregator, domain_ids)); + let _aggregator_info_id = self.node_info.push(info); + + let _node_metadata_id = self + .node_metadata + .push(metadata.map(|meta| -> Arc { Arc::new(meta) })); + + for (node_id, relation) in nodes { + self.make_edge(*node_id, aggregator_id, Strength::Strong, *relation)?; + } + + Ok(aggregator_id) + } + + pub fn make_in_aggregator( + &mut self, + values: Vec, + info: Option<&'static str>, + metadata: Option, + domain: Vec>, + ) -> Result { + let key = values + .first() + .ok_or(GraphError::NoInAggregatorValues)? + .get_key(); + + for val in &values { + if val.get_key() != key { + Err(GraphError::MalformedGraph { + reason: "Values for 'In' aggregator not of same key".to_string(), + })?; + } + } + + let mut domain_ids: Vec = Vec::new(); + domain + .iter() + .try_for_each(|ident| { + self.domain_identifier_map + .get(ident) + .map(|id| domain_ids.push(*id)) + }) + .ok_or(GraphError::DomainNotFound)?; + + let node_id = self.nodes.push(Node::new( + NodeType::InAggregator(FxHashSet::from_iter(values)), + domain_ids, + )); + let _aggregator_info_id = self.node_info.push(info); + + let _node_metadata_id = self + .node_metadata + .push(metadata.map(|meta| -> Arc { Arc::new(meta) })); + + Ok(node_id) + } + + fn ensure_node_exists(&self, id: NodeId) -> Result<(), GraphError> { + if self.nodes.contains_key(id) { + Ok(()) + } else { + Err(GraphError::NodeNotFound) + } + } +} + +impl<'a> KnowledgeGraph<'a> { + fn check_node( + &self, + ctx: &AnalysisContext, + node_id: NodeId, + relation: Relation, + strength: Strength, + memo: &mut Memoization, + ) -> Result<(), GraphError> { + let node = self.nodes.get(node_id).ok_or(GraphError::NodeNotFound)?; + if let Some(already_memo) = memo.get(&(node_id, relation, strength)) { + already_memo + .clone() + .map_err(|err| GraphError::AnalysisError(Arc::downgrade(&err))) + } else { + match &node.node_type { + NodeType::AllAggregator => { + let mut unsatisfied = Vec::>::new(); + + for edge_id in node.preds.iter().copied() { + let edge = self.edges.get(edge_id).ok_or(GraphError::EdgeNotFound)?; + + if let Err(e) = + self.check_node(ctx, edge.pred, edge.relation, edge.strength, memo) + { + unsatisfied.push(e.get_analysis_trace()?); + } + } + + if !unsatisfied.is_empty() { + let err = Arc::new(AnalysisTrace::AllAggregation { + unsatisfied, + info: self.node_info.get(node_id).cloned().flatten(), + metadata: self.node_metadata.get(node_id).cloned().flatten(), + }); + + memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); + Err(GraphError::AnalysisError(Arc::downgrade(&err))) + } else { + memo.insert((node_id, relation, strength), Ok(())); + Ok(()) + } + } + + NodeType::AnyAggregator => { + let mut unsatisfied = Vec::>::new(); + let mut matched_one = false; + + for edge_id in node.preds.iter().copied() { + let edge = self.edges.get(edge_id).ok_or(GraphError::EdgeNotFound)?; + + if let Err(e) = + self.check_node(ctx, edge.pred, edge.relation, edge.strength, memo) + { + unsatisfied.push(e.get_analysis_trace()?); + } else { + matched_one = true; + } + } + + if matched_one || node.preds.is_empty() { + memo.insert((node_id, relation, strength), Ok(())); + Ok(()) + } else { + let err = Arc::new(AnalysisTrace::AnyAggregation { + unsatisfied: unsatisfied.clone(), + info: self.node_info.get(node_id).cloned().flatten(), + metadata: self.node_metadata.get(node_id).cloned().flatten(), + }); + + memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); + Err(GraphError::AnalysisError(Arc::downgrade(&err))) + } + } + + NodeType::InAggregator(expected) => { + let the_key = expected + .iter() + .next() + .ok_or_else(|| GraphError::MalformedGraph { + reason: + "An OnlyIn aggregator node must have at least one expected value" + .to_string(), + })? + .get_key(); + + let ctx_vals = if let Some(vals) = ctx.keywise_values.get(&the_key) { + vals + } else { + return if let Strength::Weak = strength { + memo.insert((node_id, relation, strength), Ok(())); + Ok(()) + } else { + let err = Arc::new(AnalysisTrace::InAggregation { + expected: expected.iter().cloned().collect(), + found: None, + relation, + info: self.node_info.get(node_id).cloned().flatten(), + metadata: self.node_metadata.get(node_id).cloned().flatten(), + }); + + memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); + Err(GraphError::AnalysisError(Arc::downgrade(&err))) + }; + }; + + let relation_bool: bool = relation.into(); + for ctx_value in ctx_vals { + if expected.contains(ctx_value) != relation_bool { + let err = Arc::new(AnalysisTrace::InAggregation { + expected: expected.iter().cloned().collect(), + found: Some(ctx_value.clone()), + relation, + info: self.node_info.get(node_id).cloned().flatten(), + metadata: self.node_metadata.get(node_id).cloned().flatten(), + }); + + memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); + Err(GraphError::AnalysisError(Arc::downgrade(&err)))?; + } + } + + memo.insert((node_id, relation, strength), Ok(())); + Ok(()) + } + + NodeType::Value(val) => { + let in_context = ctx.check_presence(val, matches!(strength, Strength::Weak)); + let relation_bool: bool = relation.into(); + + if in_context != relation_bool { + let err = Arc::new(AnalysisTrace::Value { + value: val.clone(), + relation, + predecessors: None, + info: self.node_info.get(node_id).cloned().flatten(), + metadata: self.node_metadata.get(node_id).cloned().flatten(), + }); + + memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); + Err(GraphError::AnalysisError(Arc::downgrade(&err)))?; + } + + if !relation_bool { + memo.insert((node_id, relation, strength), Ok(())); + return Ok(()); + } + + let mut errors = Vec::>::new(); + let mut matched_one = false; + + for edge_id in node.preds.iter().copied() { + let edge = self.edges.get(edge_id).ok_or(GraphError::EdgeNotFound)?; + let result = + self.check_node(ctx, edge.pred, edge.relation, edge.strength, memo); + + match (edge.strength, result) { + (Strength::Strong, Err(trace)) => { + let err = Arc::new(AnalysisTrace::Value { + value: val.clone(), + relation, + info: self.node_info.get(node_id).cloned().flatten(), + metadata: self.node_metadata.get(node_id).cloned().flatten(), + predecessors: Some(ValueTracePredecessor::Mandatory(Box::new( + trace.get_analysis_trace()?, + ))), + }); + memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); + Err(GraphError::AnalysisError(Arc::downgrade(&err)))?; + } + + (Strength::Strong, Ok(_)) => { + matched_one = true; + } + + (Strength::Normal | Strength::Weak, Err(trace)) => { + errors.push(trace.get_analysis_trace()?); + } + + (Strength::Normal | Strength::Weak, Ok(_)) => { + matched_one = true; + } + } + } + + if matched_one || node.preds.is_empty() { + memo.insert((node_id, relation, strength), Ok(())); + Ok(()) + } else { + let err = Arc::new(AnalysisTrace::Value { + value: val.clone(), + relation, + info: self.node_info.get(node_id).cloned().flatten(), + metadata: self.node_metadata.get(node_id).cloned().flatten(), + predecessors: Some(ValueTracePredecessor::OneOf(errors.clone())), + }); + + memo.insert((node_id, relation, strength), Err(Arc::clone(&err))); + Err(GraphError::AnalysisError(Arc::downgrade(&err))) + } + } + } + } + } + + fn key_analysis( + &self, + key: dir::DirKey, + ctx: &AnalysisContext, + memo: &mut Memoization, + ) -> Result<(), GraphError> { + self.value_map + .get(&NodeValue::Key(key)) + .map_or(Ok(()), |node_id| { + self.check_node(ctx, *node_id, Relation::Positive, Strength::Strong, memo) + }) + } + + fn value_analysis( + &self, + val: dir::DirValue, + ctx: &AnalysisContext, + memo: &mut Memoization, + ) -> Result<(), GraphError> { + self.value_map + .get(&NodeValue::Value(val)) + .map_or(Ok(()), |node_id| { + self.check_node(ctx, *node_id, Relation::Positive, Strength::Strong, memo) + }) + } + + pub fn check_value_validity( + &self, + val: dir::DirValue, + analysis_ctx: &AnalysisContext, + memo: &mut Memoization, + ) -> Result { + let maybe_node_id = self.value_map.get(&NodeValue::Value(val)); + + let node_id = if let Some(nid) = maybe_node_id { + nid + } else { + return Ok(false); + }; + + let result = self.check_node( + analysis_ctx, + *node_id, + Relation::Positive, + Strength::Weak, + memo, + ); + + match result { + Ok(_) => Ok(true), + Err(e) => { + e.get_analysis_trace()?; + Ok(false) + } + } + } + + pub fn key_value_analysis( + &self, + val: dir::DirValue, + ctx: &AnalysisContext, + memo: &mut Memoization, + ) -> Result<(), GraphError> { + self.key_analysis(val.get_key(), ctx, memo) + .and_then(|_| self.value_analysis(val, ctx, memo)) + } + + fn assertion_analysis( + &self, + positive_ctx: &[(&dir::DirValue, &Metadata)], + analysis_ctx: &AnalysisContext, + memo: &mut Memoization, + ) -> Result<(), AnalysisError> { + positive_ctx.iter().try_for_each(|(value, metadata)| { + self.key_value_analysis((*value).clone(), analysis_ctx, memo) + .map_err(|e| AnalysisError::assertion_from_graph_error(metadata, e)) + }) + } + + fn negation_analysis( + &self, + negative_ctx: &[(&[dir::DirValue], &Metadata)], + analysis_ctx: &mut AnalysisContext, + memo: &mut Memoization, + ) -> Result<(), AnalysisError> { + let mut keywise_metadata: FxHashMap> = FxHashMap::default(); + let mut keywise_negation: FxHashMap> = + FxHashMap::default(); + + for (values, metadata) in negative_ctx { + let mut metadata_added = false; + + for dir_value in *values { + if !metadata_added { + keywise_metadata + .entry(dir_value.get_key()) + .or_default() + .push(metadata); + + metadata_added = true; + } + + keywise_negation + .entry(dir_value.get_key()) + .or_default() + .insert(dir_value); + } + } + + for (key, negation_set) in keywise_negation { + let all_metadata = keywise_metadata.remove(&key).unwrap_or_default(); + let first_metadata = all_metadata.first().cloned().cloned().unwrap_or_default(); + + self.key_analysis(key.clone(), analysis_ctx, memo) + .map_err(|e| AnalysisError::assertion_from_graph_error(&first_metadata, e))?; + + let mut value_set = if let Some(set) = key.kind.get_value_set() { + set + } else { + continue; + }; + + value_set.retain(|v| !negation_set.contains(v)); + + for value in value_set { + analysis_ctx.insert(value.clone()); + self.value_analysis(value.clone(), analysis_ctx, memo) + .map_err(|e| { + AnalysisError::negation_from_graph_error(all_metadata.clone(), e) + })?; + analysis_ctx.remove(value); + } + } + + Ok(()) + } + + pub fn perform_context_analysis( + &self, + ctx: &types::ConjunctiveContext<'_>, + memo: &mut Memoization, + ) -> Result<(), AnalysisError> { + let mut analysis_ctx = AnalysisContext::from_dir_values( + ctx.iter() + .filter_map(|ctx_val| ctx_val.value.get_assertion().cloned()), + ); + + let positive_ctx = ctx + .iter() + .filter_map(|ctx_val| { + ctx_val + .value + .get_assertion() + .map(|val| (val, ctx_val.metadata)) + }) + .collect::>(); + self.assertion_analysis(&positive_ctx, &analysis_ctx, memo)?; + + let negative_ctx = ctx + .iter() + .filter_map(|ctx_val| { + ctx_val + .value + .get_negation() + .map(|vals| (vals, ctx_val.metadata)) + }) + .collect::>(); + self.negation_analysis(&negative_ctx, &mut analysis_ctx, memo)?; + + Ok(()) + } + + pub fn combine<'b>(g1: &'b Self, g2: &'b Self) -> Result { + let mut node_builder = KnowledgeGraphBuilder::new(); + let mut g1_old2new_id = utils::DenseMap::::new(); + let mut g2_old2new_id = utils::DenseMap::::new(); + let mut g1_old2new_domain_id = utils::DenseMap::::new(); + let mut g2_old2new_domain_id = utils::DenseMap::::new(); + + let add_domain = |node_builder: &mut KnowledgeGraphBuilder<'a>, + domain: DomainInfo<'a>| + -> Result { + node_builder.make_domain(domain.domain_identifier, domain.domain_description) + }; + + let add_node = |node_builder: &mut KnowledgeGraphBuilder<'a>, + node: &Node, + domains: Vec>| + -> Result { + match &node.node_type { + NodeType::Value(node_value) => { + node_builder.make_value_node(node_value.clone(), None, domains, None::<()>) + } + + NodeType::AllAggregator => { + Ok(node_builder.make_all_aggregator(&[], None, None::<()>, domains)?) + } + + NodeType::AnyAggregator => { + Ok(node_builder.make_any_aggregator(&[], None, None::<()>, Vec::new())?) + } + + NodeType::InAggregator(expected) => Ok(node_builder.make_in_aggregator( + expected.iter().cloned().collect(), + None, + None::<()>, + Vec::new(), + )?), + } + }; + + for (_old_domain_id, domain) in g1.domain.iter() { + let new_domain_id = add_domain(&mut node_builder, domain.clone())?; + g1_old2new_domain_id.push(new_domain_id); + } + + for (_old_domain_id, domain) in g2.domain.iter() { + let new_domain_id = add_domain(&mut node_builder, domain.clone())?; + g2_old2new_domain_id.push(new_domain_id); + } + + for (_old_node_id, node) in g1.nodes.iter() { + let mut domain_identifiers: Vec> = Vec::new(); + for domain_id in &node.domain_ids { + match g1.domain.get(*domain_id) { + Some(domain) => domain_identifiers.push(domain.domain_identifier.clone()), + None => return Err(GraphError::DomainNotFound), + } + } + let new_node_id = add_node(&mut node_builder, node, domain_identifiers.clone())?; + g1_old2new_id.push(new_node_id); + } + + for (_old_node_id, node) in g2.nodes.iter() { + let mut domain_identifiers: Vec> = Vec::new(); + for domain_id in &node.domain_ids { + match g2.domain.get(*domain_id) { + Some(domain) => domain_identifiers.push(domain.domain_identifier.clone()), + None => return Err(GraphError::DomainNotFound), + } + } + let new_node_id = add_node(&mut node_builder, node, domain_identifiers.clone())?; + g2_old2new_id.push(new_node_id); + } + + for edge in g1.edges.values() { + let new_pred_id = g1_old2new_id + .get(edge.pred) + .ok_or(GraphError::NodeNotFound)?; + let new_succ_id = g1_old2new_id + .get(edge.succ) + .ok_or(GraphError::NodeNotFound)?; + + node_builder.make_edge(*new_pred_id, *new_succ_id, edge.strength, edge.relation)?; + } + + for edge in g2.edges.values() { + let new_pred_id = g2_old2new_id + .get(edge.pred) + .ok_or(GraphError::NodeNotFound)?; + let new_succ_id = g2_old2new_id + .get(edge.succ) + .ok_or(GraphError::NodeNotFound)?; + + node_builder.make_edge(*new_pred_id, *new_succ_id, edge.strength, edge.relation)?; + } + + Ok(node_builder.build()) + } +} + +#[cfg(test)] +mod test { + #![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)] + + use euclid_macros::knowledge; + + use super::*; + use crate::{dirval, frontend::dir::enums}; + + #[test] + fn test_strong_positive_relation_success() { + let graph = knowledge! {crate + PaymentMethod(Card) ->> CaptureMethod(Automatic); + PaymentMethod(not Wallet) + & PaymentMethod(not PayLater) -> CaptureMethod(Automatic); + }; + let memo = &mut Memoization::new(); + let result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([ + dirval!(CaptureMethod = Automatic), + dirval!(PaymentMethod = Card), + ]), + memo, + ); + + assert!(result.is_ok()); + } + + #[test] + fn test_strong_positive_relation_failure() { + let graph = knowledge! {crate + PaymentMethod(Card) ->> CaptureMethod(Automatic); + PaymentMethod(not Wallet) -> CaptureMethod(Automatic); + }; + let memo = &mut Memoization::new(); + let result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([dirval!(CaptureMethod = Automatic)]), + memo, + ); + + assert!(result.is_err()); + } + + #[test] + fn test_strong_negative_relation_success() { + let graph = knowledge! {crate + PaymentMethod(Card) -> CaptureMethod(Automatic); + PaymentMethod(not Wallet) ->> CaptureMethod(Automatic); + }; + let memo = &mut Memoization::new(); + let result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([ + dirval!(CaptureMethod = Automatic), + dirval!(PaymentMethod = Card), + ]), + memo, + ); + + assert!(result.is_ok()); + } + + #[test] + fn test_strong_negative_relation_failure() { + let graph = knowledge! {crate + PaymentMethod(Card) -> CaptureMethod(Automatic); + PaymentMethod(not Wallet) ->> CaptureMethod(Automatic); + }; + let memo = &mut Memoization::new(); + let result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([ + dirval!(CaptureMethod = Automatic), + dirval!(PaymentMethod = Wallet), + ]), + memo, + ); + + assert!(result.is_err()); + } + + #[test] + fn test_normal_one_of_failure() { + let graph = knowledge! {crate + PaymentMethod(Card) -> CaptureMethod(Automatic); + PaymentMethod(Wallet) -> CaptureMethod(Automatic); + }; + let memo = &mut Memoization::new(); + let result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([ + dirval!(CaptureMethod = Automatic), + dirval!(PaymentMethod = PayLater), + ]), + memo, + ); + assert!(matches!( + *Weak::upgrade(&result.unwrap_err().get_analysis_trace().unwrap()) + .expect("Expected Arc"), + AnalysisTrace::Value { + predecessors: Some(ValueTracePredecessor::OneOf(_)), + .. + } + )); + } + + #[test] + fn test_all_aggregator_success() { + let graph = knowledge! {crate + PaymentMethod(Card) & PaymentMethod(not Wallet) -> CaptureMethod(Automatic); + }; + let memo = &mut Memoization::new(); + let result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([ + dirval!(PaymentMethod = Card), + dirval!(CaptureMethod = Automatic), + ]), + memo, + ); + + assert!(result.is_ok()); + } + + #[test] + fn test_all_aggregator_failure() { + let graph = knowledge! {crate + PaymentMethod(Card) & PaymentMethod(not Wallet) -> CaptureMethod(Automatic); + }; + let memo = &mut Memoization::new(); + let result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([ + dirval!(CaptureMethod = Automatic), + dirval!(PaymentMethod = PayLater), + ]), + memo, + ); + + assert!(result.is_err()); + } + + #[test] + fn test_all_aggregator_mandatory_failure() { + let graph = knowledge! {crate + PaymentMethod(Card) & PaymentMethod(not Wallet) ->> CaptureMethod(Automatic); + }; + let mut memo = Memoization::new(); + let result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([ + dirval!(CaptureMethod = Automatic), + dirval!(PaymentMethod = PayLater), + ]), + &mut memo, + ); + + assert!(matches!( + *Weak::upgrade(&result.unwrap_err().get_analysis_trace().unwrap()) + .expect("Expected Arc"), + AnalysisTrace::Value { + predecessors: Some(ValueTracePredecessor::Mandatory(_)), + .. + } + )); + } + + #[test] + fn test_in_aggregator_success() { + let graph = knowledge! {crate + PaymentMethod(in [Card, Wallet]) -> CaptureMethod(Automatic); + }; + let memo = &mut Memoization::new(); + let result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([ + dirval!(CaptureMethod = Automatic), + dirval!(PaymentMethod = Card), + dirval!(PaymentMethod = Wallet), + ]), + memo, + ); + + assert!(result.is_ok()); + } + + #[test] + fn test_in_aggregator_failure() { + let graph = knowledge! {crate + PaymentMethod(in [Card, Wallet]) -> CaptureMethod(Automatic); + }; + let memo = &mut Memoization::new(); + let result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([ + dirval!(CaptureMethod = Automatic), + dirval!(PaymentMethod = Card), + dirval!(PaymentMethod = Wallet), + dirval!(PaymentMethod = PayLater), + ]), + memo, + ); + + assert!(result.is_err()); + } + + #[test] + fn test_not_in_aggregator_success() { + let graph = knowledge! {crate + PaymentMethod(not in [Card, Wallet]) ->> CaptureMethod(Automatic); + }; + let memo = &mut Memoization::new(); + let result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([ + dirval!(CaptureMethod = Automatic), + dirval!(PaymentMethod = PayLater), + dirval!(PaymentMethod = BankRedirect), + ]), + memo, + ); + + assert!(result.is_ok()); + } + + #[test] + fn test_not_in_aggregator_failure() { + let graph = knowledge! {crate + PaymentMethod(not in [Card, Wallet]) ->> CaptureMethod(Automatic); + }; + let memo = &mut Memoization::new(); + let result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([ + dirval!(CaptureMethod = Automatic), + dirval!(PaymentMethod = PayLater), + dirval!(PaymentMethod = BankRedirect), + dirval!(PaymentMethod = Card), + ]), + memo, + ); + + assert!(result.is_err()); + } + + #[test] + fn test_in_aggregator_failure_trace() { + let graph = knowledge! {crate + PaymentMethod(in [Card, Wallet]) ->> CaptureMethod(Automatic); + }; + let memo = &mut Memoization::new(); + let result = graph.key_value_analysis( + dirval!(CaptureMethod = Automatic), + &AnalysisContext::from_dir_values([ + dirval!(CaptureMethod = Automatic), + dirval!(PaymentMethod = Card), + dirval!(PaymentMethod = Wallet), + dirval!(PaymentMethod = PayLater), + ]), + memo, + ); + + if let AnalysisTrace::Value { + predecessors: Some(ValueTracePredecessor::Mandatory(agg_error)), + .. + } = Weak::upgrade(&result.unwrap_err().get_analysis_trace().unwrap()) + .expect("Expected arc") + .deref() + { + assert!(matches!( + *Weak::upgrade(agg_error.deref()).expect("Expected Arc"), + AnalysisTrace::InAggregation { + found: Some(dir::DirValue::PaymentMethod(enums::PaymentMethod::PayLater)), + .. + } + )); + } else { + panic!("Failed unwrapping OnlyInAggregation trace from AnalysisTrace"); + } + } + + #[test] + fn _test_memoization_in_kgraph() { + let mut builder = KnowledgeGraphBuilder::new(); + let _node_1 = builder.make_value_node( + NodeValue::Value(dir::DirValue::PaymentMethod(enums::PaymentMethod::Wallet)), + None, + Vec::new(), + None::<()>, + ); + let _node_2 = builder.make_value_node( + NodeValue::Value(dir::DirValue::BillingCountry(enums::BillingCountry::India)), + None, + Vec::new(), + None::<()>, + ); + let _node_3 = builder.make_value_node( + NodeValue::Value(dir::DirValue::BusinessCountry( + enums::BusinessCountry::UnitedStatesOfAmerica, + )), + None, + Vec::new(), + None::<()>, + ); + let mut memo = Memoization::new(); + let _edge_1 = builder + .make_edge( + _node_1.expect("node1 constructtion failed"), + _node_2.clone().expect("node2 construction failed"), + Strength::Strong, + Relation::Positive, + ) + .expect("Failed to make an edge"); + let _edge_2 = builder + .make_edge( + _node_2.expect("node2 construction failed"), + _node_3.clone().expect("node3 construction failed"), + Strength::Strong, + Relation::Positive, + ) + .expect("Failed to an edge"); + let graph = builder.build(); + let _result = graph.key_value_analysis( + dirval!(BusinessCountry = UnitedStatesOfAmerica), + &AnalysisContext::from_dir_values([ + dirval!(PaymentMethod = Wallet), + dirval!(BillingCountry = India), + dirval!(BusinessCountry = UnitedStatesOfAmerica), + ]), + &mut memo, + ); + let _ans = memo + .0 + .get(&( + _node_3.expect("node3 construction failed"), + Relation::Positive, + Strength::Strong, + )) + .expect("Memoization not workng"); + matches!(_ans, Ok(())); + } +} diff --git a/crates/euclid/src/dssa/state_machine.rs b/crates/euclid/src/dssa/state_machine.rs new file mode 100644 index 000000000000..4cd53911dfe4 --- /dev/null +++ b/crates/euclid/src/dssa/state_machine.rs @@ -0,0 +1,714 @@ +use super::types::EuclidAnalysable; +use crate::{dssa::types, frontend::dir, types::Metadata}; + +#[derive(Debug, Clone, serde::Serialize, thiserror::Error)] +#[serde(tag = "type", content = "info", rename_all = "snake_case")] +pub enum StateMachineError { + #[error("Index out of bounds: {0}")] + IndexOutOfBounds(&'static str), +} + +#[derive(Debug)] +struct ComparisonStateMachine<'a> { + values: &'a [dir::DirValue], + logic: &'a dir::DirComparisonLogic, + metadata: &'a Metadata, + count: usize, + ctx_idx: usize, +} + +impl<'a> ComparisonStateMachine<'a> { + #[inline] + fn is_finished(&self) -> bool { + self.count + 1 >= self.values.len() + || matches!(self.logic, dir::DirComparisonLogic::NegativeConjunction) + } + + #[inline] + fn advance(&mut self) { + if let dir::DirComparisonLogic::PositiveDisjunction = self.logic { + self.count = (self.count + 1) % self.values.len(); + } + } + + #[inline] + fn reset(&mut self) { + self.count = 0; + } + + #[inline] + fn put(&self, context: &mut types::ConjunctiveContext<'a>) -> Result<(), StateMachineError> { + if let dir::DirComparisonLogic::PositiveDisjunction = self.logic { + *context + .get_mut(self.ctx_idx) + .ok_or(StateMachineError::IndexOutOfBounds( + "in ComparisonStateMachine while indexing into context", + ))? = types::ContextValue::assertion( + self.values + .get(self.count) + .ok_or(StateMachineError::IndexOutOfBounds( + "in ComparisonStateMachine while indexing into values", + ))?, + self.metadata, + ); + } + Ok(()) + } + + #[inline] + fn push(&self, context: &mut types::ConjunctiveContext<'a>) -> Result<(), StateMachineError> { + match self.logic { + dir::DirComparisonLogic::PositiveDisjunction => { + context.push(types::ContextValue::assertion( + self.values + .get(self.count) + .ok_or(StateMachineError::IndexOutOfBounds( + "in ComparisonStateMachine while pushing", + ))?, + self.metadata, + )); + } + + dir::DirComparisonLogic::NegativeConjunction => { + context.push(types::ContextValue::negation(self.values, self.metadata)); + } + } + Ok(()) + } +} + +#[derive(Debug)] +struct ConditionStateMachine<'a> { + state_machines: Vec>, + start_ctx_idx: usize, +} + +impl<'a> ConditionStateMachine<'a> { + fn new(condition: &'a [dir::DirComparison], start_idx: usize) -> Self { + let mut machines = Vec::>::with_capacity(condition.len()); + + let mut machine_idx = start_idx; + for cond in condition { + let machine = ComparisonStateMachine { + values: &cond.values, + logic: &cond.logic, + metadata: &cond.metadata, + count: 0, + ctx_idx: machine_idx, + }; + machines.push(machine); + machine_idx += 1; + } + + Self { + state_machines: machines, + start_ctx_idx: start_idx, + } + } + + fn init(&self, context: &mut types::ConjunctiveContext<'a>) -> Result<(), StateMachineError> { + for machine in &self.state_machines { + machine.push(context)?; + } + Ok(()) + } + + #[inline] + fn destroy(&self, context: &mut types::ConjunctiveContext<'a>) { + context.truncate(self.start_ctx_idx); + } + + #[inline] + fn is_finished(&self) -> bool { + !self + .state_machines + .iter() + .any(|machine| !machine.is_finished()) + } + + #[inline] + fn get_next_ctx_idx(&self) -> usize { + self.start_ctx_idx + self.state_machines.len() + } + + fn advance( + &mut self, + context: &mut types::ConjunctiveContext<'a>, + ) -> Result<(), StateMachineError> { + for machine in self.state_machines.iter_mut().rev() { + if machine.is_finished() { + machine.reset(); + machine.put(context)?; + } else { + machine.advance(); + machine.put(context)?; + break; + } + } + Ok(()) + } +} + +#[derive(Debug)] +struct IfStmtStateMachine<'a> { + condition_machine: ConditionStateMachine<'a>, + nested: Vec<&'a dir::DirIfStatement>, + nested_idx: usize, +} + +impl<'a> IfStmtStateMachine<'a> { + fn new(stmt: &'a dir::DirIfStatement, ctx_start_idx: usize) -> Self { + let condition_machine = ConditionStateMachine::new(&stmt.condition, ctx_start_idx); + let nested: Vec<&'a dir::DirIfStatement> = match &stmt.nested { + None => Vec::new(), + Some(nested_stmts) => nested_stmts.iter().collect(), + }; + + Self { + condition_machine, + nested, + nested_idx: 0, + } + } + + fn init( + &self, + context: &mut types::ConjunctiveContext<'a>, + ) -> Result, StateMachineError> { + self.condition_machine.init(context)?; + Ok(self + .nested + .first() + .map(|nested| Self::new(nested, self.condition_machine.get_next_ctx_idx()))) + } + + #[inline] + fn is_finished(&self) -> bool { + self.nested_idx + 1 >= self.nested.len() + } + + #[inline] + fn is_condition_machine_finished(&self) -> bool { + self.condition_machine.is_finished() + } + + #[inline] + fn destroy(&self, context: &mut types::ConjunctiveContext<'a>) { + self.condition_machine.destroy(context); + } + + #[inline] + fn advance_condition_machine( + &mut self, + context: &mut types::ConjunctiveContext<'a>, + ) -> Result<(), StateMachineError> { + self.condition_machine.advance(context)?; + Ok(()) + } + + fn advance(&mut self) -> Result, StateMachineError> { + if self.nested.is_empty() { + Ok(None) + } else { + self.nested_idx = (self.nested_idx + 1) % self.nested.len(); + Ok(Some(Self::new( + self.nested + .get(self.nested_idx) + .ok_or(StateMachineError::IndexOutOfBounds( + "in IfStmtStateMachine while advancing", + ))?, + self.condition_machine.get_next_ctx_idx(), + ))) + } + } +} + +#[derive(Debug)] +struct RuleStateMachine<'a> { + connector_selection_data: &'a [(dir::DirValue, Metadata)], + connectors_added: bool, + if_stmt_machines: Vec>, + running_stack: Vec>, +} + +impl<'a> RuleStateMachine<'a> { + fn new( + rule: &'a dir::DirRule, + connector_selection_data: &'a [(dir::DirValue, Metadata)], + ) -> Self { + let mut if_stmt_machines: Vec> = + Vec::with_capacity(rule.statements.len()); + + for stmt in rule.statements.iter().rev() { + if_stmt_machines.push(IfStmtStateMachine::new( + stmt, + connector_selection_data.len(), + )); + } + + Self { + connector_selection_data, + connectors_added: false, + if_stmt_machines, + running_stack: Vec::new(), + } + } + + fn is_finished(&self) -> bool { + self.if_stmt_machines.is_empty() && self.running_stack.is_empty() + } + + fn init_next( + &mut self, + context: &mut types::ConjunctiveContext<'a>, + ) -> Result<(), StateMachineError> { + if self.if_stmt_machines.is_empty() || !self.running_stack.is_empty() { + return Ok(()); + } + + if !self.connectors_added { + for (dir_val, metadata) in self.connector_selection_data { + context.push(types::ContextValue::assertion(dir_val, metadata)); + } + self.connectors_added = true; + } + + context.truncate(self.connector_selection_data.len()); + + if let Some(mut next_running) = self.if_stmt_machines.pop() { + while let Some(nested_running) = next_running.init(context)? { + self.running_stack.push(next_running); + next_running = nested_running; + } + + self.running_stack.push(next_running); + } + + Ok(()) + } + + fn advance( + &mut self, + context: &mut types::ConjunctiveContext<'a>, + ) -> Result<(), StateMachineError> { + let mut condition_machines_finished = true; + + for stmt_machine in self.running_stack.iter_mut().rev() { + if !stmt_machine.is_condition_machine_finished() { + condition_machines_finished = false; + stmt_machine.advance_condition_machine(context)?; + break; + } else { + stmt_machine.advance_condition_machine(context)?; + } + } + + if !condition_machines_finished { + return Ok(()); + } + + let mut maybe_next_running: Option> = None; + + while let Some(last) = self.running_stack.last_mut() { + if !last.is_finished() { + maybe_next_running = last.advance()?; + break; + } else { + last.destroy(context); + self.running_stack.pop(); + } + } + + if let Some(mut next_running) = maybe_next_running { + while let Some(nested_running) = next_running.init(context)? { + self.running_stack.push(next_running); + next_running = nested_running; + } + + self.running_stack.push(next_running); + } else { + self.init_next(context)?; + } + + Ok(()) + } +} + +#[derive(Debug)] +pub struct RuleContextManager<'a> { + context: types::ConjunctiveContext<'a>, + machine: RuleStateMachine<'a>, + init: bool, +} + +impl<'a> RuleContextManager<'a> { + pub fn new( + rule: &'a dir::DirRule, + connector_selection_data: &'a [(dir::DirValue, Metadata)], + ) -> Self { + Self { + context: Vec::new(), + machine: RuleStateMachine::new(rule, connector_selection_data), + init: false, + } + } + + pub fn advance(&mut self) -> Result>, StateMachineError> { + if !self.init { + self.init = true; + self.machine.init_next(&mut self.context)?; + Ok(Some(&self.context)) + } else if self.machine.is_finished() { + Ok(None) + } else { + self.machine.advance(&mut self.context)?; + + if self.machine.is_finished() { + Ok(None) + } else { + Ok(Some(&self.context)) + } + } + } + + pub fn advance_mut( + &mut self, + ) -> Result>, StateMachineError> { + if !self.init { + self.init = true; + self.machine.init_next(&mut self.context)?; + Ok(Some(&mut self.context)) + } else if self.machine.is_finished() { + Ok(None) + } else { + self.machine.advance(&mut self.context)?; + + if self.machine.is_finished() { + Ok(None) + } else { + Ok(Some(&mut self.context)) + } + } + } +} + +#[derive(Debug)] +pub struct ProgramStateMachine<'a> { + rule_machines: Vec>, + current_rule_machine: Option>, + is_init: bool, +} + +impl<'a> ProgramStateMachine<'a> { + pub fn new( + program: &'a dir::DirProgram, + connector_selection_data: &'a [Vec<(dir::DirValue, Metadata)>], + ) -> Self { + let mut rule_machines: Vec> = program + .rules + .iter() + .zip(connector_selection_data.iter()) + .rev() + .map(|(rule, connector_selection_data)| { + RuleStateMachine::new(rule, connector_selection_data) + }) + .collect(); + + Self { + current_rule_machine: rule_machines.pop(), + rule_machines, + is_init: false, + } + } + + pub fn is_finished(&self) -> bool { + self.current_rule_machine + .as_ref() + .map_or(true, |rsm| rsm.is_finished()) + && self.rule_machines.is_empty() + } + + pub fn init( + &mut self, + context: &mut types::ConjunctiveContext<'a>, + ) -> Result<(), StateMachineError> { + if !self.is_init { + if let Some(rsm) = self.current_rule_machine.as_mut() { + rsm.init_next(context)?; + } + self.is_init = true; + } + + Ok(()) + } + + pub fn advance( + &mut self, + context: &mut types::ConjunctiveContext<'a>, + ) -> Result<(), StateMachineError> { + if self + .current_rule_machine + .as_ref() + .map_or(true, |rsm| rsm.is_finished()) + { + self.current_rule_machine = self.rule_machines.pop(); + context.clear(); + if let Some(rsm) = self.current_rule_machine.as_mut() { + rsm.init_next(context)?; + } + } else if let Some(rsm) = self.current_rule_machine.as_mut() { + rsm.advance(context)?; + } + + Ok(()) + } +} + +pub struct AnalysisContextManager<'a> { + context: types::ConjunctiveContext<'a>, + machine: ProgramStateMachine<'a>, + init: bool, +} + +impl<'a> AnalysisContextManager<'a> { + pub fn new( + program: &'a dir::DirProgram, + connector_selection_data: &'a [Vec<(dir::DirValue, Metadata)>], + ) -> Self { + let machine = ProgramStateMachine::new(program, connector_selection_data); + let context: types::ConjunctiveContext<'a> = Vec::new(); + + Self { + context, + machine, + init: false, + } + } + + pub fn advance(&mut self) -> Result>, StateMachineError> { + if !self.init { + self.init = true; + self.machine.init(&mut self.context)?; + Ok(Some(&self.context)) + } else if self.machine.is_finished() { + Ok(None) + } else { + self.machine.advance(&mut self.context)?; + + if self.machine.is_finished() { + Ok(None) + } else { + Ok(Some(&self.context)) + } + } + } +} + +pub fn make_connector_selection_data( + program: &dir::DirProgram, +) -> Vec> { + program + .rules + .iter() + .map(|rule| { + rule.connector_selection + .get_dir_value_for_analysis(rule.name.clone()) + }) + .collect() +} + +#[cfg(all(test, feature = "ast_parser"))] +mod tests { + #![allow(clippy::expect_used)] + + use super::*; + use crate::{dirval, frontend::ast, types::DummyOutput}; + + #[test] + fn test_correct_contexts() { + let program_str = r#" + default: ["stripe", "adyen"] + + stripe_first: ["stripe", "adyen"] + { + payment_method = wallet { + payment_method = (card, bank_redirect) { + currency = USD + currency = GBP + } + + payment_method = pay_later { + capture_method = automatic + capture_method = manual + } + } + + payment_method = card { + payment_method = (card, bank_redirect) & capture_method = (automatic, manual) { + currency = (USD, GBP) + } + } + } + "#; + let (_, program) = ast::parser::program::(program_str).expect("Program"); + let lowered = ast::lowering::lower_program(program).expect("Lowering"); + + let selection_data = make_connector_selection_data(&lowered); + let mut state_machine = ProgramStateMachine::new(&lowered, &selection_data); + let mut ctx: types::ConjunctiveContext<'_> = Vec::new(); + state_machine.init(&mut ctx).expect("State machine init"); + + let expected_contexts: Vec> = vec![ + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Wallet), + dirval!(PaymentMethod = Card), + dirval!(PaymentCurrency = USD), + ], + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Wallet), + dirval!(PaymentMethod = BankRedirect), + dirval!(PaymentCurrency = USD), + ], + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Wallet), + dirval!(PaymentMethod = Card), + dirval!(PaymentCurrency = GBP), + ], + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Wallet), + dirval!(PaymentMethod = BankRedirect), + dirval!(PaymentCurrency = GBP), + ], + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Wallet), + dirval!(PaymentMethod = PayLater), + dirval!(CaptureMethod = Automatic), + ], + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Wallet), + dirval!(PaymentMethod = PayLater), + dirval!(CaptureMethod = Manual), + ], + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Card), + dirval!(PaymentMethod = Card), + dirval!(CaptureMethod = Automatic), + dirval!(PaymentCurrency = USD), + ], + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Card), + dirval!(PaymentMethod = Card), + dirval!(CaptureMethod = Automatic), + dirval!(PaymentCurrency = GBP), + ], + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Card), + dirval!(PaymentMethod = Card), + dirval!(CaptureMethod = Manual), + dirval!(PaymentCurrency = USD), + ], + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Card), + dirval!(PaymentMethod = Card), + dirval!(CaptureMethod = Manual), + dirval!(PaymentCurrency = GBP), + ], + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Card), + dirval!(PaymentMethod = BankRedirect), + dirval!(CaptureMethod = Automatic), + dirval!(PaymentCurrency = USD), + ], + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Card), + dirval!(PaymentMethod = BankRedirect), + dirval!(CaptureMethod = Automatic), + dirval!(PaymentCurrency = GBP), + ], + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Card), + dirval!(PaymentMethod = BankRedirect), + dirval!(CaptureMethod = Manual), + dirval!(PaymentCurrency = USD), + ], + vec![ + dirval!("MetadataKey" = "stripe"), + dirval!("MetadataKey" = "adyen"), + dirval!(PaymentMethod = Card), + dirval!(PaymentMethod = BankRedirect), + dirval!(CaptureMethod = Manual), + dirval!(PaymentCurrency = GBP), + ], + ]; + + let mut expected_idx = 0usize; + while !state_machine.is_finished() { + let values = ctx + .iter() + .flat_map(|c| match c.value { + types::CtxValueKind::Assertion(val) => vec![val], + types::CtxValueKind::Negation(vals) => vals.iter().collect(), + }) + .collect::>(); + assert_eq!( + values, + expected_contexts[expected_idx] + .iter() + .collect::>() + ); + expected_idx += 1; + state_machine + .advance(&mut ctx) + .expect("State Machine advance"); + } + + assert_eq!(expected_idx, 14); + + let mut ctx_manager = AnalysisContextManager::new(&lowered, &selection_data); + expected_idx = 0; + while let Some(ctx) = ctx_manager.advance().expect("Context Manager Context") { + let values = ctx + .iter() + .flat_map(|c| match c.value { + types::CtxValueKind::Assertion(val) => vec![val], + types::CtxValueKind::Negation(vals) => vals.iter().collect(), + }) + .collect::>(); + assert_eq!( + values, + expected_contexts[expected_idx] + .iter() + .collect::>() + ); + expected_idx += 1; + } + + assert_eq!(expected_idx, 14); + } +} diff --git a/crates/euclid/src/dssa/truth.rs b/crates/euclid/src/dssa/truth.rs new file mode 100644 index 000000000000..17e6e728e68f --- /dev/null +++ b/crates/euclid/src/dssa/truth.rs @@ -0,0 +1,29 @@ +use euclid_macros::knowledge; +use once_cell::sync::Lazy; + +use crate::dssa::graph; + +pub static ANALYSIS_GRAPH: Lazy> = Lazy::new(|| { + knowledge! {crate + // Payment Method should be `Card` for a CardType to be present + PaymentMethod(Card) ->> CardType(any); + + // Payment Method should be `PayLater` for a PayLaterType to be present + PaymentMethod(PayLater) ->> PayLaterType(any); + + // Payment Method should be `Wallet` for a WalletType to be present + PaymentMethod(Wallet) ->> WalletType(any); + + // Payment Method should be `BankRedirect` for a BankRedirectType to + // be present + PaymentMethod(BankRedirect) ->> BankRedirectType(any); + + // Payment Method should be `BankTransfer` for a BankTransferType to + // be present + PaymentMethod(BankTransfer) ->> BankTransferType(any); + + // Payment Method should be `GiftCard` for a GiftCardType to + // be present + PaymentMethod(GiftCard) ->> GiftCardType(any); + } +}); diff --git a/crates/euclid/src/dssa/types.rs b/crates/euclid/src/dssa/types.rs new file mode 100644 index 000000000000..4070e0825ef7 --- /dev/null +++ b/crates/euclid/src/dssa/types.rs @@ -0,0 +1,158 @@ +use std::fmt; + +use serde::Serialize; + +use crate::{ + dssa::{self, graph}, + frontend::{ast, dir}, + types::{DataType, EuclidValue, Metadata}, +}; + +pub trait EuclidAnalysable: Sized { + fn get_dir_value_for_analysis(&self, rule_name: String) -> Vec<(dir::DirValue, Metadata)>; +} + +#[derive(Debug, Clone)] +pub enum CtxValueKind<'a> { + Assertion(&'a dir::DirValue), + Negation(&'a [dir::DirValue]), +} + +impl<'a> CtxValueKind<'a> { + pub fn get_assertion(&self) -> Option<&dir::DirValue> { + if let Self::Assertion(val) = self { + Some(val) + } else { + None + } + } + + pub fn get_negation(&self) -> Option<&[dir::DirValue]> { + if let Self::Negation(vals) = self { + Some(vals) + } else { + None + } + } + + pub fn get_key(&self) -> Option { + match self { + Self::Assertion(val) => Some(val.get_key()), + Self::Negation(vals) => vals.first().map(|v| (*v).get_key()), + } + } +} + +#[derive(Debug, Clone)] +pub struct ContextValue<'a> { + pub value: CtxValueKind<'a>, + pub metadata: &'a Metadata, +} + +impl<'a> ContextValue<'a> { + #[inline] + pub fn assertion(value: &'a dir::DirValue, metadata: &'a Metadata) -> Self { + Self { + value: CtxValueKind::Assertion(value), + metadata, + } + } + + #[inline] + pub fn negation(values: &'a [dir::DirValue], metadata: &'a Metadata) -> Self { + Self { + value: CtxValueKind::Negation(values), + metadata, + } + } +} + +pub type ConjunctiveContext<'a> = Vec>; + +#[derive(Clone, Serialize)] +pub enum AnalyzeResult { + AllOk, +} + +#[derive(Debug, Clone, Serialize, thiserror::Error)] +pub struct AnalysisError { + #[serde(flatten)] + pub error_type: AnalysisErrorType, + pub metadata: Metadata, +} +impl fmt::Display for AnalysisError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.error_type.fmt(f) + } +} +#[derive(Debug, Clone, Serialize)] +pub struct ValueData { + pub value: dir::DirValue, + pub metadata: Metadata, +} + +#[derive(Debug, Clone, Serialize, thiserror::Error)] +#[serde(tag = "type", content = "info", rename_all = "snake_case")] +pub enum AnalysisErrorType { + #[error("Invalid program key given: '{0}'")] + InvalidKey(String), + #[error("Invalid variant '{got}' received for key '{key}'")] + InvalidVariant { + key: String, + expected: Vec, + got: String, + }, + #[error( + "Invalid data type for value '{}' (expected {expected}, got {got})", + key + )] + InvalidType { + key: String, + expected: DataType, + got: DataType, + }, + #[error("Invalid comparison '{operator:?}' for value type {value_type}")] + InvalidComparison { + operator: ast::ComparisonType, + value_type: DataType, + }, + #[error("Invalid value received for length as '{value}: {:?}'", message)] + InvalidValue { + key: dir::DirKeyKind, + value: String, + message: Option, + }, + #[error("Conflicting assertions received for key '{}'", .key.kind)] + ConflictingAssertions { + key: dir::DirKey, + values: Vec, + }, + + #[error("Key '{}' exhaustively negated", .key.kind)] + ExhaustiveNegation { + key: dir::DirKey, + metadata: Vec, + }, + #[error("The condition '{value}' was asserted and negated in the same condition")] + NegatedAssertion { + value: dir::DirValue, + assertion_metadata: Metadata, + negation_metadata: Metadata, + }, + #[error("Graph analysis error: {0:#?}")] + GraphAnalysis(graph::AnalysisError, graph::Memoization), + #[error("State machine error")] + StateMachine(dssa::state_machine::StateMachineError), + #[error("Unsupported program key '{0}'")] + UnsupportedProgramKey(dir::DirKeyKind), + #[error("Ran into an unimplemented feature")] + NotImplemented, + #[error("The payment method type is not supported under the payment method")] + NotSupported, +} + +#[derive(Debug, Clone)] +pub enum ValueType { + EnumVariants(Vec), + Number, +} diff --git a/crates/euclid/src/dssa/utils.rs b/crates/euclid/src/dssa/utils.rs new file mode 100644 index 000000000000..df4ff82cbdb7 --- /dev/null +++ b/crates/euclid/src/dssa/utils.rs @@ -0,0 +1 @@ +pub struct Unpacker; diff --git a/crates/euclid/src/enums.rs b/crates/euclid/src/enums.rs new file mode 100644 index 000000000000..4188860ab90f --- /dev/null +++ b/crates/euclid/src/enums.rs @@ -0,0 +1,191 @@ +pub use common_enums::{ + AuthenticationType, CaptureMethod, CardNetwork, Country, Currency, + FutureUsage as SetupFutureUsage, PaymentMethod, PaymentMethodType, +}; +use serde::{Deserialize, Serialize}; +use strum::VariantNames; + +pub trait CollectVariants { + fn variants>() -> T; +} +macro_rules! collect_variants { + ($the_enum:ident) => { + impl $crate::enums::CollectVariants for $the_enum { + fn variants() -> T + where + T: FromIterator, + { + Self::VARIANTS.iter().map(|s| String::from(*s)).collect() + } + } + }; +} + +pub(crate) use collect_variants; + +collect_variants!(PaymentMethod); +collect_variants!(PaymentType); +collect_variants!(MandateType); +collect_variants!(MandateAcceptanceType); +collect_variants!(PaymentMethodType); +collect_variants!(CardNetwork); +collect_variants!(AuthenticationType); +collect_variants!(CaptureMethod); +collect_variants!(Currency); +collect_variants!(Country); +collect_variants!(Connector); +collect_variants!(SetupFutureUsage); + +#[derive( + Debug, + Copy, + Clone, + PartialEq, + Eq, + Hash, + Serialize, + Deserialize, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + frunk::LabelledGeneric, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum Connector { + #[cfg(feature = "dummy_connector")] + #[serde(rename = "phonypay")] + #[strum(serialize = "phonypay")] + DummyConnector1, + #[cfg(feature = "dummy_connector")] + #[serde(rename = "fauxpay")] + #[strum(serialize = "fauxpay")] + DummyConnector2, + #[cfg(feature = "dummy_connector")] + #[serde(rename = "pretendpay")] + #[strum(serialize = "pretendpay")] + DummyConnector3, + #[cfg(feature = "dummy_connector")] + #[serde(rename = "stripe_test")] + #[strum(serialize = "stripe_test")] + DummyConnector4, + #[cfg(feature = "dummy_connector")] + #[serde(rename = "adyen_test")] + #[strum(serialize = "adyen_test")] + DummyConnector5, + #[cfg(feature = "dummy_connector")] + #[serde(rename = "checkout_test")] + #[strum(serialize = "checkout_test")] + DummyConnector6, + #[cfg(feature = "dummy_connector")] + #[serde(rename = "paypal_test")] + #[strum(serialize = "paypal_test")] + DummyConnector7, + Aci, + Adyen, + Airwallex, + Authorizedotnet, + Bitpay, + Bambora, + Bluesnap, + Boku, + Braintree, + Cashtocode, + Checkout, + Coinbase, + Cryptopay, + Cybersource, + Dlocal, + Fiserv, + Forte, + Globalpay, + Globepay, + Gocardless, + Helcim, + Iatapay, + Klarna, + Mollie, + Multisafepay, + Nexinets, + Nmi, + Noon, + Nuvei, + Opennode, + Payme, + Paypal, + Payu, + Powertranz, + Rapyd, + Shift4, + Square, + Stax, + Stripe, + Trustpay, + Tsys, + Volt, + Wise, + Worldline, + Worldpay, + Zen, +} + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum MandateAcceptanceType { + Online, + Offline, +} + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum PaymentType { + SetupMandate, + NonMandate, +} + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum MandateType { + SingleUse, + MultiUse, +} diff --git a/crates/euclid/src/frontend.rs b/crates/euclid/src/frontend.rs new file mode 100644 index 000000000000..17fc8f3502e2 --- /dev/null +++ b/crates/euclid/src/frontend.rs @@ -0,0 +1,3 @@ +pub mod ast; +pub mod dir; +pub mod vir; diff --git a/crates/euclid/src/frontend/ast.rs b/crates/euclid/src/frontend/ast.rs new file mode 100644 index 000000000000..3adb06ab1873 --- /dev/null +++ b/crates/euclid/src/frontend/ast.rs @@ -0,0 +1,156 @@ +pub mod lowering; +#[cfg(feature = "ast_parser")] +pub mod parser; + +use serde::{Deserialize, Serialize}; + +use crate::{ + enums::Connector, + types::{DataType, Metadata}, +}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct ConnectorChoice { + pub connector: Connector, + #[cfg(not(feature = "connector_choice_mca_id"))] + pub sub_label: Option, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct MetadataValue { + pub key: String, + pub value: String, +} + +/// Represents a value in the DSL +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", content = "value", rename_all = "snake_case")] +pub enum ValueType { + /// Represents a number literal + Number(i64), + /// Represents an enum variant + EnumVariant(String), + /// Represents a Metadata variant + MetadataVariant(MetadataValue), + /// Represents a arbitrary String value + StrValue(String), + /// Represents an array of numbers. This is basically used for + /// "one of the given numbers" operations + /// eg: payment.method.amount = (1, 2, 3) + NumberArray(Vec), + /// Similar to NumberArray but for enum variants + /// eg: payment.method.cardtype = (debit, credit) + EnumVariantArray(Vec), + /// Like a number array but can include comparisons. Useful for + /// conditions like "500 < amount < 1000" + /// eg: payment.amount = (> 500, < 1000) + NumberComparisonArray(Vec), +} + +impl ValueType { + pub fn get_type(&self) -> DataType { + match self { + Self::Number(_) => DataType::Number, + Self::StrValue(_) => DataType::StrValue, + Self::MetadataVariant(_) => DataType::MetadataValue, + Self::EnumVariant(_) => DataType::EnumVariant, + Self::NumberComparisonArray(_) => DataType::Number, + Self::NumberArray(_) => DataType::Number, + Self::EnumVariantArray(_) => DataType::EnumVariant, + } + } +} + +/// Represents a number comparison for "NumberComparisonArrayValue" +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct NumberComparison { + pub comparison_type: ComparisonType, + pub number: i64, +} + +/// Conditional comparison type +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ComparisonType { + Equal, + NotEqual, + LessThan, + LessThanEqual, + GreaterThan, + GreaterThanEqual, +} + +/// Represents a single comparison condition. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Comparison { + /// The left hand side which will always be a domain input identifier like "payment.method.cardtype" + pub lhs: String, + /// The comparison operator + pub comparison: ComparisonType, + /// The value to compare against + pub value: ValueType, + /// Additional metadata that the Static Analyzer and Backend does not touch. + /// This can be used to store useful information for the frontend and is required for communication + /// between the static analyzer and the frontend. + pub metadata: Metadata, +} + +/// Represents all the conditions of an IF statement +/// eg: +/// +/// ```text +/// payment.method = card & payment.method.cardtype = debit & payment.method.network = diners +/// ``` +pub type IfCondition = Vec; + +/// Represents an IF statement with conditions and optional nested IF statements +/// +/// ```text +/// payment.method = card { +/// payment.method.cardtype = (credit, debit) { +/// payment.method.network = (amex, rupay, diners) +/// } +/// } +/// ``` +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct IfStatement { + pub condition: IfCondition, + pub nested: Option>, +} + +/// Represents a rule +/// +/// ```text +/// rule_name: [stripe, adyen, checkout] +/// { +/// payment.method = card { +/// payment.method.cardtype = (credit, debit) { +/// payment.method.network = (amex, rupay, diners) +/// } +/// +/// payment.method.cardtype = credit +/// } +/// } +/// ``` + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Rule { + pub name: String, + #[serde(alias = "routingOutput")] + pub connector_selection: O, + pub statements: Vec, +} + +/// The program, having a default connector selection and +/// a bunch of rules. Also can hold arbitrary metadata. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Program { + pub default_selection: O, + pub rules: Vec>, + pub metadata: Metadata, +} diff --git a/crates/euclid/src/frontend/ast/lowering.rs b/crates/euclid/src/frontend/ast/lowering.rs new file mode 100644 index 000000000000..ffce88a35db6 --- /dev/null +++ b/crates/euclid/src/frontend/ast/lowering.rs @@ -0,0 +1,377 @@ +//! Analysis for the Lowering logic in ast +//! +//!Certain functions that can be used to perform the complete lowering of ast to dir. +//!This includes lowering of enums, numbers, strings as well as Comparison logics. + +use std::str::FromStr; + +use crate::{ + dssa::types::{AnalysisError, AnalysisErrorType}, + enums::CollectVariants, + frontend::{ + ast, + dir::{self, enums as dir_enums, EuclidDirFilter}, + }, + types::{self, DataType}, +}; + +/// lowers the provided key (enum variant) & value to the respective DirValue +/// +/// For example +/// ```notrust +/// CardType = Visa +/// ```notrust +/// +/// This serves for the purpose were we have the DirKey as an explicit Enum type and value as one +/// of the member of the same Enum. +/// So particularly it lowers a predefined Enum from DirKey to an Enum of DirValue. + +macro_rules! lower_enum { + ($key:ident, $value:ident) => { + match $value { + ast::ValueType::EnumVariant(ev) => Ok(vec![dir::DirValue::$key( + dir_enums::$key::from_str(&ev).map_err(|_| AnalysisErrorType::InvalidVariant { + key: dir::DirKeyKind::$key.to_string(), + got: ev, + expected: dir_enums::$key::variants(), + })?, + )]), + + ast::ValueType::EnumVariantArray(eva) => eva + .into_iter() + .map(|ev| { + Ok(dir::DirValue::$key( + dir_enums::$key::from_str(&ev).map_err(|_| { + AnalysisErrorType::InvalidVariant { + key: dir::DirKeyKind::$key.to_string(), + got: ev, + expected: dir_enums::$key::variants(), + } + })?, + )) + }) + .collect(), + + _ => Err(AnalysisErrorType::InvalidType { + key: dir::DirKeyKind::$key.to_string(), + expected: DataType::EnumVariant, + got: $value.get_type(), + }), + } + }; +} + +/// lowers the provided key for a numerical value +/// +/// For example +/// ```notrust +/// payment_amount = 17052001 +/// ```notrust +/// This is for the cases in which there are numerical values involved and they are lowered +/// accordingly on basis of the supplied key, currently payment_amount is the only key having this +/// use case + +macro_rules! lower_number { + ($key:ident, $value:ident, $comp:ident) => { + match $value { + ast::ValueType::Number(num) => Ok(vec![dir::DirValue::$key(types::NumValue { + number: num, + refinement: $comp.into(), + })]), + + ast::ValueType::NumberArray(na) => na + .into_iter() + .map(|num| { + Ok(dir::DirValue::$key(types::NumValue { + number: num, + refinement: $comp.clone().into(), + })) + }) + .collect(), + + ast::ValueType::NumberComparisonArray(nca) => nca + .into_iter() + .map(|nc| { + Ok(dir::DirValue::$key(types::NumValue { + number: nc.number, + refinement: nc.comparison_type.into(), + })) + }) + .collect(), + + _ => Err(AnalysisErrorType::InvalidType { + key: dir::DirKeyKind::$key.to_string(), + expected: DataType::Number, + got: $value.get_type(), + }), + } + }; +} + +/// lowers the provided key & value to the respective DirValue +/// +/// For example +/// ```notrust +/// card_bin = "123456" +/// ```notrust +/// +/// This serves for the purpose were we have the DirKey as Card_bin and value as an arbitrary string +/// So particularly it lowers an arbitrary value to a predefined key. + +macro_rules! lower_str { + ($key:ident, $value:ident $(, $validation_closure:expr)?) => { + match $value { + ast::ValueType::StrValue(st) => { + $($validation_closure(&st)?;)? + Ok(vec![dir::DirValue::$key(types::StrValue { value: st })]) + } + _ => Err(AnalysisErrorType::InvalidType { + key: dir::DirKeyKind::$key.to_string(), + expected: DataType::StrValue, + got: $value.get_type(), + }), + } + }; +} + +macro_rules! lower_metadata { + ($key:ident, $value:ident) => { + match $value { + ast::ValueType::MetadataVariant(md) => { + Ok(vec![dir::DirValue::$key(types::MetadataValue { + key: md.key, + value: md.value, + })]) + } + _ => Err(AnalysisErrorType::InvalidType { + key: dir::DirKeyKind::$key.to_string(), + expected: DataType::MetadataValue, + got: $value.get_type(), + }), + } + }; +} +/// lowers the comparison operators for different subtle value types present +/// by throwing required errors for comparisons that can't be performed for a certain value type +/// for example +/// can't have greater/less than operations on enum types + +fn lower_comparison_inner( + comp: ast::Comparison, +) -> Result, AnalysisErrorType> { + let key_enum = dir::DirKeyKind::from_str(comp.lhs.as_str()) + .map_err(|_| AnalysisErrorType::InvalidKey(comp.lhs.clone()))?; + + if !O::is_key_allowed(&key_enum) { + return Err(AnalysisErrorType::InvalidKey(key_enum.to_string())); + } + + match (&comp.comparison, &comp.value) { + ( + ast::ComparisonType::LessThan + | ast::ComparisonType::GreaterThan + | ast::ComparisonType::GreaterThanEqual + | ast::ComparisonType::LessThanEqual, + ast::ValueType::EnumVariant(_), + ) => { + Err(AnalysisErrorType::InvalidComparison { + operator: comp.comparison.clone(), + value_type: DataType::EnumVariant, + })?; + } + + ( + ast::ComparisonType::LessThan + | ast::ComparisonType::GreaterThan + | ast::ComparisonType::GreaterThanEqual + | ast::ComparisonType::LessThanEqual, + ast::ValueType::NumberArray(_), + ) => { + Err(AnalysisErrorType::InvalidComparison { + operator: comp.comparison.clone(), + value_type: DataType::Number, + })?; + } + + ( + ast::ComparisonType::LessThan + | ast::ComparisonType::GreaterThan + | ast::ComparisonType::GreaterThanEqual + | ast::ComparisonType::LessThanEqual, + ast::ValueType::EnumVariantArray(_), + ) => { + Err(AnalysisErrorType::InvalidComparison { + operator: comp.comparison.clone(), + value_type: DataType::EnumVariant, + })?; + } + + ( + ast::ComparisonType::LessThan + | ast::ComparisonType::GreaterThan + | ast::ComparisonType::GreaterThanEqual + | ast::ComparisonType::LessThanEqual, + ast::ValueType::NumberComparisonArray(_), + ) => { + Err(AnalysisErrorType::InvalidComparison { + operator: comp.comparison.clone(), + value_type: DataType::Number, + })?; + } + + _ => {} + } + + let value = comp.value; + let comparison = comp.comparison; + + match key_enum { + dir::DirKeyKind::PaymentMethod => lower_enum!(PaymentMethod, value), + + dir::DirKeyKind::CardType => lower_enum!(CardType, value), + + dir::DirKeyKind::CardNetwork => lower_enum!(CardNetwork, value), + + dir::DirKeyKind::PayLaterType => lower_enum!(PayLaterType, value), + + dir::DirKeyKind::WalletType => lower_enum!(WalletType, value), + + dir::DirKeyKind::BankDebitType => lower_enum!(BankDebitType, value), + + dir::DirKeyKind::BankRedirectType => lower_enum!(BankRedirectType, value), + + dir::DirKeyKind::CryptoType => lower_enum!(CryptoType, value), + + dir::DirKeyKind::PaymentType => lower_enum!(PaymentType, value), + + dir::DirKeyKind::MandateType => lower_enum!(MandateType, value), + + dir::DirKeyKind::MandateAcceptanceType => lower_enum!(MandateAcceptanceType, value), + + dir::DirKeyKind::RewardType => lower_enum!(RewardType, value), + + dir::DirKeyKind::PaymentCurrency => lower_enum!(PaymentCurrency, value), + + dir::DirKeyKind::AuthenticationType => lower_enum!(AuthenticationType, value), + + dir::DirKeyKind::CaptureMethod => lower_enum!(CaptureMethod, value), + + dir::DirKeyKind::BusinessCountry => lower_enum!(BusinessCountry, value), + + dir::DirKeyKind::BillingCountry => lower_enum!(BillingCountry, value), + + dir::DirKeyKind::SetupFutureUsage => lower_enum!(SetupFutureUsage, value), + + dir::DirKeyKind::UpiType => lower_enum!(UpiType, value), + + dir::DirKeyKind::VoucherType => lower_enum!(VoucherType, value), + + dir::DirKeyKind::GiftCardType => lower_enum!(GiftCardType, value), + + dir::DirKeyKind::BankTransferType => lower_enum!(BankTransferType, value), + + dir::DirKeyKind::CardRedirectType => lower_enum!(CardRedirectType, value), + + dir::DirKeyKind::CardBin => { + let validation_closure = |st: &String| -> Result<(), AnalysisErrorType> { + if st.len() == 6 && st.chars().all(|x| x.is_ascii_digit()) { + Ok(()) + } else { + Err(AnalysisErrorType::InvalidValue { + key: dir::DirKeyKind::CardBin, + value: st.clone(), + message: Some("Expected 6 digits".to_string()), + }) + } + }; + lower_str!(CardBin, value, validation_closure) + } + + dir::DirKeyKind::BusinessLabel => lower_str!(BusinessLabel, value), + + dir::DirKeyKind::MetaData => lower_metadata!(MetaData, value), + + dir::DirKeyKind::PaymentAmount => lower_number!(PaymentAmount, value, comparison), + + dir::DirKeyKind::Connector => Err(AnalysisErrorType::InvalidKey( + dir::DirKeyKind::Connector.to_string(), + )), + } +} + +/// returns all the comparison values by matching them appropriately to ComparisonTypes and in turn +/// calls the lower_comparison_inner function +fn lower_comparison( + comp: ast::Comparison, +) -> Result { + let metadata = comp.metadata.clone(); + let logic = match &comp.comparison { + ast::ComparisonType::Equal => dir::DirComparisonLogic::PositiveDisjunction, + ast::ComparisonType::NotEqual => dir::DirComparisonLogic::NegativeConjunction, + ast::ComparisonType::LessThan => dir::DirComparisonLogic::PositiveDisjunction, + ast::ComparisonType::LessThanEqual => dir::DirComparisonLogic::PositiveDisjunction, + ast::ComparisonType::GreaterThanEqual => dir::DirComparisonLogic::PositiveDisjunction, + ast::ComparisonType::GreaterThan => dir::DirComparisonLogic::PositiveDisjunction, + }; + let values = lower_comparison_inner::(comp).map_err(|etype| AnalysisError { + error_type: etype, + metadata: metadata.clone(), + })?; + + Ok(dir::DirComparison { + values, + logic, + metadata, + }) +} + +/// lowers the if statement accordingly with a condition and following nested if statements (if +/// present) +fn lower_if_statement( + stmt: ast::IfStatement, +) -> Result { + Ok(dir::DirIfStatement { + condition: stmt + .condition + .into_iter() + .map(lower_comparison::) + .collect::>()?, + nested: stmt + .nested + .map(|n| n.into_iter().map(lower_if_statement::).collect()) + .transpose()?, + }) +} + +/// lowers the rules supplied accordingly to DirRule struct by specifying the rule_name, +/// connector_selection and statements that are a bunch of if statements +pub fn lower_rule( + rule: ast::Rule, +) -> Result, AnalysisError> { + Ok(dir::DirRule { + name: rule.name, + connector_selection: rule.connector_selection, + statements: rule + .statements + .into_iter() + .map(lower_if_statement::) + .collect::>()?, + }) +} + +/// uses the above rules and lowers the whole ast Program into DirProgram by specifying +/// default_selection that is ast ConnectorSelection, a vector of DirRules and clones the metadata +/// whatever comes in the ast_program +pub fn lower_program( + program: ast::Program, +) -> Result, AnalysisError> { + Ok(dir::DirProgram { + default_selection: program.default_selection, + rules: program + .rules + .into_iter() + .map(lower_rule) + .collect::>()?, + metadata: program.metadata, + }) +} diff --git a/crates/euclid/src/frontend/ast/parser.rs b/crates/euclid/src/frontend/ast/parser.rs new file mode 100644 index 000000000000..8b2f717a8688 --- /dev/null +++ b/crates/euclid/src/frontend/ast/parser.rs @@ -0,0 +1,441 @@ +use nom::{ + branch, bytes::complete, character::complete as pchar, combinator, error, multi, sequence, +}; + +use crate::{frontend::ast, types::DummyOutput}; +pub type ParseResult = nom::IResult>; + +pub enum EuclidError { + InvalidPercentage(String), + InvalidConnector(String), + InvalidOperator(String), + InvalidNumber(String), +} + +pub trait EuclidParsable: Sized { + fn parse_output(input: &str) -> ParseResult<&str, Self>; +} + +impl EuclidParsable for DummyOutput { + fn parse_output(input: &str) -> ParseResult<&str, Self> { + let string_w = sequence::delimited( + skip_ws(complete::tag("\"")), + complete::take_while(|c| c != '"'), + skip_ws(complete::tag("\"")), + ); + let full_sequence = multi::many0(sequence::preceded( + skip_ws(complete::tag(",")), + sequence::delimited( + skip_ws(complete::tag("\"")), + complete::take_while(|c| c != '"'), + skip_ws(complete::tag("\"")), + ), + )); + let sequence = sequence::pair(string_w, full_sequence); + error::context( + "dummy_strings", + combinator::map( + sequence::delimited( + skip_ws(complete::tag("[")), + sequence, + skip_ws(complete::tag("]")), + ), + |out: (&str, Vec<&str>)| { + let mut first = out.1; + first.insert(0, out.0); + let v = first.iter().map(|s| s.to_string()).collect(); + Self { outputs: v } + }, + ), + )(input) + } +} +pub fn skip_ws<'a, F: 'a, O>(inner: F) -> impl FnMut(&'a str) -> ParseResult<&str, O> +where + F: FnMut(&'a str) -> ParseResult<&str, O>, +{ + sequence::preceded(pchar::multispace0, inner) +} + +pub fn num_i64(input: &str) -> ParseResult<&str, i64> { + error::context( + "num_i32", + combinator::map_res( + complete::take_while1(|c: char| c.is_ascii_digit()), + |o: &str| { + o.parse::() + .map_err(|_| EuclidError::InvalidNumber(o.to_string())) + }, + ), + )(input) +} + +pub fn string_str(input: &str) -> ParseResult<&str, String> { + error::context( + "String", + combinator::map( + sequence::delimited( + complete::tag("\""), + complete::take_while1(|c: char| c != '"'), + complete::tag("\""), + ), + |val: &str| val.to_string(), + ), + )(input) +} + +pub fn identifier(input: &str) -> ParseResult<&str, String> { + error::context( + "identifier", + combinator::map( + sequence::pair( + complete::take_while1(|c: char| c.is_ascii_alphabetic() || c == '_'), + complete::take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), + ), + |out: (&str, &str)| out.0.to_string() + out.1, + ), + )(input) +} +pub fn percentage(input: &str) -> ParseResult<&str, u8> { + error::context( + "volume_split_percentage", + combinator::map_res( + sequence::terminated( + complete::take_while_m_n(1, 2, |c: char| c.is_ascii_digit()), + complete::tag("%"), + ), + |o: &str| { + o.parse::() + .map_err(|_| EuclidError::InvalidPercentage(o.to_string())) + }, + ), + )(input) +} + +pub fn number_value(input: &str) -> ParseResult<&str, ast::ValueType> { + error::context( + "number_value", + combinator::map(num_i64, ast::ValueType::Number), + )(input) +} + +pub fn str_value(input: &str) -> ParseResult<&str, ast::ValueType> { + error::context( + "str_value", + combinator::map(string_str, ast::ValueType::StrValue), + )(input) +} +pub fn enum_value_string(input: &str) -> ParseResult<&str, String> { + combinator::map( + sequence::pair( + complete::take_while1(|c: char| c.is_ascii_alphabetic() || c == '_'), + complete::take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), + ), + |out: (&str, &str)| out.0.to_string() + out.1, + )(input) +} + +pub fn enum_variant_value(input: &str) -> ParseResult<&str, ast::ValueType> { + error::context( + "enum_variant_value", + combinator::map(enum_value_string, ast::ValueType::EnumVariant), + )(input) +} + +pub fn number_array_value(input: &str) -> ParseResult<&str, ast::ValueType> { + let many_with_comma = multi::many0(sequence::preceded( + skip_ws(complete::tag(",")), + skip_ws(num_i64), + )); + + let full_sequence = sequence::pair(skip_ws(num_i64), many_with_comma); + + error::context( + "number_array_value", + combinator::map( + sequence::delimited( + skip_ws(complete::tag("(")), + full_sequence, + skip_ws(complete::tag(")")), + ), + |tup: (i64, Vec)| { + let mut rest = tup.1; + rest.insert(0, tup.0); + ast::ValueType::NumberArray(rest) + }, + ), + )(input) +} + +pub fn enum_variant_array_value(input: &str) -> ParseResult<&str, ast::ValueType> { + let many_with_comma = multi::many0(sequence::preceded( + skip_ws(complete::tag(",")), + skip_ws(enum_value_string), + )); + + let full_sequence = sequence::pair(skip_ws(enum_value_string), many_with_comma); + + error::context( + "enum_variant_array_value", + combinator::map( + sequence::delimited( + skip_ws(complete::tag("(")), + full_sequence, + skip_ws(complete::tag(")")), + ), + |tup: (String, Vec)| { + let mut rest = tup.1; + rest.insert(0, tup.0); + ast::ValueType::EnumVariantArray(rest) + }, + ), + )(input) +} + +pub fn number_comparison(input: &str) -> ParseResult<&str, ast::NumberComparison> { + let operator = combinator::map_res( + branch::alt(( + complete::tag(">="), + complete::tag("<="), + complete::tag(">"), + complete::tag("<"), + )), + |s: &str| match s { + ">=" => Ok(ast::ComparisonType::GreaterThanEqual), + "<=" => Ok(ast::ComparisonType::LessThanEqual), + ">" => Ok(ast::ComparisonType::GreaterThan), + "<" => Ok(ast::ComparisonType::LessThan), + _ => Err(EuclidError::InvalidOperator(s.to_string())), + }, + ); + + error::context( + "number_comparison", + combinator::map( + sequence::pair(operator, num_i64), + |tup: (ast::ComparisonType, i64)| ast::NumberComparison { + comparison_type: tup.0, + number: tup.1, + }, + ), + )(input) +} + +pub fn number_comparison_array_value(input: &str) -> ParseResult<&str, ast::ValueType> { + let many_with_comma = multi::many0(sequence::preceded( + skip_ws(complete::tag(",")), + skip_ws(number_comparison), + )); + + let full_sequence = sequence::pair(skip_ws(number_comparison), many_with_comma); + + error::context( + "number_comparison_array_value", + combinator::map( + sequence::delimited( + skip_ws(complete::tag("(")), + full_sequence, + skip_ws(complete::tag(")")), + ), + |tup: (ast::NumberComparison, Vec)| { + let mut rest = tup.1; + rest.insert(0, tup.0); + ast::ValueType::NumberComparisonArray(rest) + }, + ), + )(input) +} + +pub fn value_type(input: &str) -> ParseResult<&str, ast::ValueType> { + error::context( + "value_type", + branch::alt(( + number_value, + enum_variant_value, + enum_variant_array_value, + number_array_value, + number_comparison_array_value, + str_value, + )), + )(input) +} + +pub fn comparison_type(input: &str) -> ParseResult<&str, ast::ComparisonType> { + error::context( + "comparison_operator", + combinator::map_res( + branch::alt(( + complete::tag("/="), + complete::tag(">="), + complete::tag("<="), + complete::tag("="), + complete::tag(">"), + complete::tag("<"), + )), + |s: &str| match s { + "/=" => Ok(ast::ComparisonType::NotEqual), + ">=" => Ok(ast::ComparisonType::GreaterThanEqual), + "<=" => Ok(ast::ComparisonType::LessThanEqual), + "=" => Ok(ast::ComparisonType::Equal), + ">" => Ok(ast::ComparisonType::GreaterThan), + "<" => Ok(ast::ComparisonType::LessThan), + _ => Err(EuclidError::InvalidOperator(s.to_string())), + }, + ), + )(input) +} + +pub fn comparison(input: &str) -> ParseResult<&str, ast::Comparison> { + error::context( + "condition", + combinator::map( + sequence::tuple(( + skip_ws(complete::take_while1(|c: char| { + c.is_ascii_alphabetic() || c == '.' || c == '_' + })), + skip_ws(comparison_type), + skip_ws(value_type), + )), + |tup: (&str, ast::ComparisonType, ast::ValueType)| ast::Comparison { + lhs: tup.0.to_string(), + comparison: tup.1, + value: tup.2, + metadata: std::collections::HashMap::new(), + }, + ), + )(input) +} + +pub fn arbitrary_comparison(input: &str) -> ParseResult<&str, ast::Comparison> { + error::context( + "condition", + combinator::map( + sequence::tuple(( + skip_ws(string_str), + skip_ws(comparison_type), + skip_ws(string_str), + )), + |tup: (String, ast::ComparisonType, String)| ast::Comparison { + lhs: "metadata".to_string(), + comparison: tup.1, + value: ast::ValueType::MetadataVariant(ast::MetadataValue { + key: tup.0, + value: tup.2, + }), + metadata: std::collections::HashMap::new(), + }, + ), + )(input) +} + +pub fn comparison_array(input: &str) -> ParseResult<&str, Vec> { + let many_with_ampersand = error::context( + "many_with_amp", + multi::many0(sequence::preceded(skip_ws(complete::tag("&")), comparison)), + ); + + let full_sequence = sequence::pair( + skip_ws(branch::alt((comparison, arbitrary_comparison))), + many_with_ampersand, + ); + + error::context( + "comparison_array", + combinator::map( + full_sequence, + |tup: (ast::Comparison, Vec)| { + let mut rest = tup.1; + rest.insert(0, tup.0); + rest + }, + ), + )(input) +} + +pub fn if_statement(input: &str) -> ParseResult<&str, ast::IfStatement> { + let nested_block = sequence::delimited( + skip_ws(complete::tag("{")), + multi::many0(if_statement), + skip_ws(complete::tag("}")), + ); + + error::context( + "if_statement", + combinator::map( + sequence::pair(comparison_array, combinator::opt(nested_block)), + |tup: (ast::IfCondition, Option>)| ast::IfStatement { + condition: tup.0, + nested: tup.1, + }, + ), + )(input) +} + +pub fn rule_conditions_array(input: &str) -> ParseResult<&str, Vec> { + error::context( + "rules_array", + sequence::delimited( + skip_ws(complete::tag("{")), + multi::many1(if_statement), + skip_ws(complete::tag("}")), + ), + )(input) +} + +pub fn rule(input: &str) -> ParseResult<&str, ast::Rule> { + let rule_name = error::context( + "rule_name", + combinator::map( + skip_ws(sequence::pair( + complete::take_while1(|c: char| c.is_ascii_alphabetic() || c == '_'), + complete::take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), + )), + |out: (&str, &str)| out.0.to_string() + out.1, + ), + ); + + let connector_selection = error::context( + "parse_output", + sequence::preceded(skip_ws(complete::tag(":")), output), + ); + + error::context( + "rule", + combinator::map( + sequence::tuple((rule_name, connector_selection, rule_conditions_array)), + |tup: (String, O, Vec)| ast::Rule { + name: tup.0, + connector_selection: tup.1, + statements: tup.2, + }, + ), + )(input) +} + +pub fn output(input: &str) -> ParseResult<&str, O> { + O::parse_output(input) +} + +pub fn default_output(input: &str) -> ParseResult<&str, O> { + error::context( + "default_output", + sequence::preceded( + sequence::pair(skip_ws(complete::tag("default")), skip_ws(pchar::char(':'))), + skip_ws(output), + ), + )(input) +} + +pub fn program(input: &str) -> ParseResult<&str, ast::Program> { + error::context( + "program", + combinator::map( + sequence::pair(default_output, multi::many1(skip_ws(rule::))), + |tup: (O, Vec>)| ast::Program { + default_selection: tup.0, + rules: tup.1, + metadata: std::collections::HashMap::new(), + }, + ), + )(input) +} diff --git a/crates/euclid/src/frontend/dir.rs b/crates/euclid/src/frontend/dir.rs new file mode 100644 index 000000000000..7f2fc252d232 --- /dev/null +++ b/crates/euclid/src/frontend/dir.rs @@ -0,0 +1,803 @@ +//! Domain Intermediate Representation +pub mod enums; +pub mod lowering; +pub mod transformers; + +use strum::IntoEnumIterator; + +use crate::{enums as euclid_enums, frontend::ast, types}; + +#[macro_export] +#[cfg(feature = "connector_choice_mca_id")] +macro_rules! dirval { + (Connector = $name:ident) => { + $crate::frontend::dir::DirValue::Connector(Box::new( + $crate::frontend::ast::ConnectorChoice { + connector: $crate::frontend::dir::enums::Connector::$name, + }, + )) + }; + + ($key:ident = $val:ident) => {{ + pub use $crate::frontend::dir::enums::*; + + $crate::frontend::dir::DirValue::$key($key::$val) + }}; + + ($key:ident = $num:literal) => {{ + $crate::frontend::dir::DirValue::$key($crate::types::NumValue { + number: $num, + refinement: None, + }) + }}; + + ($key:ident s= $str:literal) => {{ + $crate::frontend::dir::DirValue::$key($crate::types::StrValue { + value: $str.to_string(), + }) + }}; + + ($key:literal = $str:literal) => {{ + $crate::frontend::dir::DirValue::MetaData($crate::types::MetadataValue { + key: $key.to_string(), + value: $str.to_string(), + }) + }}; +} + +#[macro_export] +#[cfg(not(feature = "connector_choice_mca_id"))] +macro_rules! dirval { + (Connector = $name:ident) => { + $crate::frontend::dir::DirValue::Connector(Box::new( + $crate::frontend::ast::ConnectorChoice { + connector: $crate::frontend::dir::enums::Connector::$name, + sub_label: None, + }, + )) + }; + + (Connector = ($name:ident, $sub_label:literal)) => { + $crate::frontend::dir::DirValue::Connector(Box::new( + $crate::frontend::ast::ConnectorChoice { + connector: $crate::frontend::dir::enums::Connector::$name, + sub_label: Some($sub_label.to_string()), + }, + )) + }; + + ($key:ident = $val:ident) => {{ + pub use $crate::frontend::dir::enums::*; + + $crate::frontend::dir::DirValue::$key($key::$val) + }}; + + ($key:ident = $num:literal) => {{ + $crate::frontend::dir::DirValue::$key($crate::types::NumValue { + number: $num, + refinement: None, + }) + }}; + + ($key:ident s= $str:literal) => {{ + $crate::frontend::dir::DirValue::$key($crate::types::StrValue { + value: $str.to_string(), + }) + }}; + ($key:literal = $str:literal) => {{ + $crate::frontend::dir::DirValue::MetaData($crate::types::MetadataValue { + key: $key.to_string(), + value: $str.to_string(), + }) + }}; +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Serialize)] +pub struct DirKey { + pub kind: DirKeyKind, + pub value: Option, +} + +impl DirKey { + pub fn new(kind: DirKeyKind, value: Option) -> Self { + Self { kind, value } + } +} + +#[derive( + Debug, + Clone, + Hash, + PartialEq, + Eq, + serde::Serialize, + strum::Display, + strum::EnumIter, + strum::EnumVariantNames, + strum::EnumString, + strum::EnumMessage, + strum::EnumProperty, +)] +pub enum DirKeyKind { + #[strum( + serialize = "payment_method", + detailed_message = "Different modes of payment - eg. cards, wallets, banks", + props(Category = "Payment Methods") + )] + #[serde(rename = "payment_method")] + PaymentMethod, + #[strum( + serialize = "card_bin", + detailed_message = "First 4 to 6 digits of a payment card number", + props(Category = "Payment Methods") + )] + #[serde(rename = "card_bin")] + CardBin, + #[strum( + serialize = "card_type", + detailed_message = "Type of the payment card - eg. credit, debit", + props(Category = "Payment Methods") + )] + #[serde(rename = "card_type")] + CardType, + #[strum( + serialize = "card_network", + detailed_message = "Network that facilitates payment card transactions", + props(Category = "Payment Methods") + )] + #[serde(rename = "card_network")] + CardNetwork, + #[strum( + serialize = "pay_later", + detailed_message = "Supported types of Pay Later payment method", + props(Category = "Payment Method Types") + )] + #[serde(rename = "pay_later")] + PayLaterType, + #[strum( + serialize = "gift_card", + detailed_message = "Supported types of Gift Card payment method", + props(Category = "Payment Method Types") + )] + #[serde(rename = "gift_card")] + GiftCardType, + #[strum( + serialize = "mandate_acceptance_type", + detailed_message = "Mode of customer acceptance for mandates - online and offline", + props(Category = "Payments") + )] + #[serde(rename = "mandate_acceptance_type")] + MandateAcceptanceType, + #[strum( + serialize = "mandate_type", + detailed_message = "Type of mandate acceptance - single use and multi use", + props(Category = "Payments") + )] + #[serde(rename = "mandate_type")] + MandateType, + #[strum( + serialize = "payment_type", + detailed_message = "Indicates if a payment is mandate or non-mandate", + props(Category = "Payments") + )] + #[serde(rename = "payment_type")] + PaymentType, + #[strum( + serialize = "wallet", + detailed_message = "Supported types of Wallet payment method", + props(Category = "Payment Method Types") + )] + #[serde(rename = "wallet")] + WalletType, + #[strum( + serialize = "upi", + detailed_message = "Supported types of UPI payment method", + props(Category = "Payment Method Types") + )] + #[serde(rename = "upi")] + UpiType, + #[strum( + serialize = "voucher", + detailed_message = "Supported types of Voucher payment method", + props(Category = "Payment Method Types") + )] + #[serde(rename = "voucher")] + VoucherType, + #[strum( + serialize = "bank_transfer", + detailed_message = "Supported types of Bank Transfer payment method", + props(Category = "Payment Method Types") + )] + #[serde(rename = "bank_transfer")] + BankTransferType, + #[strum( + serialize = "bank_redirect", + detailed_message = "Supported types of Bank Redirect payment methods", + props(Category = "Payment Method Types") + )] + #[serde(rename = "bank_redirect")] + BankRedirectType, + #[strum( + serialize = "bank_debit", + detailed_message = "Supported types of Bank Debit payment method", + props(Category = "Payment Method Types") + )] + #[serde(rename = "bank_debit")] + BankDebitType, + #[strum( + serialize = "crypto", + detailed_message = "Supported types of Crypto payment method", + props(Category = "Payment Method Types") + )] + #[serde(rename = "crypto")] + CryptoType, + #[strum( + serialize = "metadata", + detailed_message = "Aribitrary Key and value pair", + props(Category = "Metadata") + )] + #[serde(rename = "metadata")] + MetaData, + #[strum( + serialize = "reward", + detailed_message = "Supported types of Reward payment method", + props(Category = "Payment Method Types") + )] + #[serde(rename = "reward")] + RewardType, + #[strum( + serialize = "amount", + detailed_message = "Value of the transaction", + props(Category = "Payments") + )] + #[serde(rename = "amount")] + PaymentAmount, + #[strum( + serialize = "currency", + detailed_message = "Currency used for the payment", + props(Category = "Payments") + )] + #[serde(rename = "currency")] + PaymentCurrency, + #[strum( + serialize = "authentication_type", + detailed_message = "Type of authentication for the payment", + props(Category = "Payments") + )] + #[serde(rename = "authentication_type")] + AuthenticationType, + #[strum( + serialize = "capture_method", + detailed_message = "Modes of capturing a payment", + props(Category = "Payments") + )] + #[serde(rename = "capture_method")] + CaptureMethod, + #[strum( + serialize = "country", + serialize = "business_country", + detailed_message = "Country of the business unit", + props(Category = "Merchant") + )] + #[serde(rename = "business_country", alias = "country")] + BusinessCountry, + #[strum( + serialize = "billing_country", + detailed_message = "Country of the billing address of the customer", + props(Category = "Customer") + )] + #[serde(rename = "billing_country")] + BillingCountry, + #[serde(skip_deserializing, rename = "connector")] + #[strum(disabled)] + Connector, + #[strum( + serialize = "business_label", + detailed_message = "Identifier for business unit", + props(Category = "Merchant") + )] + #[serde(rename = "business_label")] + BusinessLabel, + #[strum( + serialize = "setup_future_usage", + detailed_message = "Identifier for recurring payments", + props(Category = "Payments") + )] + #[serde(rename = "setup_future_usage")] + SetupFutureUsage, + #[strum( + serialize = "card_redirect_type", + detailed_message = "Supported types of Card Redirect payment method", + props(Category = "Payment Method Types") + )] + #[serde(rename = "card_redirect")] + CardRedirectType, +} + +pub trait EuclidDirFilter: Sized +where + Self: 'static, +{ + const ALLOWED: &'static [DirKeyKind]; + fn get_allowed_keys() -> &'static [DirKeyKind] { + Self::ALLOWED + } + + fn is_key_allowed(key: &DirKeyKind) -> bool { + Self::ALLOWED.contains(key) + } +} + +impl DirKeyKind { + pub fn get_type(&self) -> types::DataType { + match self { + Self::PaymentMethod => types::DataType::EnumVariant, + Self::CardBin => types::DataType::StrValue, + Self::CardType => types::DataType::EnumVariant, + Self::CardNetwork => types::DataType::EnumVariant, + Self::MetaData => types::DataType::MetadataValue, + Self::MandateType => types::DataType::EnumVariant, + Self::PaymentType => types::DataType::EnumVariant, + Self::MandateAcceptanceType => types::DataType::EnumVariant, + Self::PayLaterType => types::DataType::EnumVariant, + Self::WalletType => types::DataType::EnumVariant, + Self::UpiType => types::DataType::EnumVariant, + Self::VoucherType => types::DataType::EnumVariant, + Self::BankTransferType => types::DataType::EnumVariant, + Self::GiftCardType => types::DataType::EnumVariant, + Self::BankRedirectType => types::DataType::EnumVariant, + Self::CryptoType => types::DataType::EnumVariant, + Self::RewardType => types::DataType::EnumVariant, + Self::PaymentAmount => types::DataType::Number, + Self::PaymentCurrency => types::DataType::EnumVariant, + Self::AuthenticationType => types::DataType::EnumVariant, + Self::CaptureMethod => types::DataType::EnumVariant, + Self::BusinessCountry => types::DataType::EnumVariant, + Self::BillingCountry => types::DataType::EnumVariant, + Self::Connector => types::DataType::EnumVariant, + Self::BankDebitType => types::DataType::EnumVariant, + Self::BusinessLabel => types::DataType::StrValue, + Self::SetupFutureUsage => types::DataType::EnumVariant, + Self::CardRedirectType => types::DataType::EnumVariant, + } + } + pub fn get_value_set(&self) -> Option> { + match self { + Self::PaymentMethod => Some( + enums::PaymentMethod::iter() + .map(DirValue::PaymentMethod) + .collect(), + ), + Self::CardBin => None, + Self::CardType => Some(enums::CardType::iter().map(DirValue::CardType).collect()), + Self::MandateAcceptanceType => Some( + euclid_enums::MandateAcceptanceType::iter() + .map(DirValue::MandateAcceptanceType) + .collect(), + ), + Self::PaymentType => Some( + euclid_enums::PaymentType::iter() + .map(DirValue::PaymentType) + .collect(), + ), + Self::MandateType => Some( + euclid_enums::MandateType::iter() + .map(DirValue::MandateType) + .collect(), + ), + Self::CardNetwork => Some( + enums::CardNetwork::iter() + .map(DirValue::CardNetwork) + .collect(), + ), + Self::PayLaterType => Some( + enums::PayLaterType::iter() + .map(DirValue::PayLaterType) + .collect(), + ), + Self::MetaData => None, + Self::WalletType => Some( + enums::WalletType::iter() + .map(DirValue::WalletType) + .collect(), + ), + Self::UpiType => Some(enums::UpiType::iter().map(DirValue::UpiType).collect()), + Self::VoucherType => Some( + enums::VoucherType::iter() + .map(DirValue::VoucherType) + .collect(), + ), + Self::BankTransferType => Some( + enums::BankTransferType::iter() + .map(DirValue::BankTransferType) + .collect(), + ), + Self::GiftCardType => Some( + enums::GiftCardType::iter() + .map(DirValue::GiftCardType) + .collect(), + ), + Self::BankRedirectType => Some( + enums::BankRedirectType::iter() + .map(DirValue::BankRedirectType) + .collect(), + ), + Self::CryptoType => Some( + enums::CryptoType::iter() + .map(DirValue::CryptoType) + .collect(), + ), + Self::RewardType => Some( + enums::RewardType::iter() + .map(DirValue::RewardType) + .collect(), + ), + Self::PaymentAmount => None, + Self::PaymentCurrency => Some( + enums::PaymentCurrency::iter() + .map(DirValue::PaymentCurrency) + .collect(), + ), + Self::AuthenticationType => Some( + enums::AuthenticationType::iter() + .map(DirValue::AuthenticationType) + .collect(), + ), + Self::CaptureMethod => Some( + enums::CaptureMethod::iter() + .map(DirValue::CaptureMethod) + .collect(), + ), + Self::BankDebitType => Some( + enums::BankDebitType::iter() + .map(DirValue::BankDebitType) + .collect(), + ), + Self::BusinessCountry => Some( + enums::Country::iter() + .map(DirValue::BusinessCountry) + .collect(), + ), + Self::BillingCountry => Some( + enums::Country::iter() + .map(DirValue::BillingCountry) + .collect(), + ), + Self::Connector => Some( + enums::Connector::iter() + .map(|connector| { + DirValue::Connector(Box::new(ast::ConnectorChoice { + connector, + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label: None, + })) + }) + .collect(), + ), + Self::BusinessLabel => None, + Self::SetupFutureUsage => Some( + enums::SetupFutureUsage::iter() + .map(DirValue::SetupFutureUsage) + .collect(), + ), + Self::CardRedirectType => Some( + enums::CardRedirectType::iter() + .map(DirValue::CardRedirectType) + .collect(), + ), + } + } +} + +#[derive( + Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, strum::Display, strum::EnumVariantNames, +)] +#[serde(tag = "key", content = "value")] +pub enum DirValue { + #[serde(rename = "payment_method")] + PaymentMethod(enums::PaymentMethod), + #[serde(rename = "card_bin")] + CardBin(types::StrValue), + #[serde(rename = "card_type")] + CardType(enums::CardType), + #[serde(rename = "card_network")] + CardNetwork(enums::CardNetwork), + #[serde(rename = "metadata")] + MetaData(types::MetadataValue), + #[serde(rename = "pay_later")] + PayLaterType(enums::PayLaterType), + #[serde(rename = "wallet")] + WalletType(enums::WalletType), + #[serde(rename = "acceptance_type")] + MandateAcceptanceType(euclid_enums::MandateAcceptanceType), + #[serde(rename = "mandate_type")] + MandateType(euclid_enums::MandateType), + #[serde(rename = "payment_type")] + PaymentType(euclid_enums::PaymentType), + #[serde(rename = "upi")] + UpiType(enums::UpiType), + #[serde(rename = "voucher")] + VoucherType(enums::VoucherType), + #[serde(rename = "bank_transfer")] + BankTransferType(enums::BankTransferType), + #[serde(rename = "bank_redirect")] + BankRedirectType(enums::BankRedirectType), + #[serde(rename = "bank_debit")] + BankDebitType(enums::BankDebitType), + #[serde(rename = "crypto")] + CryptoType(enums::CryptoType), + #[serde(rename = "reward")] + RewardType(enums::RewardType), + #[serde(rename = "gift_card")] + GiftCardType(enums::GiftCardType), + #[serde(rename = "amount")] + PaymentAmount(types::NumValue), + #[serde(rename = "currency")] + PaymentCurrency(enums::PaymentCurrency), + #[serde(rename = "authentication_type")] + AuthenticationType(enums::AuthenticationType), + #[serde(rename = "capture_method")] + CaptureMethod(enums::CaptureMethod), + #[serde(rename = "business_country", alias = "country")] + BusinessCountry(enums::Country), + #[serde(rename = "billing_country")] + BillingCountry(enums::Country), + #[serde(skip_deserializing, rename = "connector")] + Connector(Box), + #[serde(rename = "business_label")] + BusinessLabel(types::StrValue), + #[serde(rename = "setup_future_usage")] + SetupFutureUsage(enums::SetupFutureUsage), + #[serde(rename = "card_redirect")] + CardRedirectType(enums::CardRedirectType), +} + +impl DirValue { + pub fn get_key(&self) -> DirKey { + let (kind, data) = match self { + Self::PaymentMethod(_) => (DirKeyKind::PaymentMethod, None), + Self::CardBin(_) => (DirKeyKind::CardBin, None), + Self::RewardType(_) => (DirKeyKind::RewardType, None), + Self::BusinessCountry(_) => (DirKeyKind::BusinessCountry, None), + Self::BillingCountry(_) => (DirKeyKind::CardBin, None), + Self::BankTransferType(_) => (DirKeyKind::BankTransferType, None), + Self::UpiType(_) => (DirKeyKind::UpiType, None), + Self::CardType(_) => (DirKeyKind::CardType, None), + Self::CardNetwork(_) => (DirKeyKind::CardNetwork, None), + Self::MetaData(met) => (DirKeyKind::MetaData, Some(met.key.clone())), + Self::PayLaterType(_) => (DirKeyKind::PayLaterType, None), + Self::WalletType(_) => (DirKeyKind::WalletType, None), + Self::BankRedirectType(_) => (DirKeyKind::BankRedirectType, None), + Self::CryptoType(_) => (DirKeyKind::CryptoType, None), + Self::AuthenticationType(_) => (DirKeyKind::AuthenticationType, None), + Self::CaptureMethod(_) => (DirKeyKind::CaptureMethod, None), + Self::PaymentAmount(_) => (DirKeyKind::PaymentAmount, None), + Self::PaymentCurrency(_) => (DirKeyKind::PaymentCurrency, None), + Self::Connector(_) => (DirKeyKind::Connector, None), + Self::BankDebitType(_) => (DirKeyKind::BankDebitType, None), + Self::MandateAcceptanceType(_) => (DirKeyKind::MandateAcceptanceType, None), + Self::MandateType(_) => (DirKeyKind::MandateType, None), + Self::PaymentType(_) => (DirKeyKind::PaymentType, None), + Self::BusinessLabel(_) => (DirKeyKind::BusinessLabel, None), + Self::SetupFutureUsage(_) => (DirKeyKind::SetupFutureUsage, None), + Self::CardRedirectType(_) => (DirKeyKind::CardRedirectType, None), + Self::VoucherType(_) => (DirKeyKind::VoucherType, None), + Self::GiftCardType(_) => (DirKeyKind::GiftCardType, None), + }; + + DirKey::new(kind, data) + } + pub fn get_metadata_val(&self) -> Option { + match self { + Self::MetaData(val) => Some(val.clone()), + Self::PaymentMethod(_) => None, + Self::CardBin(_) => None, + Self::CardType(_) => None, + Self::CardNetwork(_) => None, + Self::PayLaterType(_) => None, + Self::WalletType(_) => None, + Self::BankRedirectType(_) => None, + Self::CryptoType(_) => None, + Self::AuthenticationType(_) => None, + Self::CaptureMethod(_) => None, + Self::GiftCardType(_) => None, + Self::PaymentAmount(_) => None, + Self::PaymentCurrency(_) => None, + Self::BusinessCountry(_) => None, + Self::BillingCountry(_) => None, + Self::Connector(_) => None, + Self::BankTransferType(_) => None, + Self::UpiType(_) => None, + Self::BankDebitType(_) => None, + Self::RewardType(_) => None, + Self::VoucherType(_) => None, + Self::MandateAcceptanceType(_) => None, + Self::MandateType(_) => None, + Self::PaymentType(_) => None, + Self::BusinessLabel(_) => None, + Self::SetupFutureUsage(_) => None, + Self::CardRedirectType(_) => None, + } + } + + pub fn get_str_val(&self) -> Option { + match self { + Self::CardBin(val) => Some(val.clone()), + _ => None, + } + } + + pub fn get_num_value(&self) -> Option { + match self { + Self::PaymentAmount(val) => Some(val.clone()), + _ => None, + } + } + + pub fn check_equality(v1: &Self, v2: &Self) -> bool { + match (v1, v2) { + (Self::PaymentMethod(pm1), Self::PaymentMethod(pm2)) => pm1 == pm2, + (Self::CardType(ct1), Self::CardType(ct2)) => ct1 == ct2, + (Self::CardNetwork(cn1), Self::CardNetwork(cn2)) => cn1 == cn2, + (Self::MetaData(md1), Self::MetaData(md2)) => md1 == md2, + (Self::PayLaterType(plt1), Self::PayLaterType(plt2)) => plt1 == plt2, + (Self::WalletType(wt1), Self::WalletType(wt2)) => wt1 == wt2, + (Self::BankDebitType(bdt1), Self::BankDebitType(bdt2)) => bdt1 == bdt2, + (Self::BankRedirectType(brt1), Self::BankRedirectType(brt2)) => brt1 == brt2, + (Self::BankTransferType(btt1), Self::BankTransferType(btt2)) => btt1 == btt2, + (Self::GiftCardType(gct1), Self::GiftCardType(gct2)) => gct1 == gct2, + (Self::CryptoType(ct1), Self::CryptoType(ct2)) => ct1 == ct2, + (Self::AuthenticationType(at1), Self::AuthenticationType(at2)) => at1 == at2, + (Self::CaptureMethod(cm1), Self::CaptureMethod(cm2)) => cm1 == cm2, + (Self::PaymentCurrency(pc1), Self::PaymentCurrency(pc2)) => pc1 == pc2, + (Self::BusinessCountry(c1), Self::BusinessCountry(c2)) => c1 == c2, + (Self::BillingCountry(c1), Self::BillingCountry(c2)) => c1 == c2, + (Self::PaymentType(pt1), Self::PaymentType(pt2)) => pt1 == pt2, + (Self::MandateType(mt1), Self::MandateType(mt2)) => mt1 == mt2, + (Self::MandateAcceptanceType(mat1), Self::MandateAcceptanceType(mat2)) => mat1 == mat2, + (Self::RewardType(rt1), Self::RewardType(rt2)) => rt1 == rt2, + (Self::Connector(c1), Self::Connector(c2)) => c1 == c2, + (Self::BusinessLabel(bl1), Self::BusinessLabel(bl2)) => bl1 == bl2, + (Self::SetupFutureUsage(sfu1), Self::SetupFutureUsage(sfu2)) => sfu1 == sfu2, + (Self::UpiType(ut1), Self::UpiType(ut2)) => ut1 == ut2, + (Self::VoucherType(vt1), Self::VoucherType(vt2)) => vt1 == vt2, + (Self::CardRedirectType(crt1), Self::CardRedirectType(crt2)) => crt1 == crt2, + _ => false, + } + } +} + +#[derive(Debug, Clone)] +pub enum DirComparisonLogic { + NegativeConjunction, + PositiveDisjunction, +} + +#[derive(Debug, Clone)] +pub struct DirComparison { + pub values: Vec, + pub logic: DirComparisonLogic, + pub metadata: types::Metadata, +} + +pub type DirIfCondition = Vec; + +#[derive(Debug, Clone)] +pub struct DirIfStatement { + pub condition: DirIfCondition, + pub nested: Option>, +} + +#[derive(Debug, Clone)] +pub struct DirRule { + pub name: String, + pub connector_selection: O, + pub statements: Vec, +} + +#[derive(Debug, Clone)] +pub struct DirProgram { + pub default_selection: O, + pub rules: Vec>, + pub metadata: types::Metadata, +} + +#[cfg(test)] +mod test { + #![allow(clippy::expect_used)] + use rustc_hash::FxHashMap; + use strum::IntoEnumIterator; + + use super::*; + + #[test] + fn test_consistent_dir_key_naming() { + let mut key_names: FxHashMap = FxHashMap::default(); + + for key in DirKeyKind::iter() { + let json_str = if let DirKeyKind::MetaData = key { + r#""metadata""#.to_string() + } else { + serde_json::to_string(&key).expect("JSON Serialization") + }; + let display_str = key.to_string(); + + assert_eq!(&json_str[1..json_str.len() - 1], display_str); + key_names.insert(key, display_str); + } + + let values = vec![ + dirval!(PaymentMethod = Card), + dirval!(CardBin s= "123456"), + dirval!(CardType = Credit), + dirval!(CardNetwork = Visa), + dirval!(PayLaterType = Klarna), + dirval!(WalletType = Paypal), + dirval!(BankRedirectType = Sofort), + dirval!(BankDebitType = Bacs), + dirval!(CryptoType = CryptoCurrency), + dirval!("" = "metadata"), + dirval!(PaymentAmount = 100), + dirval!(PaymentCurrency = USD), + dirval!(CardRedirectType = Benefit), + dirval!(AuthenticationType = ThreeDs), + dirval!(CaptureMethod = Manual), + dirval!(BillingCountry = UnitedStatesOfAmerica), + dirval!(BusinessCountry = France), + ]; + + for val in values { + let json_val = serde_json::to_value(&val).expect("JSON Value Serialization"); + + let json_key = json_val + .as_object() + .expect("Serialized Object") + .get("key") + .expect("Object Key"); + + let value_str = json_key.as_str().expect("Value string"); + let dir_key = val.get_key(); + + let key_name = key_names.get(&dir_key.kind).expect("Key name"); + + assert_eq!(key_name, value_str); + } + } + + #[cfg(feature = "ast_parser")] + #[test] + fn test_allowed_dir_keys() { + use crate::types::DummyOutput; + + let program_str = r#" + default: ["stripe", "adyen"] + + rule_1: ["stripe"] + { + payment_method = card + } + "#; + let (_, program) = ast::parser::program::(program_str).expect("Program"); + + let out = ast::lowering::lower_program::(program); + assert!(out.is_ok()) + } + #[cfg(feature = "ast_parser")] + #[test] + fn test_not_allowed_dir_keys() { + use crate::types::DummyOutput; + + let program_str = r#" + default: ["stripe", "adyen"] + + rule_1: ["stripe"] + { + bank_debit = ach + } + "#; + let (_, program) = ast::parser::program::(program_str).expect("Program"); + + let out = ast::lowering::lower_program::(program); + assert!(out.is_err()) + } +} diff --git a/crates/euclid/src/frontend/dir/enums.rs b/crates/euclid/src/frontend/dir/enums.rs new file mode 100644 index 000000000000..17699940363f --- /dev/null +++ b/crates/euclid/src/frontend/dir/enums.rs @@ -0,0 +1,321 @@ +use strum::VariantNames; + +use crate::enums::collect_variants; +pub use crate::enums::{ + AuthenticationType, CaptureMethod, CardNetwork, Connector, Country, Country as BusinessCountry, + Country as BillingCountry, Currency as PaymentCurrency, MandateAcceptanceType, MandateType, + PaymentMethod, PaymentType, SetupFutureUsage, +}; + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum CardType { + Credit, + Debit, +} + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum PayLaterType { + Affirm, + AfterpayClearpay, + Alma, + Klarna, + PayBright, + Walley, + Atome, +} + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum WalletType { + GooglePay, + ApplePay, + Paypal, + AliPay, + AliPayHk, + MbWay, + MobilePay, + WeChatPay, + SamsungPay, + GoPay, + KakaoPay, + Twint, + Gcash, + Vipps, + Momo, + Dana, + TouchNGo, + Swish, + Cashapp, +} + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum VoucherType { + Boleto, + Efecty, + PagoEfectivo, + RedCompra, + RedPagos, + Alfamart, + Indomaret, + SevenEleven, + Lawson, + MiniStop, + FamilyMart, + Seicomart, + PayEasy, + Oxxo, +} + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum BankRedirectType { + Bizum, + Giropay, + Ideal, + Sofort, + Eps, + BancontactCard, + Blik, + Interac, + OnlineBankingCzechRepublic, + OnlineBankingFinland, + OnlineBankingPoland, + OnlineBankingSlovakia, + OnlineBankingFpx, + OnlineBankingThailand, + OpenBankingUk, + Przelewy24, + Trustly, +} +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum BankTransferType { + Multibanco, + Ach, + Sepa, + Bacs, + BcaBankTransfer, + BniVa, + BriVa, + CimbVa, + DanamonVa, + MandiriVa, + PermataBankTransfer, + Pix, + Pse, +} + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum GiftCardType { + PaySafeCard, + Givex, +} + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum CardRedirectType { + Benefit, + Knet, + MomoAtm, +} + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum CryptoType { + CryptoCurrency, +} + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum UpiType { + UpiCollect, +} + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum BankDebitType { + Ach, + Sepa, + Bacs, + Becs, +} + +#[derive( + Clone, + Debug, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumIter, + strum::EnumString, + serde::Serialize, + serde::Deserialize, +)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum RewardType { + ClassicReward, + Evoucher, +} + +collect_variants!(CardType); +collect_variants!(PayLaterType); +collect_variants!(WalletType); +collect_variants!(BankRedirectType); +collect_variants!(BankDebitType); +collect_variants!(CryptoType); +collect_variants!(RewardType); +collect_variants!(UpiType); +collect_variants!(VoucherType); +collect_variants!(GiftCardType); +collect_variants!(BankTransferType); +collect_variants!(CardRedirectType); diff --git a/crates/euclid/src/frontend/dir/lowering.rs b/crates/euclid/src/frontend/dir/lowering.rs new file mode 100644 index 000000000000..516e10e0389e --- /dev/null +++ b/crates/euclid/src/frontend/dir/lowering.rs @@ -0,0 +1,295 @@ +//! Analysis of the lowering logic for the DIR +//! +//! Consists of certain functions that supports the lowering logic from DIR to VIR. +//! These includes the lowering of the DIR program and vector of rules , and the lowering of ifstatements +//! ,and comparisonsLogic and also the lowering of the enums of value variants from DIR to VIR. +use super::enums; +use crate::{ + dssa::types::{AnalysisError, AnalysisErrorType}, + enums as global_enums, + frontend::{dir, vir}, + types::EuclidValue, +}; + +impl From for global_enums::PaymentMethodType { + fn from(value: enums::CardType) -> Self { + match value { + enums::CardType::Credit => Self::Credit, + enums::CardType::Debit => Self::Debit, + } + } +} + +impl From for global_enums::PaymentMethodType { + fn from(value: enums::PayLaterType) -> Self { + match value { + enums::PayLaterType::Affirm => Self::Affirm, + enums::PayLaterType::AfterpayClearpay => Self::AfterpayClearpay, + enums::PayLaterType::Alma => Self::Alma, + enums::PayLaterType::Klarna => Self::Klarna, + enums::PayLaterType::PayBright => Self::PayBright, + enums::PayLaterType::Walley => Self::Walley, + enums::PayLaterType::Atome => Self::Atome, + } + } +} + +impl From for global_enums::PaymentMethodType { + fn from(value: enums::WalletType) -> Self { + match value { + enums::WalletType::GooglePay => Self::GooglePay, + enums::WalletType::ApplePay => Self::ApplePay, + enums::WalletType::Paypal => Self::Paypal, + enums::WalletType::AliPay => Self::AliPay, + enums::WalletType::AliPayHk => Self::AliPayHk, + enums::WalletType::MbWay => Self::MbWay, + enums::WalletType::MobilePay => Self::MobilePay, + enums::WalletType::WeChatPay => Self::WeChatPay, + enums::WalletType::SamsungPay => Self::SamsungPay, + enums::WalletType::GoPay => Self::GoPay, + enums::WalletType::KakaoPay => Self::KakaoPay, + enums::WalletType::Twint => Self::Twint, + enums::WalletType::Gcash => Self::Gcash, + enums::WalletType::Vipps => Self::Vipps, + enums::WalletType::Momo => Self::Momo, + enums::WalletType::Dana => Self::Dana, + enums::WalletType::TouchNGo => Self::TouchNGo, + enums::WalletType::Swish => Self::Swish, + enums::WalletType::Cashapp => Self::Cashapp, + } + } +} + +impl From for global_enums::PaymentMethodType { + fn from(value: enums::BankDebitType) -> Self { + match value { + enums::BankDebitType::Ach => Self::Ach, + enums::BankDebitType::Sepa => Self::Sepa, + enums::BankDebitType::Bacs => Self::Bacs, + enums::BankDebitType::Becs => Self::Becs, + } + } +} +impl From for global_enums::PaymentMethodType { + fn from(value: enums::UpiType) -> Self { + match value { + enums::UpiType::UpiCollect => Self::UpiCollect, + } + } +} + +impl From for global_enums::PaymentMethodType { + fn from(value: enums::VoucherType) -> Self { + match value { + enums::VoucherType::Boleto => Self::Boleto, + enums::VoucherType::Efecty => Self::Efecty, + enums::VoucherType::PagoEfectivo => Self::PagoEfectivo, + enums::VoucherType::RedCompra => Self::RedCompra, + enums::VoucherType::RedPagos => Self::RedPagos, + enums::VoucherType::Alfamart => Self::Alfamart, + enums::VoucherType::Indomaret => Self::Indomaret, + enums::VoucherType::SevenEleven => Self::SevenEleven, + enums::VoucherType::Lawson => Self::Lawson, + enums::VoucherType::MiniStop => Self::MiniStop, + enums::VoucherType::FamilyMart => Self::FamilyMart, + enums::VoucherType::Seicomart => Self::Seicomart, + enums::VoucherType::PayEasy => Self::PayEasy, + enums::VoucherType::Oxxo => Self::Oxxo, + } + } +} + +impl From for global_enums::PaymentMethodType { + fn from(value: enums::BankTransferType) -> Self { + match value { + enums::BankTransferType::Multibanco => Self::Multibanco, + enums::BankTransferType::Pix => Self::Pix, + enums::BankTransferType::Pse => Self::Pse, + enums::BankTransferType::Ach => Self::Ach, + enums::BankTransferType::Sepa => Self::Sepa, + enums::BankTransferType::Bacs => Self::Bacs, + enums::BankTransferType::BcaBankTransfer => Self::BcaBankTransfer, + enums::BankTransferType::BniVa => Self::BniVa, + enums::BankTransferType::BriVa => Self::BriVa, + enums::BankTransferType::CimbVa => Self::CimbVa, + enums::BankTransferType::DanamonVa => Self::DanamonVa, + enums::BankTransferType::MandiriVa => Self::MandiriVa, + enums::BankTransferType::PermataBankTransfer => Self::PermataBankTransfer, + } + } +} + +impl From for global_enums::PaymentMethodType { + fn from(value: enums::GiftCardType) -> Self { + match value { + enums::GiftCardType::PaySafeCard => Self::PaySafeCard, + enums::GiftCardType::Givex => Self::Givex, + } + } +} + +impl From for global_enums::PaymentMethodType { + fn from(value: enums::CardRedirectType) -> Self { + match value { + enums::CardRedirectType::Benefit => Self::Benefit, + enums::CardRedirectType::Knet => Self::Knet, + enums::CardRedirectType::MomoAtm => Self::MomoAtm, + } + } +} + +impl From for global_enums::PaymentMethodType { + fn from(value: enums::BankRedirectType) -> Self { + match value { + enums::BankRedirectType::Bizum => Self::Bizum, + enums::BankRedirectType::Giropay => Self::Giropay, + enums::BankRedirectType::Ideal => Self::Ideal, + enums::BankRedirectType::Sofort => Self::Sofort, + enums::BankRedirectType::Eps => Self::Eps, + enums::BankRedirectType::BancontactCard => Self::BancontactCard, + enums::BankRedirectType::Blik => Self::Blik, + enums::BankRedirectType::Interac => Self::Interac, + enums::BankRedirectType::OnlineBankingCzechRepublic => Self::OnlineBankingCzechRepublic, + enums::BankRedirectType::OnlineBankingFinland => Self::OnlineBankingFinland, + enums::BankRedirectType::OnlineBankingPoland => Self::OnlineBankingPoland, + enums::BankRedirectType::OnlineBankingSlovakia => Self::OnlineBankingSlovakia, + enums::BankRedirectType::OnlineBankingFpx => Self::OnlineBankingFpx, + enums::BankRedirectType::OnlineBankingThailand => Self::OnlineBankingThailand, + enums::BankRedirectType::OpenBankingUk => Self::OpenBankingUk, + enums::BankRedirectType::Przelewy24 => Self::Przelewy24, + enums::BankRedirectType::Trustly => Self::Trustly, + } + } +} + +impl From for global_enums::PaymentMethodType { + fn from(value: enums::CryptoType) -> Self { + match value { + enums::CryptoType::CryptoCurrency => Self::CryptoCurrency, + } + } +} + +impl From for global_enums::PaymentMethodType { + fn from(value: enums::RewardType) -> Self { + match value { + enums::RewardType::ClassicReward => Self::ClassicReward, + enums::RewardType::Evoucher => Self::Evoucher, + } + } +} + +/// Analyses of the lowering of the DirValues to EuclidValues . +/// +/// For example, +/// ```notrust +/// DirValue::PaymentMethod::Cards -> EuclidValue::PaymentMethod::Cards +/// ```notrust +/// This is a function that lowers the Values of the DIR variants into the Value of the VIR variants. +/// The function for each DirValue variant creates a corresponding EuclidValue variants and if there +/// lacks any direct mapping, it return an Error. +fn lower_value(dir_value: dir::DirValue) -> Result { + Ok(match dir_value { + dir::DirValue::PaymentMethod(pm) => EuclidValue::PaymentMethod(pm), + dir::DirValue::CardBin(ci) => EuclidValue::CardBin(ci), + dir::DirValue::CardType(ct) => EuclidValue::PaymentMethodType(ct.into()), + dir::DirValue::CardNetwork(cn) => EuclidValue::CardNetwork(cn), + dir::DirValue::MetaData(md) => EuclidValue::Metadata(md), + dir::DirValue::PayLaterType(plt) => EuclidValue::PaymentMethodType(plt.into()), + dir::DirValue::WalletType(wt) => EuclidValue::PaymentMethodType(wt.into()), + dir::DirValue::UpiType(ut) => EuclidValue::PaymentMethodType(ut.into()), + dir::DirValue::VoucherType(vt) => EuclidValue::PaymentMethodType(vt.into()), + dir::DirValue::BankTransferType(btt) => EuclidValue::PaymentMethodType(btt.into()), + dir::DirValue::GiftCardType(gct) => EuclidValue::PaymentMethodType(gct.into()), + dir::DirValue::CardRedirectType(crt) => EuclidValue::PaymentMethodType(crt.into()), + dir::DirValue::BankRedirectType(brt) => EuclidValue::PaymentMethodType(brt.into()), + dir::DirValue::CryptoType(ct) => EuclidValue::PaymentMethodType(ct.into()), + dir::DirValue::AuthenticationType(at) => EuclidValue::AuthenticationType(at), + dir::DirValue::CaptureMethod(cm) => EuclidValue::CaptureMethod(cm), + dir::DirValue::PaymentAmount(pa) => EuclidValue::PaymentAmount(pa), + dir::DirValue::PaymentCurrency(pc) => EuclidValue::PaymentCurrency(pc), + dir::DirValue::BusinessCountry(buc) => EuclidValue::BusinessCountry(buc), + dir::DirValue::BillingCountry(bic) => EuclidValue::BillingCountry(bic), + dir::DirValue::MandateAcceptanceType(mat) => EuclidValue::MandateAcceptanceType(mat), + dir::DirValue::MandateType(mt) => EuclidValue::MandateType(mt), + dir::DirValue::PaymentType(pt) => EuclidValue::PaymentType(pt), + dir::DirValue::Connector(_) => Err(AnalysisErrorType::UnsupportedProgramKey( + dir::DirKeyKind::Connector, + ))?, + dir::DirValue::BankDebitType(bdt) => EuclidValue::PaymentMethodType(bdt.into()), + dir::DirValue::RewardType(rt) => EuclidValue::PaymentMethodType(rt.into()), + dir::DirValue::BusinessLabel(bl) => EuclidValue::BusinessLabel(bl), + dir::DirValue::SetupFutureUsage(sfu) => EuclidValue::SetupFutureUsage(sfu), + }) +} + +fn lower_comparison( + dir_comparison: dir::DirComparison, +) -> Result { + Ok(vir::ValuedComparison { + values: dir_comparison + .values + .into_iter() + .map(lower_value) + .collect::>()?, + logic: match dir_comparison.logic { + dir::DirComparisonLogic::NegativeConjunction => { + vir::ValuedComparisonLogic::NegativeConjunction + } + dir::DirComparisonLogic::PositiveDisjunction => { + vir::ValuedComparisonLogic::PositiveDisjunction + } + }, + metadata: dir_comparison.metadata, + }) +} + +fn lower_if_statement( + dir_if_statement: dir::DirIfStatement, +) -> Result { + Ok(vir::ValuedIfStatement { + condition: dir_if_statement + .condition + .into_iter() + .map(lower_comparison) + .collect::>()?, + nested: dir_if_statement + .nested + .map(|v| { + v.into_iter() + .map(lower_if_statement) + .collect::>() + }) + .transpose()?, + }) +} + +fn lower_rule(dir_rule: dir::DirRule) -> Result, AnalysisErrorType> { + Ok(vir::ValuedRule { + name: dir_rule.name, + connector_selection: dir_rule.connector_selection, + statements: dir_rule + .statements + .into_iter() + .map(lower_if_statement) + .collect::>()?, + }) +} + +pub fn lower_program( + dir_program: dir::DirProgram, +) -> Result, AnalysisError> { + Ok(vir::ValuedProgram { + default_selection: dir_program.default_selection, + rules: dir_program + .rules + .into_iter() + .map(lower_rule) + .collect::>() + .map_err(|e| AnalysisError { + error_type: e, + metadata: Default::default(), + })?, + metadata: dir_program.metadata, + }) +} diff --git a/crates/euclid/src/frontend/dir/transformers.rs b/crates/euclid/src/frontend/dir/transformers.rs new file mode 100644 index 000000000000..da413d380c0f --- /dev/null +++ b/crates/euclid/src/frontend/dir/transformers.rs @@ -0,0 +1,166 @@ +use crate::{dirval, dssa::types::AnalysisErrorType, enums as global_enums, frontend::dir}; + +pub trait IntoDirValue { + fn into_dir_value(self) -> Result; +} +impl IntoDirValue for (global_enums::PaymentMethodType, global_enums::PaymentMethod) { + fn into_dir_value(self) -> Result { + match self.0 { + global_enums::PaymentMethodType::Credit => Ok(dirval!(CardType = Credit)), + global_enums::PaymentMethodType::Debit => Ok(dirval!(CardType = Debit)), + global_enums::PaymentMethodType::Giropay => Ok(dirval!(BankRedirectType = Giropay)), + global_enums::PaymentMethodType::Ideal => Ok(dirval!(BankRedirectType = Ideal)), + global_enums::PaymentMethodType::Sofort => Ok(dirval!(BankRedirectType = Sofort)), + global_enums::PaymentMethodType::Eps => Ok(dirval!(BankRedirectType = Eps)), + global_enums::PaymentMethodType::Klarna => Ok(dirval!(PayLaterType = Klarna)), + global_enums::PaymentMethodType::Affirm => Ok(dirval!(PayLaterType = Affirm)), + global_enums::PaymentMethodType::AfterpayClearpay => { + Ok(dirval!(PayLaterType = AfterpayClearpay)) + } + global_enums::PaymentMethodType::GooglePay => Ok(dirval!(WalletType = GooglePay)), + global_enums::PaymentMethodType::ApplePay => Ok(dirval!(WalletType = ApplePay)), + global_enums::PaymentMethodType::Paypal => Ok(dirval!(WalletType = Paypal)), + global_enums::PaymentMethodType::CryptoCurrency => { + Ok(dirval!(CryptoType = CryptoCurrency)) + } + global_enums::PaymentMethodType::Ach => match self.1 { + global_enums::PaymentMethod::BankDebit => Ok(dirval!(BankDebitType = Ach)), + global_enums::PaymentMethod::BankTransfer => Ok(dirval!(BankTransferType = Ach)), + global_enums::PaymentMethod::PayLater + | global_enums::PaymentMethod::Card + | global_enums::PaymentMethod::CardRedirect + | global_enums::PaymentMethod::Wallet + | global_enums::PaymentMethod::BankRedirect + | global_enums::PaymentMethod::Crypto + | global_enums::PaymentMethod::Reward + | global_enums::PaymentMethod::Upi + | global_enums::PaymentMethod::Voucher + | global_enums::PaymentMethod::GiftCard => Err(AnalysisErrorType::NotSupported), + }, + global_enums::PaymentMethodType::Bacs => match self.1 { + global_enums::PaymentMethod::BankDebit => Ok(dirval!(BankDebitType = Bacs)), + global_enums::PaymentMethod::BankTransfer => Ok(dirval!(BankTransferType = Bacs)), + global_enums::PaymentMethod::PayLater + | global_enums::PaymentMethod::Card + | global_enums::PaymentMethod::CardRedirect + | global_enums::PaymentMethod::Wallet + | global_enums::PaymentMethod::BankRedirect + | global_enums::PaymentMethod::Crypto + | global_enums::PaymentMethod::Reward + | global_enums::PaymentMethod::Upi + | global_enums::PaymentMethod::Voucher + | global_enums::PaymentMethod::GiftCard => Err(AnalysisErrorType::NotSupported), + }, + global_enums::PaymentMethodType::Becs => Ok(dirval!(BankDebitType = Becs)), + global_enums::PaymentMethodType::Sepa => match self.1 { + global_enums::PaymentMethod::BankDebit => Ok(dirval!(BankDebitType = Sepa)), + global_enums::PaymentMethod::BankTransfer => Ok(dirval!(BankTransferType = Sepa)), + global_enums::PaymentMethod::PayLater + | global_enums::PaymentMethod::Card + | global_enums::PaymentMethod::CardRedirect + | global_enums::PaymentMethod::Wallet + | global_enums::PaymentMethod::BankRedirect + | global_enums::PaymentMethod::Crypto + | global_enums::PaymentMethod::Reward + | global_enums::PaymentMethod::Upi + | global_enums::PaymentMethod::Voucher + | global_enums::PaymentMethod::GiftCard => Err(AnalysisErrorType::NotSupported), + }, + global_enums::PaymentMethodType::AliPay => Ok(dirval!(WalletType = AliPay)), + global_enums::PaymentMethodType::AliPayHk => Ok(dirval!(WalletType = AliPayHk)), + global_enums::PaymentMethodType::BancontactCard => { + Ok(dirval!(BankRedirectType = BancontactCard)) + } + global_enums::PaymentMethodType::Blik => Ok(dirval!(BankRedirectType = Blik)), + global_enums::PaymentMethodType::MbWay => Ok(dirval!(WalletType = MbWay)), + global_enums::PaymentMethodType::MobilePay => Ok(dirval!(WalletType = MobilePay)), + global_enums::PaymentMethodType::Cashapp => Ok(dirval!(WalletType = Cashapp)), + global_enums::PaymentMethodType::Multibanco => { + Ok(dirval!(BankTransferType = Multibanco)) + } + global_enums::PaymentMethodType::Pix => Ok(dirval!(BankTransferType = Pix)), + global_enums::PaymentMethodType::Pse => Ok(dirval!(BankTransferType = Pse)), + global_enums::PaymentMethodType::Interac => Ok(dirval!(BankRedirectType = Interac)), + global_enums::PaymentMethodType::OnlineBankingCzechRepublic => { + Ok(dirval!(BankRedirectType = OnlineBankingCzechRepublic)) + } + global_enums::PaymentMethodType::OnlineBankingFinland => { + Ok(dirval!(BankRedirectType = OnlineBankingFinland)) + } + global_enums::PaymentMethodType::OnlineBankingPoland => { + Ok(dirval!(BankRedirectType = OnlineBankingPoland)) + } + global_enums::PaymentMethodType::OnlineBankingSlovakia => { + Ok(dirval!(BankRedirectType = OnlineBankingSlovakia)) + } + global_enums::PaymentMethodType::Swish => Ok(dirval!(WalletType = Swish)), + global_enums::PaymentMethodType::Trustly => Ok(dirval!(BankRedirectType = Trustly)), + global_enums::PaymentMethodType::Bizum => Ok(dirval!(BankRedirectType = Bizum)), + + global_enums::PaymentMethodType::PayBright => Ok(dirval!(PayLaterType = PayBright)), + global_enums::PaymentMethodType::Walley => Ok(dirval!(PayLaterType = Walley)), + global_enums::PaymentMethodType::Przelewy24 => { + Ok(dirval!(BankRedirectType = Przelewy24)) + } + global_enums::PaymentMethodType::WeChatPay => Ok(dirval!(WalletType = WeChatPay)), + + global_enums::PaymentMethodType::ClassicReward => { + Ok(dirval!(RewardType = ClassicReward)) + } + global_enums::PaymentMethodType::Evoucher => Ok(dirval!(RewardType = Evoucher)), + global_enums::PaymentMethodType::UpiCollect => Ok(dirval!(UpiType = UpiCollect)), + global_enums::PaymentMethodType::SamsungPay => Ok(dirval!(WalletType = SamsungPay)), + global_enums::PaymentMethodType::GoPay => Ok(dirval!(WalletType = GoPay)), + global_enums::PaymentMethodType::KakaoPay => Ok(dirval!(WalletType = KakaoPay)), + global_enums::PaymentMethodType::Twint => Ok(dirval!(WalletType = Twint)), + global_enums::PaymentMethodType::Gcash => Ok(dirval!(WalletType = Gcash)), + global_enums::PaymentMethodType::Vipps => Ok(dirval!(WalletType = Vipps)), + global_enums::PaymentMethodType::Momo => Ok(dirval!(WalletType = Momo)), + global_enums::PaymentMethodType::Alma => Ok(dirval!(PayLaterType = Alma)), + global_enums::PaymentMethodType::Dana => Ok(dirval!(WalletType = Dana)), + global_enums::PaymentMethodType::OnlineBankingFpx => { + Ok(dirval!(BankRedirectType = OnlineBankingFpx)) + } + global_enums::PaymentMethodType::OnlineBankingThailand => { + Ok(dirval!(BankRedirectType = OnlineBankingThailand)) + } + global_enums::PaymentMethodType::TouchNGo => Ok(dirval!(WalletType = TouchNGo)), + global_enums::PaymentMethodType::Atome => Ok(dirval!(PayLaterType = Atome)), + global_enums::PaymentMethodType::Boleto => Ok(dirval!(VoucherType = Boleto)), + global_enums::PaymentMethodType::Efecty => Ok(dirval!(VoucherType = Efecty)), + global_enums::PaymentMethodType::PagoEfectivo => { + Ok(dirval!(VoucherType = PagoEfectivo)) + } + global_enums::PaymentMethodType::RedCompra => Ok(dirval!(VoucherType = RedCompra)), + global_enums::PaymentMethodType::RedPagos => Ok(dirval!(VoucherType = RedPagos)), + global_enums::PaymentMethodType::Alfamart => Ok(dirval!(VoucherType = Alfamart)), + global_enums::PaymentMethodType::BcaBankTransfer => { + Ok(dirval!(BankTransferType = BcaBankTransfer)) + } + global_enums::PaymentMethodType::BniVa => Ok(dirval!(BankTransferType = BniVa)), + global_enums::PaymentMethodType::BriVa => Ok(dirval!(BankTransferType = BriVa)), + global_enums::PaymentMethodType::CimbVa => Ok(dirval!(BankTransferType = CimbVa)), + global_enums::PaymentMethodType::DanamonVa => Ok(dirval!(BankTransferType = DanamonVa)), + global_enums::PaymentMethodType::Indomaret => Ok(dirval!(VoucherType = Indomaret)), + global_enums::PaymentMethodType::MandiriVa => Ok(dirval!(BankTransferType = MandiriVa)), + global_enums::PaymentMethodType::PermataBankTransfer => { + Ok(dirval!(BankTransferType = PermataBankTransfer)) + } + global_enums::PaymentMethodType::PaySafeCard => Ok(dirval!(GiftCardType = PaySafeCard)), + global_enums::PaymentMethodType::SevenEleven => Ok(dirval!(VoucherType = SevenEleven)), + global_enums::PaymentMethodType::Lawson => Ok(dirval!(VoucherType = Lawson)), + global_enums::PaymentMethodType::MiniStop => Ok(dirval!(VoucherType = MiniStop)), + global_enums::PaymentMethodType::FamilyMart => Ok(dirval!(VoucherType = FamilyMart)), + global_enums::PaymentMethodType::Seicomart => Ok(dirval!(VoucherType = Seicomart)), + global_enums::PaymentMethodType::PayEasy => Ok(dirval!(VoucherType = PayEasy)), + global_enums::PaymentMethodType::Givex => Ok(dirval!(GiftCardType = Givex)), + global_enums::PaymentMethodType::Benefit => Ok(dirval!(CardRedirectType = Benefit)), + global_enums::PaymentMethodType::Knet => Ok(dirval!(CardRedirectType = Knet)), + global_enums::PaymentMethodType::OpenBankingUk => { + Ok(dirval!(BankRedirectType = OpenBankingUk)) + } + global_enums::PaymentMethodType::MomoAtm => Ok(dirval!(CardRedirectType = MomoAtm)), + global_enums::PaymentMethodType::Oxxo => Ok(dirval!(VoucherType = Oxxo)), + } + } +} diff --git a/crates/euclid/src/frontend/vir.rs b/crates/euclid/src/frontend/vir.rs new file mode 100644 index 000000000000..750ff4e61ff8 --- /dev/null +++ b/crates/euclid/src/frontend/vir.rs @@ -0,0 +1,37 @@ +//! Valued Intermediate Representation +use crate::types::{EuclidValue, Metadata}; + +#[derive(Debug, Clone)] +pub enum ValuedComparisonLogic { + NegativeConjunction, + PositiveDisjunction, +} + +#[derive(Clone, Debug)] +pub struct ValuedComparison { + pub values: Vec, + pub logic: ValuedComparisonLogic, + pub metadata: Metadata, +} + +pub type ValuedIfCondition = Vec; + +#[derive(Clone, Debug)] +pub struct ValuedIfStatement { + pub condition: ValuedIfCondition, + pub nested: Option>, +} + +#[derive(Clone, Debug)] +pub struct ValuedRule { + pub name: String, + pub connector_selection: O, + pub statements: Vec, +} + +#[derive(Clone, Debug)] +pub struct ValuedProgram { + pub default_selection: O, + pub rules: Vec>, + pub metadata: Metadata, +} diff --git a/crates/euclid/src/lib.rs b/crates/euclid/src/lib.rs new file mode 100644 index 000000000000..d64297437aeb --- /dev/null +++ b/crates/euclid/src/lib.rs @@ -0,0 +1,7 @@ +#![allow(clippy::result_large_err)] +pub mod backend; +pub mod dssa; +pub mod enums; +pub mod frontend; +pub mod types; +pub mod utils; diff --git a/crates/euclid/src/types.rs b/crates/euclid/src/types.rs new file mode 100644 index 000000000000..59736ae65125 --- /dev/null +++ b/crates/euclid/src/types.rs @@ -0,0 +1,318 @@ +pub mod transformers; + +use euclid_macros::EnumNums; +use serde::Serialize; +use strum::VariantNames; + +use crate::{ + dssa::types::EuclidAnalysable, + enums, + frontend::{ + ast, + dir::{DirKeyKind, DirValue, EuclidDirFilter}, + }, +}; + +pub type Metadata = std::collections::HashMap; + +#[derive( + Debug, + Clone, + EnumNums, + Hash, + PartialEq, + Eq, + strum::Display, + strum::EnumVariantNames, + strum::EnumString, +)] +pub enum EuclidKey { + #[strum(serialize = "payment_method")] + PaymentMethod, + #[strum(serialize = "card_bin")] + CardBin, + #[strum(serialize = "metadata")] + Metadata, + #[strum(serialize = "mandate_type")] + MandateType, + #[strum(serialize = "mandate_acceptance_type")] + MandateAcceptanceType, + #[strum(serialize = "payment_type")] + PaymentType, + #[strum(serialize = "payment_method_type")] + PaymentMethodType, + #[strum(serialize = "card_network")] + CardNetwork, + #[strum(serialize = "authentication_type")] + AuthenticationType, + #[strum(serialize = "capture_method")] + CaptureMethod, + #[strum(serialize = "amount")] + PaymentAmount, + #[strum(serialize = "currency")] + PaymentCurrency, + #[strum(serialize = "country", to_string = "business_country")] + BusinessCountry, + #[strum(serialize = "billing_country")] + BillingCountry, + #[strum(serialize = "business_label")] + BusinessLabel, + #[strum(serialize = "setup_future_usage")] + SetupFutureUsage, +} +impl EuclidDirFilter for DummyOutput { + const ALLOWED: &'static [DirKeyKind] = &[ + DirKeyKind::AuthenticationType, + DirKeyKind::PaymentMethod, + DirKeyKind::CardType, + DirKeyKind::PaymentCurrency, + DirKeyKind::CaptureMethod, + DirKeyKind::AuthenticationType, + DirKeyKind::CardBin, + DirKeyKind::PayLaterType, + DirKeyKind::PaymentAmount, + DirKeyKind::MetaData, + DirKeyKind::MandateAcceptanceType, + DirKeyKind::MandateType, + DirKeyKind::PaymentType, + DirKeyKind::SetupFutureUsage, + ]; +} +impl EuclidAnalysable for DummyOutput { + fn get_dir_value_for_analysis(&self, rule_name: String) -> Vec<(DirValue, Metadata)> { + self.outputs + .iter() + .map(|dummyc| { + let metadata_key = "MetadataKey".to_string(); + let metadata_value = dummyc; + ( + DirValue::MetaData(MetadataValue { + key: metadata_key.clone(), + value: metadata_value.clone(), + }), + std::collections::HashMap::from_iter([( + "DUMMY_OUTPUT".to_string(), + serde_json::json!({ + "rule_name":rule_name, + "Metadata_Key" :metadata_key, + "Metadata_Value" : metadata_value, + }), + )]), + ) + }) + .collect() + } +} +#[derive(Debug, Clone, Serialize)] +pub struct DummyOutput { + pub outputs: Vec, +} + +#[derive(Debug, Clone, serde::Serialize, strum::Display)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum DataType { + Number, + EnumVariant, + MetadataValue, + StrValue, +} + +impl EuclidKey { + pub fn key_type(&self) -> DataType { + match self { + Self::PaymentMethod => DataType::EnumVariant, + Self::CardBin => DataType::StrValue, + Self::Metadata => DataType::MetadataValue, + Self::PaymentMethodType => DataType::EnumVariant, + Self::CardNetwork => DataType::EnumVariant, + Self::AuthenticationType => DataType::EnumVariant, + Self::CaptureMethod => DataType::EnumVariant, + Self::PaymentAmount => DataType::Number, + Self::PaymentCurrency => DataType::EnumVariant, + Self::BusinessCountry => DataType::EnumVariant, + Self::BillingCountry => DataType::EnumVariant, + Self::MandateType => DataType::EnumVariant, + Self::MandateAcceptanceType => DataType::EnumVariant, + Self::PaymentType => DataType::EnumVariant, + Self::BusinessLabel => DataType::StrValue, + Self::SetupFutureUsage => DataType::EnumVariant, + } + } +} + +enums::collect_variants!(EuclidKey); + +#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize)] +#[serde(rename_all = "snake_case")] +pub enum NumValueRefinement { + NotEqual, + GreaterThan, + LessThan, + GreaterThanEqual, + LessThanEqual, +} + +impl From for Option { + fn from(comp_type: ast::ComparisonType) -> Self { + match comp_type { + ast::ComparisonType::Equal => None, + ast::ComparisonType::NotEqual => Some(NumValueRefinement::NotEqual), + ast::ComparisonType::GreaterThan => Some(NumValueRefinement::GreaterThan), + ast::ComparisonType::LessThan => Some(NumValueRefinement::LessThan), + ast::ComparisonType::LessThanEqual => Some(NumValueRefinement::LessThanEqual), + ast::ComparisonType::GreaterThanEqual => Some(NumValueRefinement::GreaterThanEqual), + } + } +} + +impl From for ast::ComparisonType { + fn from(value: NumValueRefinement) -> Self { + match value { + NumValueRefinement::NotEqual => Self::NotEqual, + NumValueRefinement::LessThan => Self::LessThan, + NumValueRefinement::GreaterThan => Self::GreaterThan, + NumValueRefinement::GreaterThanEqual => Self::GreaterThanEqual, + NumValueRefinement::LessThanEqual => Self::LessThanEqual, + } + } +} + +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, serde::Serialize)] +pub struct StrValue { + pub value: String, +} + +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, serde::Serialize)] +pub struct MetadataValue { + pub key: String, + pub value: String, +} + +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, serde::Serialize)] +pub struct NumValue { + pub number: i64, + pub refinement: Option, +} + +impl NumValue { + pub fn fits(&self, other: &Self) -> bool { + let this_num = self.number; + let other_num = other.number; + + match (&self.refinement, &other.refinement) { + (None, None) => this_num == other_num, + + (Some(NumValueRefinement::GreaterThan), None) => other_num > this_num, + + (Some(NumValueRefinement::LessThan), None) => other_num < this_num, + + (Some(NumValueRefinement::NotEqual), Some(NumValueRefinement::NotEqual)) => { + other_num == this_num + } + + (Some(NumValueRefinement::GreaterThan), Some(NumValueRefinement::GreaterThan)) => { + other_num > this_num + } + (Some(NumValueRefinement::LessThan), Some(NumValueRefinement::LessThan)) => { + other_num < this_num + } + + (Some(NumValueRefinement::GreaterThanEqual), None) => other_num >= this_num, + (Some(NumValueRefinement::LessThanEqual), None) => other_num <= this_num, + ( + Some(NumValueRefinement::GreaterThanEqual), + Some(NumValueRefinement::GreaterThanEqual), + ) => other_num >= this_num, + + (Some(NumValueRefinement::LessThanEqual), Some(NumValueRefinement::LessThanEqual)) => { + other_num <= this_num + } + + _ => false, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum EuclidValue { + PaymentMethod(enums::PaymentMethod), + CardBin(StrValue), + Metadata(MetadataValue), + PaymentMethodType(enums::PaymentMethodType), + CardNetwork(enums::CardNetwork), + AuthenticationType(enums::AuthenticationType), + CaptureMethod(enums::CaptureMethod), + PaymentType(enums::PaymentType), + MandateAcceptanceType(enums::MandateAcceptanceType), + MandateType(enums::MandateType), + PaymentAmount(NumValue), + PaymentCurrency(enums::Currency), + BusinessCountry(enums::Country), + BillingCountry(enums::Country), + BusinessLabel(StrValue), + SetupFutureUsage(enums::SetupFutureUsage), +} + +impl EuclidValue { + pub fn get_num_value(&self) -> Option { + match self { + Self::PaymentAmount(val) => Some(val.clone()), + _ => None, + } + } + + pub fn get_key(&self) -> EuclidKey { + match self { + Self::PaymentMethod(_) => EuclidKey::PaymentMethod, + Self::CardBin(_) => EuclidKey::CardBin, + Self::Metadata(_) => EuclidKey::Metadata, + Self::PaymentMethodType(_) => EuclidKey::PaymentMethodType, + Self::MandateType(_) => EuclidKey::MandateType, + Self::PaymentType(_) => EuclidKey::PaymentType, + Self::MandateAcceptanceType(_) => EuclidKey::MandateAcceptanceType, + Self::CardNetwork(_) => EuclidKey::CardNetwork, + Self::AuthenticationType(_) => EuclidKey::AuthenticationType, + Self::CaptureMethod(_) => EuclidKey::CaptureMethod, + Self::PaymentAmount(_) => EuclidKey::PaymentAmount, + Self::PaymentCurrency(_) => EuclidKey::PaymentCurrency, + Self::BusinessCountry(_) => EuclidKey::BusinessCountry, + Self::BillingCountry(_) => EuclidKey::BillingCountry, + Self::BusinessLabel(_) => EuclidKey::BusinessLabel, + Self::SetupFutureUsage(_) => EuclidKey::SetupFutureUsage, + } + } +} + +#[cfg(test)] +mod global_type_tests { + use super::*; + + #[test] + fn test_num_value_fits_greater_than() { + let val1 = NumValue { + number: 10, + refinement: Some(NumValueRefinement::GreaterThan), + }; + let val2 = NumValue { + number: 30, + refinement: Some(NumValueRefinement::GreaterThan), + }; + + assert!(val1.fits(&val2)) + } + + #[test] + fn test_num_value_fits_less_than() { + let val1 = NumValue { + number: 30, + refinement: Some(NumValueRefinement::LessThan), + }; + let val2 = NumValue { + number: 10, + refinement: Some(NumValueRefinement::LessThan), + }; + + assert!(val1.fits(&val2)); + } +} diff --git a/crates/euclid/src/types/transformers.rs b/crates/euclid/src/types/transformers.rs new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/crates/euclid/src/types/transformers.rs @@ -0,0 +1 @@ + diff --git a/crates/euclid/src/utils.rs b/crates/euclid/src/utils.rs new file mode 100644 index 000000000000..e8cb7901f0d7 --- /dev/null +++ b/crates/euclid/src/utils.rs @@ -0,0 +1,3 @@ +pub mod dense_map; + +pub use dense_map::{DenseMap, EntityId}; diff --git a/crates/euclid/src/utils/dense_map.rs b/crates/euclid/src/utils/dense_map.rs new file mode 100644 index 000000000000..8bd4487c77b9 --- /dev/null +++ b/crates/euclid/src/utils/dense_map.rs @@ -0,0 +1,224 @@ +use std::{fmt, iter, marker::PhantomData, ops, slice, vec}; + +pub trait EntityId { + fn get_id(&self) -> usize; + fn with_id(id: usize) -> Self; +} + +pub struct DenseMap { + data: Vec, + _marker: PhantomData, +} + +impl DenseMap { + pub fn new() -> Self { + Self { + data: Vec::new(), + _marker: PhantomData, + } + } +} + +impl Default for DenseMap { + fn default() -> Self { + Self::new() + } +} + +impl DenseMap +where + K: EntityId, +{ + pub fn push(&mut self, elem: V) -> K { + let curr_len = self.data.len(); + self.data.push(elem); + K::with_id(curr_len) + } + + #[inline] + pub fn get(&self, idx: K) -> Option<&V> { + self.data.get(idx.get_id()) + } + + #[inline] + pub fn get_mut(&mut self, idx: K) -> Option<&mut V> { + self.data.get_mut(idx.get_id()) + } + + #[inline] + pub fn contains_key(&self, key: K) -> bool { + key.get_id() < self.data.len() + } + + #[inline] + pub fn keys(&self) -> Keys { + Keys::new(0..self.data.len()) + } + + #[inline] + pub fn into_keys(self) -> Keys { + Keys::new(0..self.data.len()) + } + + #[inline] + pub fn values(&self) -> slice::Iter<'_, V> { + self.data.iter() + } + + #[inline] + pub fn values_mut(&mut self) -> slice::IterMut<'_, V> { + self.data.iter_mut() + } + + #[inline] + pub fn into_values(self) -> vec::IntoIter { + self.data.into_iter() + } + + #[inline] + pub fn iter(&self) -> Iter<'_, K, V> { + Iter::new(self.data.iter()) + } + + #[inline] + pub fn iter_mut(&mut self) -> IterMut<'_, K, V> { + IterMut::new(self.data.iter_mut()) + } +} + +impl fmt::Debug for DenseMap +where + K: EntityId + fmt::Debug, + V: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_map().entries(self.iter()).finish() + } +} + +pub struct Keys { + inner: ops::Range, + _marker: PhantomData, +} + +impl Keys { + fn new(range: ops::Range) -> Self { + Self { + inner: range, + _marker: PhantomData, + } + } +} + +impl Iterator for Keys +where + K: EntityId, +{ + type Item = K; + + fn next(&mut self) -> Option { + self.inner.next().map(K::with_id) + } +} + +pub struct Iter<'a, K, V> { + inner: iter::Enumerate>, + _marker: PhantomData, +} + +impl<'a, K, V> Iter<'a, K, V> { + fn new(iter: slice::Iter<'a, V>) -> Self { + Self { + inner: iter.enumerate(), + _marker: PhantomData, + } + } +} + +impl<'a, K, V> Iterator for Iter<'a, K, V> +where + K: EntityId, +{ + type Item = (K, &'a V); + + fn next(&mut self) -> Option { + self.inner.next().map(|(id, val)| (K::with_id(id), val)) + } +} + +pub struct IterMut<'a, K, V> { + inner: iter::Enumerate>, + _marker: PhantomData, +} + +impl<'a, K, V> IterMut<'a, K, V> { + fn new(iter: slice::IterMut<'a, V>) -> Self { + Self { + inner: iter.enumerate(), + _marker: PhantomData, + } + } +} + +impl<'a, K, V> Iterator for IterMut<'a, K, V> +where + K: EntityId, +{ + type Item = (K, &'a mut V); + + fn next(&mut self) -> Option { + self.inner.next().map(|(id, val)| (K::with_id(id), val)) + } +} + +pub struct IntoIter { + inner: iter::Enumerate>, + _marker: PhantomData, +} + +impl IntoIter { + fn new(iter: vec::IntoIter) -> Self { + Self { + inner: iter.enumerate(), + _marker: PhantomData, + } + } +} + +impl Iterator for IntoIter +where + K: EntityId, +{ + type Item = (K, V); + + fn next(&mut self) -> Option { + self.inner.next().map(|(id, val)| (K::with_id(id), val)) + } +} + +impl IntoIterator for DenseMap +where + K: EntityId, +{ + type Item = (K, V); + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self.data.into_iter()) + } +} + +impl FromIterator for DenseMap +where + K: EntityId, +{ + fn from_iter(iter: T) -> Self + where + T: IntoIterator, + { + Self { + data: Vec::from_iter(iter), + _marker: PhantomData, + } + } +} diff --git a/crates/euclid_macros/Cargo.toml b/crates/euclid_macros/Cargo.toml new file mode 100644 index 000000000000..2524887a8a0f --- /dev/null +++ b/crates/euclid_macros/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "euclid_macros" +description = "Macros for Euclid DSL" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0.51" +quote = "1.0.23" +rustc-hash = "1.1.0" +strum = { version = "0.24", features = ["derive"] } +syn = "1.0.109" diff --git a/crates/euclid_macros/src/inner.rs b/crates/euclid_macros/src/inner.rs new file mode 100644 index 000000000000..979527560dd6 --- /dev/null +++ b/crates/euclid_macros/src/inner.rs @@ -0,0 +1,5 @@ +mod enum_nums; +mod knowledge; + +pub(crate) use enum_nums::enum_nums_inner; +pub(crate) use knowledge::knowledge_inner; diff --git a/crates/euclid_macros/src/inner/enum_nums.rs b/crates/euclid_macros/src/inner/enum_nums.rs new file mode 100644 index 000000000000..61f6765fce0e --- /dev/null +++ b/crates/euclid_macros/src/inner/enum_nums.rs @@ -0,0 +1,47 @@ +use proc_macro::TokenStream; +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::quote; + +fn error() -> TokenStream2 { + syn::Error::new( + Span::call_site(), + "'EnumNums' can only be derived on enums with unit variants".to_string(), + ) + .to_compile_error() +} + +pub(crate) fn enum_nums_inner(ts: TokenStream) -> TokenStream { + let derive_input = syn::parse_macro_input!(ts as syn::DeriveInput); + + let enum_obj = match derive_input.data { + syn::Data::Enum(e) => e, + _ => return error().into(), + }; + + let enum_name = derive_input.ident; + + let mut match_arms = Vec::::with_capacity(enum_obj.variants.len()); + + for (i, variant) in enum_obj.variants.iter().enumerate() { + match variant.fields { + syn::Fields::Unit => {} + _ => return error().into(), + } + + let var_ident = &variant.ident; + + match_arms.push(quote! { Self::#var_ident => #i }); + } + + let impl_block = quote! { + impl #enum_name { + pub fn to_num(&self) -> usize { + match self { + #(#match_arms),* + } + } + } + }; + + impl_block.into() +} diff --git a/crates/euclid_macros/src/inner/knowledge.rs b/crates/euclid_macros/src/inner/knowledge.rs new file mode 100644 index 000000000000..73b94919c903 --- /dev/null +++ b/crates/euclid_macros/src/inner/knowledge.rs @@ -0,0 +1,680 @@ +use std::{hash::Hash, rc::Rc}; + +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote}; +use rustc_hash::{FxHashMap, FxHashSet}; +use syn::{parse::Parse, Token}; + +mod strength { + syn::custom_punctuation!(Normal, ->); + syn::custom_punctuation!(Strong, ->>); +} + +mod kw { + syn::custom_keyword!(any); + syn::custom_keyword!(not); +} + +#[derive(Clone, PartialEq, Eq, Hash)] +enum Comparison { + LessThan, + Equal, + GreaterThan, + GreaterThanEqual, + LessThanEqual, +} + +impl ToString for Comparison { + fn to_string(&self) -> String { + match self { + Self::LessThan => "< ".to_string(), + Self::Equal => String::new(), + Self::GreaterThanEqual => ">= ".to_string(), + Self::LessThanEqual => "<= ".to_string(), + Self::GreaterThan => "> ".to_string(), + } + } +} + +impl Parse for Comparison { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + if input.peek(Token![>]) { + input.parse::]>()?; + Ok(Self::GreaterThan) + } else if input.peek(Token![<]) { + input.parse::()?; + Ok(Self::LessThan) + } else if input.peek(Token!(<=)) { + input.parse::()?; + Ok(Self::LessThanEqual) + } else if input.peek(Token!(>=)) { + input.parse::=]>()?; + Ok(Self::GreaterThanEqual) + } else { + Ok(Self::Equal) + } + } +} + +#[derive(Clone, PartialEq, Eq, Hash)] +enum ValueType { + Any, + EnumVariant(String), + Number { number: i64, comparison: Comparison }, +} + +impl ValueType { + fn to_string(&self, key: &str) -> String { + match self { + Self::Any => format!("{key}(any)"), + Self::EnumVariant(s) => format!("{key}({s})"), + Self::Number { number, comparison } => { + format!("{}({}{})", key, comparison.to_string(), number) + } + } + } +} + +impl Parse for ValueType { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + let lookahead = input.lookahead1(); + if lookahead.peek(syn::Ident) { + let ident: syn::Ident = input.parse()?; + Ok(Self::EnumVariant(ident.to_string())) + } else if lookahead.peek(Token![>]) + || lookahead.peek(Token![<]) + || lookahead.peek(syn::LitInt) + { + let comparison: Comparison = input.parse()?; + let number: syn::LitInt = input.parse()?; + let num_val = number.base10_parse::()?; + Ok(Self::Number { + number: num_val, + comparison, + }) + } else { + Err(lookahead.error()) + } + } +} + +#[derive(Clone, PartialEq, Eq, Hash)] +struct Atom { + key: String, + value: ValueType, +} + +impl ToString for Atom { + fn to_string(&self) -> String { + self.value.to_string(&self.key) + } +} + +impl Parse for Atom { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + let maybe_any: syn::Ident = input.parse()?; + if maybe_any == "any" { + let actual_key: syn::Ident = input.parse()?; + Ok(Self { + key: actual_key.to_string(), + value: ValueType::Any, + }) + } else { + let content; + syn::parenthesized!(content in input); + let value: ValueType = content.parse()?; + Ok(Self { + key: maybe_any.to_string(), + value, + }) + } + } +} + +#[derive(Clone, PartialEq, Eq, Hash, strum::Display)] +enum Strength { + Normal, + Strong, +} + +impl Parse for Strength { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + let lookahead = input.lookahead1(); + if lookahead.peek(strength::Strong) { + input.parse::()?; + Ok(Self::Strong) + } else if lookahead.peek(strength::Normal) { + input.parse::()?; + Ok(Self::Normal) + } else { + Err(lookahead.error()) + } + } +} + +#[derive(Clone, PartialEq, Eq, Hash, strum::Display)] +enum Relation { + Positive, + Negative, +} + +enum AtomType { + Value { + relation: Relation, + atom: Rc, + }, + + InAggregator { + key: String, + values: Vec, + relation: Relation, + }, +} + +fn parse_atom_type_inner( + input: syn::parse::ParseStream<'_>, + key: syn::Ident, + relation: Relation, +) -> syn::Result { + let result = if input.peek(Token![in]) { + input.parse::()?; + + let bracketed; + syn::bracketed!(bracketed in input); + + let mut values = Vec::::new(); + let first: syn::Ident = bracketed.parse()?; + values.push(first.to_string()); + while !bracketed.is_empty() { + bracketed.parse::()?; + let next: syn::Ident = bracketed.parse()?; + values.push(next.to_string()); + } + + AtomType::InAggregator { + key: key.to_string(), + values, + relation, + } + } else if input.peek(kw::any) { + input.parse::()?; + AtomType::Value { + relation, + atom: Rc::new(Atom { + key: key.to_string(), + value: ValueType::Any, + }), + } + } else { + let value: ValueType = input.parse()?; + AtomType::Value { + relation, + atom: Rc::new(Atom { + key: key.to_string(), + value, + }), + } + }; + + Ok(result) +} + +impl Parse for AtomType { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + let key: syn::Ident = input.parse()?; + let content; + syn::parenthesized!(content in input); + + let relation = if content.peek(kw::not) { + content.parse::()?; + Relation::Negative + } else { + Relation::Positive + }; + + let result = parse_atom_type_inner(&content, key, relation)?; + + if !content.is_empty() { + Err(content.error("Unexpected input received after atom value")) + } else { + Ok(result) + } + } +} + +fn parse_rhs_atom(input: syn::parse::ParseStream<'_>) -> syn::Result { + let key: syn::Ident = input.parse()?; + let content; + syn::parenthesized!(content in input); + + let lookahead = content.lookahead1(); + + let value_type = if lookahead.peek(kw::any) { + content.parse::()?; + ValueType::Any + } else if lookahead.peek(syn::Ident) { + let variant = content.parse::()?; + ValueType::EnumVariant(variant.to_string()) + } else { + return Err(lookahead.error()); + }; + + if !content.is_empty() { + Err(content.error("Unexpected input received after atom value")) + } else { + Ok(Atom { + key: key.to_string(), + value: value_type, + }) + } +} + +struct Rule { + lhs: Vec, + strength: Strength, + rhs: Rc, +} + +impl Parse for Rule { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + let first_atom: AtomType = input.parse()?; + let mut lhs: Vec = vec![first_atom]; + + while input.peek(Token![&]) { + input.parse::()?; + let and_atom: AtomType = input.parse()?; + lhs.push(and_atom); + } + + let strength: Strength = input.parse()?; + + let rhs: Rc = Rc::new(parse_rhs_atom(input)?); + + input.parse::()?; + + Ok(Self { lhs, strength, rhs }) + } +} + +#[derive(Clone)] +enum Scope { + Crate, + Extern, +} + +impl Parse for Scope { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + let lookahead = input.lookahead1(); + if lookahead.peek(Token![crate]) { + input.parse::()?; + Ok(Self::Crate) + } else if lookahead.peek(Token![extern]) { + input.parse::()?; + Ok(Self::Extern) + } else { + Err(lookahead.error()) + } + } +} + +impl ToString for Scope { + fn to_string(&self) -> String { + match self { + Self::Crate => "crate".to_string(), + Self::Extern => "euclid".to_string(), + } + } +} + +#[derive(Clone)] +struct Program { + rules: Vec>, + scope: Scope, +} + +impl Parse for Program { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + let scope: Scope = input.parse()?; + let mut rules: Vec> = Vec::new(); + + while !input.is_empty() { + rules.push(Rc::new(input.parse::()?)); + } + + Ok(Self { rules, scope }) + } +} + +struct GenContext { + next_idx: usize, + next_node_idx: usize, + idx2atom: FxHashMap>, + atom2idx: FxHashMap, usize>, + edges: FxHashMap>, + compiled_atoms: FxHashMap, proc_macro2::Ident>, +} + +impl GenContext { + fn new() -> Self { + Self { + next_idx: 1, + next_node_idx: 1, + idx2atom: FxHashMap::default(), + atom2idx: FxHashMap::default(), + edges: FxHashMap::default(), + compiled_atoms: FxHashMap::default(), + } + } + + fn register_node(&mut self, atom: Rc) -> usize { + if let Some(idx) = self.atom2idx.get(&atom) { + *idx + } else { + let this_idx = self.next_idx; + self.next_idx += 1; + + self.idx2atom.insert(this_idx, Rc::clone(&atom)); + self.atom2idx.insert(atom, this_idx); + + this_idx + } + } + + fn register_edge(&mut self, from: usize, to: usize) -> Result<(), String> { + let node_children = self.edges.entry(from).or_default(); + if node_children.contains(&to) { + Err("Duplicate edge detected".to_string()) + } else { + node_children.insert(to); + self.edges.entry(to).or_default(); + Ok(()) + } + } + + fn register_rule(&mut self, rule: &Rule) -> Result<(), String> { + let to_idx = self.register_node(Rc::clone(&rule.rhs)); + + for atom_type in &rule.lhs { + if let AtomType::Value { atom, .. } = atom_type { + let from_idx = self.register_node(Rc::clone(atom)); + self.register_edge(from_idx, to_idx)?; + } + } + + Ok(()) + } + + fn cycle_dfs( + &self, + node_id: usize, + explored: &mut FxHashSet, + visited: &mut FxHashSet, + order: &mut Vec, + ) -> Result>, String> { + if explored.contains(&node_id) { + let position = order + .iter() + .position(|v| *v == node_id) + .ok_or_else(|| "Error deciding cycle order".to_string())?; + + let cycle_order = order[position..].to_vec(); + Ok(Some(cycle_order)) + } else if visited.contains(&node_id) { + Ok(None) + } else { + visited.insert(node_id); + explored.insert(node_id); + order.push(node_id); + let dests = self + .edges + .get(&node_id) + .ok_or_else(|| "Error getting edges of node".to_string())?; + + for dest in dests.iter().copied() { + if let Some(cycle) = self.cycle_dfs(dest, explored, visited, order)? { + return Ok(Some(cycle)); + } + } + + order.pop(); + + Ok(None) + } + } + + fn detect_graph_cycles(&self) -> Result<(), String> { + let start_nodes = self.edges.keys().copied().collect::>(); + + let mut total_visited = FxHashSet::::default(); + + for node_id in start_nodes.iter().copied() { + let mut explored = FxHashSet::::default(); + let mut order = Vec::::new(); + + match self.cycle_dfs(node_id, &mut explored, &mut total_visited, &mut order)? { + None => {} + Some(order) => { + let mut display_strings = Vec::::with_capacity(order.len() + 1); + + for cycle_node_id in order { + let node = self.idx2atom.get(&cycle_node_id).ok_or_else(|| { + "Failed to find node during cycle display creation".to_string() + })?; + + display_strings.push(node.to_string()); + } + + let first = display_strings + .first() + .cloned() + .ok_or("Unable to fill cycle display array")?; + + display_strings.push(first); + + return Err(format!("Found cycle: {}", display_strings.join(" -> "))); + } + } + } + + Ok(()) + } + + fn next_node_ident(&mut self) -> (proc_macro2::Ident, usize) { + let this_idx = self.next_node_idx; + self.next_node_idx += 1; + (format_ident!("_node_{this_idx}"), this_idx) + } + + fn compile_atom( + &mut self, + atom: &Rc, + tokens: &mut TokenStream, + ) -> Result { + let maybe_ident = self.compiled_atoms.get(atom); + + if let Some(ident) = maybe_ident { + Ok(ident.clone()) + } else { + let (identifier, _) = self.next_node_ident(); + let key = format_ident!("{}", &atom.key); + let the_value = match &atom.value { + ValueType::Any => quote! { + NodeValue::Key(DirKey::new(DirKeyKind::#key,None)) + }, + ValueType::EnumVariant(variant) => { + let variant = format_ident!("{}", variant); + quote! { + NodeValue::Value(DirValue::#key(#key::#variant)) + } + } + ValueType::Number { number, comparison } => { + let comp_type = match comparison { + Comparison::Equal => quote! { + None + }, + Comparison::LessThan => quote! { + Some(NumValueRefinement::LessThan) + }, + Comparison::GreaterThan => quote! { + Some(NumValueRefinement::GreaterThan) + }, + Comparison::GreaterThanEqual => quote! { + Some(NumValueRefinement::GreaterThanEqual) + }, + Comparison::LessThanEqual => quote! { + Some(NumValueRefinement::LessThanEqual) + }, + }; + + quote! { + NodeValue::Value(DirValue::#key(NumValue { + number: #number, + refinement: #comp_type, + })) + } + } + }; + + let compiled = quote! { + let #identifier = graph.make_value_node(#the_value, None, Vec::new(), None::<()>).expect("NodeId derivation failed"); + }; + + tokens.extend(compiled); + self.compiled_atoms + .insert(Rc::clone(atom), identifier.clone()); + + Ok(identifier) + } + } + + fn compile_atom_type( + &mut self, + atom_type: &AtomType, + tokens: &mut TokenStream, + ) -> Result<(proc_macro2::Ident, Relation), String> { + match atom_type { + AtomType::Value { relation, atom } => { + let node_ident = self.compile_atom(atom, tokens)?; + + Ok((node_ident, relation.clone())) + } + + AtomType::InAggregator { + key, + values, + relation, + } => { + let key_ident = format_ident!("{key}"); + let mut values_tokens: Vec = Vec::new(); + + for value in values { + let value_ident = format_ident!("{value}"); + values_tokens.push(quote! { DirValue::#key_ident(#key_ident::#value_ident) }); + } + + let (node_ident, _) = self.next_node_ident(); + let node_code = quote! { + let #node_ident = graph.make_in_aggregator( + Vec::from_iter([#(#values_tokens),*]), + None, + None::<()>, + Vec::new(), + ).expect("Failed to make In aggregator"); + }; + + tokens.extend(node_code); + + Ok((node_ident, relation.clone())) + } + } + } + + fn compile_rule(&mut self, rule: &Rule, tokens: &mut TokenStream) -> Result<(), String> { + let rhs_ident = self.compile_atom(&rule.rhs, tokens)?; + let mut node_details: Vec<(proc_macro2::Ident, Relation)> = + Vec::with_capacity(rule.lhs.len()); + for lhs_atom_type in &rule.lhs { + let details = self.compile_atom_type(lhs_atom_type, tokens)?; + node_details.push(details); + } + + if node_details.len() <= 1 { + let strength = format_ident!("{}", rule.strength.to_string()); + for (from_node, relation) in &node_details { + let relation = format_ident!("{}", relation.to_string()); + tokens.extend(quote! { + graph.make_edge(#from_node, #rhs_ident, Strength::#strength, Relation::#relation) + .expect("Failed to make edge"); + }); + } + } else { + let mut all_agg_nodes: Vec = Vec::with_capacity(node_details.len()); + for (from_node, relation) in &node_details { + let relation = format_ident!("{}", relation.to_string()); + all_agg_nodes.push(quote! { (#from_node, Relation::#relation, Strength::Strong) }); + } + + let strength = format_ident!("{}", rule.strength.to_string()); + let (agg_node_ident, _) = self.next_node_ident(); + tokens.extend(quote! { + let #agg_node_ident = graph.make_all_aggregator(&[#(#all_agg_nodes),*], None, None::<()>, Vec::new()) + .expect("Failed to make all aggregator node"); + + graph.make_edge(#agg_node_ident, #rhs_ident, Strength::#strength, Relation::Positive) + .expect("Failed to create all aggregator edge"); + + }); + } + + Ok(()) + } + + fn compile(&mut self, program: Program) -> Result { + let mut tokens = TokenStream::new(); + for rule in &program.rules { + self.compile_rule(rule, &mut tokens)?; + } + + let scope = match &program.scope { + Scope::Crate => quote! { crate }, + Scope::Extern => quote! { euclid }, + }; + + let compiled = quote! {{ + use #scope::{ + dssa::graph::*, + types::*, + frontend::dir::{*, enums::*}, + }; + + use rustc_hash::{FxHashMap, FxHashSet}; + + let mut graph = KnowledgeGraphBuilder::new(); + + #tokens + + graph.build() + }}; + + Ok(compiled) + } +} + +pub(crate) fn knowledge_inner(ts: TokenStream) -> syn::Result { + let program = syn::parse::(ts.into())?; + let mut gen_context = GenContext::new(); + + for rule in &program.rules { + gen_context + .register_rule(rule) + .map_err(|msg| syn::Error::new(Span::call_site(), msg))?; + } + + gen_context + .detect_graph_cycles() + .map_err(|msg| syn::Error::new(Span::call_site(), msg))?; + + gen_context + .compile(program) + .map_err(|msg| syn::Error::new(Span::call_site(), msg)) +} diff --git a/crates/euclid_macros/src/lib.rs b/crates/euclid_macros/src/lib.rs new file mode 100644 index 000000000000..97b42aaa64c1 --- /dev/null +++ b/crates/euclid_macros/src/lib.rs @@ -0,0 +1,16 @@ +mod inner; + +use proc_macro::TokenStream; + +#[proc_macro_derive(EnumNums)] +pub fn enum_nums(ts: TokenStream) -> TokenStream { + inner::enum_nums_inner(ts) +} + +#[proc_macro] +pub fn knowledge(ts: TokenStream) -> TokenStream { + match inner::knowledge_inner(ts.into()) { + Ok(ts) => ts.into(), + Err(e) => e.into_compile_error().into(), + } +} diff --git a/crates/euclid_wasm/Cargo.toml b/crates/euclid_wasm/Cargo.toml new file mode 100644 index 000000000000..90489eb78bf6 --- /dev/null +++ b/crates/euclid_wasm/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "euclid_wasm" +description = "WASM bindings for Euclid DSL" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +crate-type = ["cdylib"] + +[features] +default = ["connector_choice_bcompat", "payouts"] +connector_choice_bcompat = [ + "euclid/connector_choice_bcompat", + "api_models/connector_choice_bcompat", + "kgraph_utils/backwards_compatibility" +] +connector_choice_mca_id = [ + "api_models/connector_choice_mca_id", + "euclid/connector_choice_mca_id", + "kgraph_utils/connector_choice_mca_id" +] +dummy_connector = ["kgraph_utils/dummy_connector"] +payouts = [] + +[dependencies] +api_models = { version = "0.1.0", path = "../api_models", package = "api_models" } +euclid = { path = "../euclid", features = [] } +kgraph_utils = { version = "0.1.0", path = "../kgraph_utils" } +getrandom = { version = "0.2.10", features = ["js"] } +once_cell = "1.18.0" +serde = { version = "1.0", features = [] } +serde-wasm-bindgen = "0.5" +strum = { version = "0.25", features = ["derive"] } +wasm-bindgen = { version = "0.2.86" } +ron-parser = "0.1.4" diff --git a/crates/euclid_wasm/src/lib.rs b/crates/euclid_wasm/src/lib.rs new file mode 100644 index 000000000000..e85a002544ff --- /dev/null +++ b/crates/euclid_wasm/src/lib.rs @@ -0,0 +1,227 @@ +#![allow(non_upper_case_globals)] +mod types; +mod utils; +use std::{ + collections::{HashMap, HashSet}, + str::FromStr, +}; + +use api_models::{admin as admin_api, routing::ConnectorSelection}; +use euclid::{ + backend::{inputs, interpreter::InterpreterBackend, EuclidBackend}, + dssa::{ + self, analyzer, + graph::{self, Memoization}, + state_machine, truth, + }, + enums, + frontend::{ + ast, + dir::{self, enums as dir_enums}, + }, +}; +use once_cell::sync::OnceCell; +use strum::{EnumMessage, EnumProperty, VariantNames}; +use wasm_bindgen::prelude::*; + +use crate::utils::JsResultExt; +type JsResult = Result; + +struct SeedData<'a> { + kgraph: graph::KnowledgeGraph<'a>, + connectors: Vec, +} + +static SEED_DATA: OnceCell> = OnceCell::new(); + +/// This function can be used by the frontend to provide the WASM with information about +/// all the merchant's connector accounts. The input argument is a vector of all the merchant's +/// connector accounts from the API. +#[wasm_bindgen(js_name = seedKnowledgeGraph)] +pub fn seed_knowledge_graph(mcas: JsValue) -> JsResult { + let mcas: Vec = serde_wasm_bindgen::from_value(mcas)?; + let connectors: Vec = mcas + .iter() + .map(|mca| { + Ok::<_, strum::ParseError>(ast::ConnectorChoice { + connector: dir_enums::Connector::from_str(&mca.connector_name)?, + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label: mca.business_sub_label.clone(), + }) + }) + .collect::>() + .map_err(|_| "invalid connector name received") + .err_to_js()?; + + let mca_graph = kgraph_utils::mca::make_mca_graph(mcas).err_to_js()?; + let analysis_graph = + graph::KnowledgeGraph::combine(&mca_graph, &truth::ANALYSIS_GRAPH).err_to_js()?; + + SEED_DATA + .set(SeedData { + kgraph: analysis_graph, + connectors, + }) + .map_err(|_| "Knowledge Graph has been already seeded".to_string()) + .err_to_js()?; + + Ok(JsValue::NULL) +} + +/// This function allows the frontend to get all the merchant's configured +/// connectors that are valid for a rule based on the conditions specified in +/// the rule +#[wasm_bindgen(js_name = getValidConnectorsForRule)] +pub fn get_valid_connectors_for_rule(rule: JsValue) -> JsResult { + let seed_data = SEED_DATA.get().ok_or("Data not seeded").err_to_js()?; + + let rule: ast::Rule = serde_wasm_bindgen::from_value(rule)?; + let dir_rule = ast::lowering::lower_rule(rule).err_to_js()?; + let mut valid_connectors: Vec<(ast::ConnectorChoice, dir::DirValue)> = seed_data + .connectors + .iter() + .cloned() + .map(|choice| (choice.clone(), dir::DirValue::Connector(Box::new(choice)))) + .collect(); + let mut invalid_connectors: HashSet = HashSet::new(); + + let mut ctx_manager = state_machine::RuleContextManager::new(&dir_rule, &[]); + + let dummy_meta = HashMap::new(); + + // For every conjunctive context in the Rule, verify validity of all still-valid connectors + // using the knowledge graph + while let Some(ctx) = ctx_manager.advance_mut().err_to_js()? { + // Standalone conjunctive context analysis to ensure the context itself is valid before + // checking it against merchant's connectors + seed_data + .kgraph + .perform_context_analysis(ctx, &mut Memoization::new()) + .err_to_js()?; + + // Update conjunctive context and run analysis on all of merchant's connectors. + for (conn, choice) in &valid_connectors { + if invalid_connectors.contains(conn) { + continue; + } + + let ctx_val = dssa::types::ContextValue::assertion(choice, &dummy_meta); + ctx.push(ctx_val); + let analysis_result = seed_data + .kgraph + .perform_context_analysis(ctx, &mut Memoization::new()); + if analysis_result.is_err() { + invalid_connectors.insert(conn.clone()); + } + ctx.pop(); + } + } + + valid_connectors.retain(|(k, _)| !invalid_connectors.contains(k)); + + let valid_connectors: Vec = + valid_connectors.into_iter().map(|c| c.0).collect(); + + Ok(serde_wasm_bindgen::to_value(&valid_connectors)?) +} + +#[wasm_bindgen(js_name = analyzeProgram)] +pub fn analyze_program(js_program: JsValue) -> JsResult { + let program: ast::Program = serde_wasm_bindgen::from_value(js_program)?; + analyzer::analyze(program, SEED_DATA.get().map(|sd| &sd.kgraph)).err_to_js()?; + Ok(JsValue::NULL) +} + +#[wasm_bindgen(js_name = runProgram)] +pub fn run_program(program: JsValue, input: JsValue) -> JsResult { + let program: ast::Program = serde_wasm_bindgen::from_value(program)?; + let input: inputs::BackendInput = serde_wasm_bindgen::from_value(input)?; + + let backend = InterpreterBackend::with_program(program).err_to_js()?; + + let res: euclid::backend::BackendOutput = + backend.execute(input).err_to_js()?; + + Ok(serde_wasm_bindgen::to_value(&res)?) +} + +#[wasm_bindgen(js_name = getAllConnectors)] +pub fn get_all_connectors() -> JsResult { + Ok(serde_wasm_bindgen::to_value(enums::Connector::VARIANTS)?) +} + +#[wasm_bindgen(js_name = getAllKeys)] +pub fn get_all_keys() -> JsResult { + let keys: Vec<&'static str> = dir::DirKeyKind::VARIANTS + .iter() + .copied() + .filter(|s| s != &"Connector") + .collect(); + Ok(serde_wasm_bindgen::to_value(&keys)?) +} + +#[wasm_bindgen(js_name = getKeyType)] +pub fn get_key_type(key: &str) -> Result { + let key = dir::DirKeyKind::from_str(key).map_err(|_| "Invalid key received".to_string())?; + let key_str = key.get_type().to_string(); + Ok(key_str) +} + +#[wasm_bindgen(js_name=parseToString)] +pub fn parser(val: String) -> String { + ron_parser::my_parse(val) +} + +#[wasm_bindgen(js_name = getVariantValues)] +pub fn get_variant_values(key: &str) -> Result { + let key = dir::DirKeyKind::from_str(key).map_err(|_| "Invalid key received".to_string())?; + + let variants: &[&str] = match key { + dir::DirKeyKind::PaymentMethod => dir_enums::PaymentMethod::VARIANTS, + dir::DirKeyKind::CardType => dir_enums::CardType::VARIANTS, + dir::DirKeyKind::CardNetwork => dir_enums::CardNetwork::VARIANTS, + dir::DirKeyKind::PayLaterType => dir_enums::PayLaterType::VARIANTS, + dir::DirKeyKind::WalletType => dir_enums::WalletType::VARIANTS, + dir::DirKeyKind::BankRedirectType => dir_enums::BankRedirectType::VARIANTS, + dir::DirKeyKind::CryptoType => dir_enums::CryptoType::VARIANTS, + dir::DirKeyKind::RewardType => dir_enums::RewardType::VARIANTS, + dir::DirKeyKind::AuthenticationType => dir_enums::AuthenticationType::VARIANTS, + dir::DirKeyKind::CaptureMethod => dir_enums::CaptureMethod::VARIANTS, + dir::DirKeyKind::PaymentCurrency => dir_enums::PaymentCurrency::VARIANTS, + dir::DirKeyKind::BusinessCountry => dir_enums::Country::VARIANTS, + dir::DirKeyKind::BillingCountry => dir_enums::Country::VARIANTS, + dir::DirKeyKind::BankTransferType => dir_enums::BankTransferType::VARIANTS, + dir::DirKeyKind::UpiType => dir_enums::UpiType::VARIANTS, + dir::DirKeyKind::SetupFutureUsage => dir_enums::SetupFutureUsage::VARIANTS, + dir::DirKeyKind::PaymentType => dir_enums::PaymentType::VARIANTS, + dir::DirKeyKind::MandateType => dir_enums::MandateType::VARIANTS, + dir::DirKeyKind::MandateAcceptanceType => dir_enums::MandateAcceptanceType::VARIANTS, + dir::DirKeyKind::CardRedirectType => dir_enums::CardRedirectType::VARIANTS, + dir::DirKeyKind::GiftCardType => dir_enums::GiftCardType::VARIANTS, + dir::DirKeyKind::VoucherType => dir_enums::VoucherType::VARIANTS, + dir::DirKeyKind::PaymentAmount + | dir::DirKeyKind::Connector + | dir::DirKeyKind::CardBin + | dir::DirKeyKind::BusinessLabel + | dir::DirKeyKind::MetaData => Err("Key does not have variants".to_string())?, + dir::DirKeyKind::BankDebitType => dir_enums::BankDebitType::VARIANTS, + }; + + Ok(serde_wasm_bindgen::to_value(variants)?) +} + +#[wasm_bindgen(js_name = addTwo)] +pub fn add_two(n1: i64, n2: i64) -> i64 { + n1 + n2 +} + +#[wasm_bindgen(js_name = getDescriptionCategory)] +pub fn get_description_category(key: &str) -> JsResult { + let key = dir::DirKeyKind::from_str(key).map_err(|_| "Invalid key received".to_string())?; + + let result = types::Details { + description: key.get_detailed_message(), + category: key.get_str("Category"), + }; + Ok(serde_wasm_bindgen::to_value(&result)?) +} diff --git a/crates/euclid_wasm/src/types.rs b/crates/euclid_wasm/src/types.rs new file mode 100644 index 000000000000..ea40449971bc --- /dev/null +++ b/crates/euclid_wasm/src/types.rs @@ -0,0 +1,7 @@ +use serde::Serialize; + +#[derive(Serialize, Clone)] +pub struct Details<'a> { + pub description: Option<&'a str>, + pub category: Option<&'a str>, +} diff --git a/crates/euclid_wasm/src/utils.rs b/crates/euclid_wasm/src/utils.rs new file mode 100644 index 000000000000..c531dabd7e2a --- /dev/null +++ b/crates/euclid_wasm/src/utils.rs @@ -0,0 +1,17 @@ +use wasm_bindgen::prelude::*; + +pub trait JsResultExt { + fn err_to_js(self) -> Result; +} + +impl JsResultExt for Result +where + E: serde::Serialize, +{ + fn err_to_js(self) -> Result { + match self { + Ok(t) => Ok(t), + Err(e) => Err(serde_wasm_bindgen::to_value(&e)?), + } + } +} diff --git a/crates/kgraph_utils/Cargo.toml b/crates/kgraph_utils/Cargo.toml new file mode 100644 index 000000000000..fa90b3974c20 --- /dev/null +++ b/crates/kgraph_utils/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "kgraph_utils" +description = "Utilities for constructing and working with Knowledge Graphs" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true + +[features] +dummy_connector = ["api_models/dummy_connector", "euclid/dummy_connector"] +backwards_compatibility = ["euclid/backwards_compatibility", "euclid/backwards_compatibility"] +connector_choice_mca_id = ["api_models/connector_choice_mca_id", "euclid/connector_choice_mca_id"] + +[dependencies] +api_models = { version = "0.1.0", path = "../api_models", package = "api_models" } +euclid = { version = "0.1.0", path = "../euclid" } +masking = { version = "0.1.0", path = "../masking/"} + +serde = "1.0.163" +serde_json = "1.0.96" +thiserror = "1.0.43" + +[dev-dependencies] +criterion = "0.5" + +[[bench]] +name = "evaluation" +harness = false diff --git a/crates/kgraph_utils/benches/evaluation.rs b/crates/kgraph_utils/benches/evaluation.rs new file mode 100644 index 000000000000..ecea12203f8a --- /dev/null +++ b/crates/kgraph_utils/benches/evaluation.rs @@ -0,0 +1,113 @@ +#![allow(unused, clippy::expect_used)] + +use std::str::FromStr; + +use api_models::{ + admin as admin_api, enums as api_enums, payment_methods::RequestPaymentMethodTypes, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use euclid::{ + dirval, + dssa::graph::{self, Memoization}, + frontend::dir, + types::{NumValue, NumValueRefinement}, +}; +use kgraph_utils::{error::KgraphError, transformers::IntoDirValue}; + +fn build_test_data<'a>(total_enabled: usize, total_pm_types: usize) -> graph::KnowledgeGraph<'a> { + use api_models::{admin::*, payment_methods::*}; + + let mut pms_enabled: Vec = Vec::new(); + + for _ in (0..total_enabled) { + let mut pm_types: Vec = Vec::new(); + for _ in (0..total_pm_types) { + pm_types.push(RequestPaymentMethodTypes { + payment_method_type: api_enums::PaymentMethodType::Credit, + payment_experience: None, + card_networks: Some(vec![ + api_enums::CardNetwork::Visa, + api_enums::CardNetwork::Mastercard, + ]), + accepted_currencies: Some(AcceptedCurrencies::EnableOnly(vec![ + api_enums::Currency::USD, + api_enums::Currency::INR, + ])), + accepted_countries: None, + minimum_amount: Some(10), + maximum_amount: Some(1000), + recurring_enabled: true, + installment_payment_enabled: true, + }); + } + + pms_enabled.push(PaymentMethodsEnabled { + payment_method: api_enums::PaymentMethod::Card, + payment_method_types: Some(pm_types), + }); + } + + let stripe_account = MerchantConnectorResponse { + connector_type: api_enums::ConnectorType::FizOperations, + connector_name: "stripe".to_string(), + merchant_connector_id: "something".to_string(), + connector_account_details: masking::Secret::new(serde_json::json!({})), + test_mode: None, + disabled: None, + metadata: None, + payment_methods_enabled: Some(pms_enabled), + business_country: Some(api_enums::CountryAlpha2::US), + business_label: Some("hello".to_string()), + connector_label: Some("something".to_string()), + business_sub_label: Some("something".to_string()), + frm_configs: None, + connector_webhook_details: None, + profile_id: None, + applepay_verified_domains: None, + pm_auth_config: None, + }; + + kgraph_utils::mca::make_mca_graph(vec![stripe_account]).expect("Failed graph construction") +} + +fn evaluation(c: &mut Criterion) { + let small_graph = build_test_data(3, 8); + let big_graph = build_test_data(20, 20); + + c.bench_function("MCA Small Graph Evaluation", |b| { + b.iter(|| { + small_graph.key_value_analysis( + dirval!(Connector = Stripe), + &graph::AnalysisContext::from_dir_values([ + dirval!(Connector = Stripe), + dirval!(PaymentMethod = Card), + dirval!(CardType = Credit), + dirval!(CardNetwork = Visa), + dirval!(PaymentCurrency = BWP), + dirval!(PaymentAmount = 100), + ]), + &mut Memoization::new(), + ); + }); + }); + + c.bench_function("MCA Big Graph Evaluation", |b| { + b.iter(|| { + big_graph.key_value_analysis( + dirval!(Connector = Stripe), + &graph::AnalysisContext::from_dir_values([ + dirval!(Connector = Stripe), + dirval!(PaymentMethod = Card), + dirval!(CardType = Credit), + dirval!(CardNetwork = Visa), + dirval!(PaymentCurrency = BWP), + dirval!(PaymentAmount = 100), + ]), + &mut Memoization::new(), + ); + }); + }); +} + +criterion_group!(benches, evaluation); +criterion_main!(benches); diff --git a/crates/kgraph_utils/src/error.rs b/crates/kgraph_utils/src/error.rs new file mode 100644 index 000000000000..5a16c6375b06 --- /dev/null +++ b/crates/kgraph_utils/src/error.rs @@ -0,0 +1,14 @@ +use euclid::dssa::{graph::GraphError, types::AnalysisErrorType}; + +#[derive(Debug, thiserror::Error, serde::Serialize)] +#[serde(tag = "type", content = "info", rename_all = "snake_case")] +pub enum KgraphError { + #[error("Invalid connector name encountered: '{0}'")] + InvalidConnectorName(String), + #[error("There was an error constructing the graph: {0}")] + GraphConstructionError(GraphError), + #[error("There was an error constructing the context")] + ContextConstructionError(AnalysisErrorType), + #[error("there was an unprecedented indexing error")] + IndexingError, +} diff --git a/crates/kgraph_utils/src/lib.rs b/crates/kgraph_utils/src/lib.rs new file mode 100644 index 000000000000..eb8eef6dedb5 --- /dev/null +++ b/crates/kgraph_utils/src/lib.rs @@ -0,0 +1,3 @@ +pub mod error; +pub mod mca; +pub mod transformers; diff --git a/crates/kgraph_utils/src/mca.rs b/crates/kgraph_utils/src/mca.rs new file mode 100644 index 000000000000..34babd7a02bd --- /dev/null +++ b/crates/kgraph_utils/src/mca.rs @@ -0,0 +1,739 @@ +use std::str::FromStr; + +use api_models::{ + admin as admin_api, enums as api_enums, payment_methods::RequestPaymentMethodTypes, +}; +use euclid::{ + dssa::graph::{self, DomainIdentifier}, + frontend::{ + ast, + dir::{self, enums as dir_enums}, + }, + types::{NumValue, NumValueRefinement}, +}; + +use crate::{error::KgraphError, transformers::IntoDirValue}; + +pub const DOMAIN_IDENTIFIER: &str = "payment_methods_enabled_for_merchantconnectoraccount"; + +fn compile_request_pm_types( + builder: &mut graph::KnowledgeGraphBuilder<'_>, + pm_types: RequestPaymentMethodTypes, + pm: api_enums::PaymentMethod, +) -> Result { + let mut agg_nodes: Vec<(graph::NodeId, graph::Relation, graph::Strength)> = Vec::new(); + + let pmt_info = "PaymentMethodType"; + let pmt_id = builder + .make_value_node( + (pm_types.payment_method_type, pm) + .into_dir_value() + .map(Into::into)?, + Some(pmt_info), + vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], + None::<()>, + ) + .map_err(KgraphError::GraphConstructionError)?; + agg_nodes.push(( + pmt_id, + graph::Relation::Positive, + match pm_types.payment_method_type { + api_enums::PaymentMethodType::Credit | api_enums::PaymentMethodType::Debit => { + graph::Strength::Weak + } + + _ => graph::Strength::Strong, + }, + )); + + if let Some(card_networks) = pm_types.card_networks { + if !card_networks.is_empty() { + let dir_vals: Vec = card_networks + .into_iter() + .map(IntoDirValue::into_dir_value) + .collect::>()?; + + let card_network_info = "Card Networks"; + let card_network_id = builder + .make_in_aggregator(dir_vals, Some(card_network_info), None::<()>, Vec::new()) + .map_err(KgraphError::GraphConstructionError)?; + + agg_nodes.push(( + card_network_id, + graph::Relation::Positive, + graph::Strength::Weak, + )); + } + } + + let currencies_data = pm_types + .accepted_currencies + .and_then(|accepted_currencies| match accepted_currencies { + admin_api::AcceptedCurrencies::EnableOnly(curr) if !curr.is_empty() => Some(( + curr.into_iter() + .map(IntoDirValue::into_dir_value) + .collect::>() + .ok()?, + graph::Relation::Positive, + )), + + admin_api::AcceptedCurrencies::DisableOnly(curr) if !curr.is_empty() => Some(( + curr.into_iter() + .map(IntoDirValue::into_dir_value) + .collect::>() + .ok()?, + graph::Relation::Negative, + )), + + _ => None, + }); + + if let Some((currencies, relation)) = currencies_data { + let accepted_currencies_info = "Accepted Currencies"; + let accepted_currencies_id = builder + .make_in_aggregator( + currencies, + Some(accepted_currencies_info), + None::<()>, + Vec::new(), + ) + .map_err(KgraphError::GraphConstructionError)?; + + agg_nodes.push((accepted_currencies_id, relation, graph::Strength::Strong)); + } + + let mut amount_nodes = Vec::with_capacity(2); + + if let Some(min_amt) = pm_types.minimum_amount { + let num_val = NumValue { + number: min_amt.into(), + refinement: Some(NumValueRefinement::GreaterThanEqual), + }; + + let min_amt_info = "Minimum Amount"; + let min_amt_id = builder + .make_value_node( + dir::DirValue::PaymentAmount(num_val).into(), + Some(min_amt_info), + vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], + None::<()>, + ) + .map_err(KgraphError::GraphConstructionError)?; + + amount_nodes.push(min_amt_id); + } + + if let Some(max_amt) = pm_types.maximum_amount { + let num_val = NumValue { + number: max_amt.into(), + refinement: Some(NumValueRefinement::LessThanEqual), + }; + + let max_amt_info = "Maximum Amount"; + let max_amt_id = builder + .make_value_node( + dir::DirValue::PaymentAmount(num_val).into(), + Some(max_amt_info), + vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], + None::<()>, + ) + .map_err(KgraphError::GraphConstructionError)?; + + amount_nodes.push(max_amt_id); + } + + if !amount_nodes.is_empty() { + let zero_num_val = NumValue { + number: 0, + refinement: None, + }; + + let zero_amt_id = builder + .make_value_node( + dir::DirValue::PaymentAmount(zero_num_val).into(), + Some("zero_amount"), + vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], + None::<()>, + ) + .map_err(KgraphError::GraphConstructionError)?; + + let or_node_neighbor_id = if amount_nodes.len() == 1 { + amount_nodes + .get(0) + .copied() + .ok_or(KgraphError::IndexingError)? + } else { + let nodes = amount_nodes + .iter() + .copied() + .map(|node_id| (node_id, graph::Relation::Positive, graph::Strength::Strong)) + .collect::>(); + + builder + .make_all_aggregator( + &nodes, + Some("amount_constraint_aggregator"), + None::<()>, + vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], + ) + .map_err(KgraphError::GraphConstructionError)? + }; + + let any_aggregator = builder + .make_any_aggregator( + &[ + (zero_amt_id, graph::Relation::Positive), + (or_node_neighbor_id, graph::Relation::Positive), + ], + Some("zero_plus_limits_amount_aggregator"), + None::<()>, + vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], + ) + .map_err(KgraphError::GraphConstructionError)?; + + agg_nodes.push(( + any_aggregator, + graph::Relation::Positive, + graph::Strength::Strong, + )); + } + + let pmt_all_aggregator_info = "All Aggregator for PaymentMethodType"; + builder + .make_all_aggregator( + &agg_nodes, + Some(pmt_all_aggregator_info), + None::<()>, + Vec::new(), + ) + .map_err(KgraphError::GraphConstructionError) +} + +fn compile_payment_method_enabled( + builder: &mut graph::KnowledgeGraphBuilder<'_>, + enabled: admin_api::PaymentMethodsEnabled, +) -> Result, KgraphError> { + let agg_id = if !enabled + .payment_method_types + .as_ref() + .map(|v| v.is_empty()) + .unwrap_or(true) + { + let pm_info = "PaymentMethod"; + let pm_id = builder + .make_value_node( + enabled.payment_method.into_dir_value().map(Into::into)?, + Some(pm_info), + vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], + None::<()>, + ) + .map_err(KgraphError::GraphConstructionError)?; + + let mut agg_nodes: Vec<(graph::NodeId, graph::Relation)> = Vec::new(); + + if let Some(pm_types) = enabled.payment_method_types { + for pm_type in pm_types { + let node_id = compile_request_pm_types(builder, pm_type, enabled.payment_method)?; + agg_nodes.push((node_id, graph::Relation::Positive)); + } + } + + let any_aggregator_info = "Any aggregation for PaymentMethodsType"; + let pm_type_agg_id = builder + .make_any_aggregator( + &agg_nodes, + Some(any_aggregator_info), + None::<()>, + Vec::new(), + ) + .map_err(KgraphError::GraphConstructionError)?; + + let all_aggregator_info = "All aggregation for PaymentMethod"; + let enabled_pm_agg_id = builder + .make_all_aggregator( + &[ + (pm_id, graph::Relation::Positive, graph::Strength::Strong), + ( + pm_type_agg_id, + graph::Relation::Positive, + graph::Strength::Strong, + ), + ], + Some(all_aggregator_info), + None::<()>, + Vec::new(), + ) + .map_err(KgraphError::GraphConstructionError)?; + + Some(enabled_pm_agg_id) + } else { + None + }; + + Ok(agg_id) +} + +fn compile_merchant_connector_graph( + builder: &mut graph::KnowledgeGraphBuilder<'_>, + mca: admin_api::MerchantConnectorResponse, +) -> Result<(), KgraphError> { + let connector = dir_enums::Connector::from_str(&mca.connector_name) + .map_err(|_| KgraphError::InvalidConnectorName(mca.connector_name.clone()))?; + + let mut agg_nodes: Vec<(graph::NodeId, graph::Relation)> = Vec::new(); + + if let Some(pms_enabled) = mca.payment_methods_enabled { + for pm_enabled in pms_enabled { + let maybe_pm_enabled_id = compile_payment_method_enabled(builder, pm_enabled)?; + if let Some(pm_enabled_id) = maybe_pm_enabled_id { + agg_nodes.push((pm_enabled_id, graph::Relation::Positive)); + } + } + } + + let aggregator_info = "Available Payment methods for connector"; + let pms_enabled_agg_id = builder + .make_any_aggregator(&agg_nodes, Some(aggregator_info), None::<()>, Vec::new()) + .map_err(KgraphError::GraphConstructionError)?; + + let connector_dir_val = dir::DirValue::Connector(Box::new(ast::ConnectorChoice { + connector, + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label: mca.business_sub_label, + })); + + let connector_info = "Connector"; + let connector_node_id = builder + .make_value_node( + connector_dir_val.into(), + Some(connector_info), + vec![DomainIdentifier::new(DOMAIN_IDENTIFIER)], + None::<()>, + ) + .map_err(KgraphError::GraphConstructionError)?; + + builder + .make_edge( + pms_enabled_agg_id, + connector_node_id, + graph::Strength::Normal, + graph::Relation::Positive, + ) + .map_err(KgraphError::GraphConstructionError)?; + + Ok(()) +} + +pub fn make_mca_graph<'a>( + accts: Vec, +) -> Result, KgraphError> { + let mut builder = graph::KnowledgeGraphBuilder::new(); + let _domain = builder.make_domain( + DomainIdentifier::new(DOMAIN_IDENTIFIER), + "Payment methods enabled for MerchantConnectorAccount".to_string(), + ); + for acct in accts { + compile_merchant_connector_graph(&mut builder, acct)?; + } + + Ok(builder.build()) +} + +#[cfg(test)] +mod tests { + #![allow(clippy::expect_used)] + + use api_models::enums as api_enums; + use euclid::{ + dirval, + dssa::graph::{AnalysisContext, Memoization}, + }; + + use super::*; + + fn build_test_data<'a>() -> graph::KnowledgeGraph<'a> { + use api_models::{admin::*, payment_methods::*}; + + let stripe_account = MerchantConnectorResponse { + connector_type: api_enums::ConnectorType::FizOperations, + connector_name: "stripe".to_string(), + merchant_connector_id: "something".to_string(), + business_country: Some(api_enums::CountryAlpha2::US), + connector_label: Some("something".to_string()), + business_label: Some("food".to_string()), + business_sub_label: None, + connector_account_details: masking::Secret::new(serde_json::json!({})), + test_mode: None, + disabled: None, + metadata: None, + payment_methods_enabled: Some(vec![PaymentMethodsEnabled { + payment_method: api_enums::PaymentMethod::Card, + payment_method_types: Some(vec![ + RequestPaymentMethodTypes { + payment_method_type: api_enums::PaymentMethodType::Credit, + payment_experience: None, + card_networks: Some(vec![ + api_enums::CardNetwork::Visa, + api_enums::CardNetwork::Mastercard, + ]), + accepted_currencies: Some(AcceptedCurrencies::EnableOnly(vec![ + api_enums::Currency::USD, + api_enums::Currency::INR, + ])), + accepted_countries: None, + minimum_amount: Some(10), + maximum_amount: Some(1000), + recurring_enabled: true, + installment_payment_enabled: true, + }, + RequestPaymentMethodTypes { + payment_method_type: api_enums::PaymentMethodType::Debit, + payment_experience: None, + card_networks: Some(vec![ + api_enums::CardNetwork::Maestro, + api_enums::CardNetwork::JCB, + ]), + accepted_currencies: Some(AcceptedCurrencies::EnableOnly(vec![ + api_enums::Currency::GBP, + api_enums::Currency::PHP, + ])), + accepted_countries: None, + minimum_amount: Some(10), + maximum_amount: Some(1000), + recurring_enabled: true, + installment_payment_enabled: true, + }, + ]), + }]), + frm_configs: None, + connector_webhook_details: None, + profile_id: None, + applepay_verified_domains: None, + pm_auth_config: None, + }; + + make_mca_graph(vec![stripe_account]).expect("Failed graph construction") + } + + #[test] + fn test_credit_card_success_case() { + let graph = build_test_data(); + + let result = graph.key_value_analysis( + dirval!(Connector = Stripe), + &AnalysisContext::from_dir_values([ + dirval!(Connector = Stripe), + dirval!(PaymentMethod = Card), + dirval!(CardType = Credit), + dirval!(CardNetwork = Visa), + dirval!(PaymentCurrency = USD), + dirval!(PaymentAmount = 100), + ]), + &mut Memoization::new(), + ); + + assert!(result.is_ok()); + } + + #[test] + fn test_debit_card_success_case() { + let graph = build_test_data(); + + let result = graph.key_value_analysis( + dirval!(Connector = Stripe), + &AnalysisContext::from_dir_values([ + dirval!(Connector = Stripe), + dirval!(PaymentMethod = Card), + dirval!(CardType = Debit), + dirval!(CardNetwork = Maestro), + dirval!(PaymentCurrency = GBP), + dirval!(PaymentAmount = 100), + ]), + &mut Memoization::new(), + ); + + assert!(result.is_ok()); + } + + #[test] + fn test_single_mismatch_failure_case() { + let graph = build_test_data(); + + let result = graph.key_value_analysis( + dirval!(Connector = Stripe), + &AnalysisContext::from_dir_values([ + dirval!(Connector = Stripe), + dirval!(PaymentMethod = Card), + dirval!(CardType = Debit), + dirval!(CardNetwork = DinersClub), + dirval!(PaymentCurrency = GBP), + dirval!(PaymentAmount = 100), + ]), + &mut Memoization::new(), + ); + + assert!(result.is_err()); + } + + #[test] + fn test_amount_mismatch_failure_case() { + let graph = build_test_data(); + + let result = graph.key_value_analysis( + dirval!(Connector = Stripe), + &AnalysisContext::from_dir_values([ + dirval!(Connector = Stripe), + dirval!(PaymentMethod = Card), + dirval!(CardType = Debit), + dirval!(CardNetwork = Visa), + dirval!(PaymentCurrency = GBP), + dirval!(PaymentAmount = 7), + ]), + &mut Memoization::new(), + ); + + assert!(result.is_err()); + } + + #[test] + fn test_incomplete_data_failure_case() { + let graph = build_test_data(); + + let result = graph.key_value_analysis( + dirval!(Connector = Stripe), + &AnalysisContext::from_dir_values([ + dirval!(Connector = Stripe), + dirval!(PaymentMethod = Card), + dirval!(CardType = Debit), + dirval!(PaymentCurrency = GBP), + dirval!(PaymentAmount = 7), + ]), + &mut Memoization::new(), + ); + + //println!("{:#?}", result); + //println!("{}", serde_json::to_string_pretty(&result).expect("Hello")); + + assert!(result.is_err()); + } + + #[test] + fn test_incomplete_data_failure_case2() { + let graph = build_test_data(); + + let result = graph.key_value_analysis( + dirval!(Connector = Stripe), + &AnalysisContext::from_dir_values([ + dirval!(Connector = Stripe), + dirval!(CardType = Debit), + dirval!(CardNetwork = Visa), + dirval!(PaymentCurrency = GBP), + dirval!(PaymentAmount = 100), + ]), + &mut Memoization::new(), + ); + + //println!("{:#?}", result); + //println!("{}", serde_json::to_string_pretty(&result).expect("Hello")); + + assert!(result.is_err()); + } + + #[test] + fn test_sandbox_applepay_bug_usecase() { + let value = serde_json::json!([ + { + "connector_type": "payment_processor", + "connector_name": "bluesnap", + "merchant_connector_id": "REDACTED", + "connector_account_details": { + "auth_type": "BodyKey", + "api_key": "REDACTED", + "key1": "REDACTED" + }, + "test_mode": true, + "disabled": false, + "payment_methods_enabled": [ + { + "payment_method": "card", + "payment_method_types": [ + { + "payment_method_type": "credit", + "payment_experience": null, + "card_networks": [ + "Mastercard", + "Visa", + "AmericanExpress", + "JCB", + "DinersClub", + "Discover", + "CartesBancaires", + "UnionPay" + ], + "accepted_currencies": null, + "accepted_countries": null, + "minimum_amount": 1, + "maximum_amount": 68607706, + "recurring_enabled": true, + "installment_payment_enabled": true + }, + { + "payment_method_type": "debit", + "payment_experience": null, + "card_networks": [ + "Mastercard", + "Visa", + "Interac", + "AmericanExpress", + "JCB", + "DinersClub", + "Discover", + "CartesBancaires", + "UnionPay" + ], + "accepted_currencies": null, + "accepted_countries": null, + "minimum_amount": 1, + "maximum_amount": 68607706, + "recurring_enabled": true, + "installment_payment_enabled": true + } + ] + }, + { + "payment_method": "wallet", + "payment_method_types": [ + { + "payment_method_type": "google_pay", + "payment_experience": "invoke_sdk_client", + "card_networks": null, + "accepted_currencies": null, + "accepted_countries": null, + "minimum_amount": 1, + "maximum_amount": 68607706, + "recurring_enabled": true, + "installment_payment_enabled": true + } + ] + } + ], + "metadata": {}, + "business_country": "US", + "business_label": "default", + "business_sub_label": null, + "frm_configs": null + }, + { + "connector_type": "payment_processor", + "connector_name": "stripe", + "merchant_connector_id": "REDACTED", + "connector_account_details": { + "auth_type": "HeaderKey", + "api_key": "REDACTED" + }, + "test_mode": true, + "disabled": false, + "payment_methods_enabled": [ + { + "payment_method": "card", + "payment_method_types": [ + { + "payment_method_type": "credit", + "payment_experience": null, + "card_networks": [ + "Mastercard", + "Visa", + "AmericanExpress", + "JCB", + "DinersClub", + "Discover", + "CartesBancaires", + "UnionPay" + ], + "accepted_currencies": null, + "accepted_countries": null, + "minimum_amount": 1, + "maximum_amount": 68607706, + "recurring_enabled": true, + "installment_payment_enabled": true + }, + { + "payment_method_type": "debit", + "payment_experience": null, + "card_networks": [ + "Mastercard", + "Visa", + "Interac", + "AmericanExpress", + "JCB", + "DinersClub", + "Discover", + "CartesBancaires", + "UnionPay" + ], + "accepted_currencies": null, + "accepted_countries": null, + "minimum_amount": 1, + "maximum_amount": 68607706, + "recurring_enabled": true, + "installment_payment_enabled": true + } + ] + }, + { + "payment_method": "wallet", + "payment_method_types": [ + { + "payment_method_type": "apple_pay", + "payment_experience": "invoke_sdk_client", + "card_networks": null, + "accepted_currencies": null, + "accepted_countries": null, + "minimum_amount": 1, + "maximum_amount": 68607706, + "recurring_enabled": true, + "installment_payment_enabled": true + } + ] + }, + { + "payment_method": "pay_later", + "payment_method_types": [] + } + ], + "metadata": {}, + "business_country": "US", + "business_label": "default", + "business_sub_label": null, + "frm_configs": null + } + ]); + + let data: Vec = + serde_json::from_value(value).expect("data"); + + let graph = make_mca_graph(data).expect("graph"); + let context = AnalysisContext::from_dir_values([ + dirval!(Connector = Stripe), + dirval!(PaymentAmount = 212), + dirval!(PaymentCurrency = ILS), + dirval!(PaymentMethod = Wallet), + dirval!(WalletType = ApplePay), + ]); + + let result = graph.key_value_analysis( + dirval!(Connector = Stripe), + &context, + &mut Memoization::new(), + ); + + assert!(result.is_ok(), "stripe validation failed"); + + let result = graph.key_value_analysis( + dirval!(Connector = Bluesnap), + &context, + &mut Memoization::new(), + ); + assert!(result.is_err(), "bluesnap validation failed"); + } +} diff --git a/crates/kgraph_utils/src/transformers.rs b/crates/kgraph_utils/src/transformers.rs new file mode 100644 index 000000000000..3d32cce38bd8 --- /dev/null +++ b/crates/kgraph_utils/src/transformers.rs @@ -0,0 +1,724 @@ +use api_models::enums as api_enums; +use euclid::{ + backend::BackendInput, + dirval, + dssa::types::AnalysisErrorType, + frontend::{ast, dir}, + types::{NumValue, StrValue}, +}; + +use crate::error::KgraphError; + +pub trait IntoContext { + fn into_context(self) -> Result, KgraphError>; +} + +impl IntoContext for BackendInput { + fn into_context(self) -> Result, KgraphError> { + let mut ctx: Vec = Vec::new(); + + ctx.push(dir::DirValue::PaymentAmount(NumValue { + number: self.payment.amount, + refinement: None, + })); + + ctx.push(dir::DirValue::PaymentCurrency(self.payment.currency)); + + if let Some(auth_type) = self.payment.authentication_type { + ctx.push(dir::DirValue::AuthenticationType(auth_type)); + } + + if let Some(capture_method) = self.payment.capture_method { + ctx.push(dir::DirValue::CaptureMethod(capture_method)); + } + + if let Some(business_country) = self.payment.business_country { + ctx.push(dir::DirValue::BusinessCountry(business_country)); + } + if let Some(business_label) = self.payment.business_label { + ctx.push(dir::DirValue::BusinessLabel(StrValue { + value: business_label, + })); + } + if let Some(billing_country) = self.payment.billing_country { + ctx.push(dir::DirValue::BillingCountry(billing_country)); + } + + if let Some(payment_method) = self.payment_method.payment_method { + ctx.push(dir::DirValue::PaymentMethod(payment_method)); + } + + if let (Some(pm_type), Some(payment_method)) = ( + self.payment_method.payment_method_type, + self.payment_method.payment_method, + ) { + ctx.push((pm_type, payment_method).into_dir_value()?) + } + + if let Some(card_network) = self.payment_method.card_network { + ctx.push(dir::DirValue::CardNetwork(card_network)); + } + if let Some(setup_future_usage) = self.payment.setup_future_usage { + ctx.push(dir::DirValue::SetupFutureUsage(setup_future_usage)); + } + if let Some(mandate_acceptance_type) = self.mandate.mandate_acceptance_type { + ctx.push(dir::DirValue::MandateAcceptanceType( + mandate_acceptance_type, + )); + } + if let Some(mandate_type) = self.mandate.mandate_type { + ctx.push(dir::DirValue::MandateType(mandate_type)); + } + if let Some(payment_type) = self.mandate.payment_type { + ctx.push(dir::DirValue::PaymentType(payment_type)); + } + + Ok(ctx) + } +} + +pub trait IntoDirValue { + fn into_dir_value(self) -> Result; +} + +impl IntoDirValue for ast::ConnectorChoice { + fn into_dir_value(self) -> Result { + Ok(dir::DirValue::Connector(Box::new(self))) + } +} + +impl IntoDirValue for api_enums::PaymentMethod { + fn into_dir_value(self) -> Result { + match self { + Self::Card => Ok(dirval!(PaymentMethod = Card)), + Self::Wallet => Ok(dirval!(PaymentMethod = Wallet)), + Self::PayLater => Ok(dirval!(PaymentMethod = PayLater)), + Self::BankRedirect => Ok(dirval!(PaymentMethod = BankRedirect)), + Self::Crypto => Ok(dirval!(PaymentMethod = Crypto)), + Self::BankDebit => Ok(dirval!(PaymentMethod = BankDebit)), + Self::BankTransfer => Ok(dirval!(PaymentMethod = BankTransfer)), + Self::Reward => Ok(dirval!(PaymentMethod = Reward)), + Self::Upi => Ok(dirval!(PaymentMethod = Upi)), + Self::Voucher => Ok(dirval!(PaymentMethod = Voucher)), + Self::GiftCard => Ok(dirval!(PaymentMethod = GiftCard)), + Self::CardRedirect => Ok(dirval!(PaymentMethod = CardRedirect)), + } + } +} + +impl IntoDirValue for api_enums::AuthenticationType { + fn into_dir_value(self) -> Result { + match self { + Self::ThreeDs => Ok(dirval!(AuthenticationType = ThreeDs)), + Self::NoThreeDs => Ok(dirval!(AuthenticationType = NoThreeDs)), + } + } +} + +impl IntoDirValue for api_enums::FutureUsage { + fn into_dir_value(self) -> Result { + match self { + Self::OnSession => Ok(dirval!(SetupFutureUsage = OnSession)), + Self::OffSession => Ok(dirval!(SetupFutureUsage = OffSession)), + } + } +} + +impl IntoDirValue for (api_enums::PaymentMethodType, api_enums::PaymentMethod) { + fn into_dir_value(self) -> Result { + match self.0 { + api_enums::PaymentMethodType::Credit => Ok(dirval!(CardType = Credit)), + api_enums::PaymentMethodType::Debit => Ok(dirval!(CardType = Debit)), + api_enums::PaymentMethodType::Giropay => Ok(dirval!(BankRedirectType = Giropay)), + api_enums::PaymentMethodType::Ideal => Ok(dirval!(BankRedirectType = Ideal)), + api_enums::PaymentMethodType::Sofort => Ok(dirval!(BankRedirectType = Sofort)), + api_enums::PaymentMethodType::Eps => Ok(dirval!(BankRedirectType = Eps)), + api_enums::PaymentMethodType::Klarna => Ok(dirval!(PayLaterType = Klarna)), + api_enums::PaymentMethodType::Affirm => Ok(dirval!(PayLaterType = Affirm)), + api_enums::PaymentMethodType::AfterpayClearpay => { + Ok(dirval!(PayLaterType = AfterpayClearpay)) + } + api_enums::PaymentMethodType::GooglePay => Ok(dirval!(WalletType = GooglePay)), + api_enums::PaymentMethodType::ApplePay => Ok(dirval!(WalletType = ApplePay)), + api_enums::PaymentMethodType::Paypal => Ok(dirval!(WalletType = Paypal)), + api_enums::PaymentMethodType::CryptoCurrency => { + Ok(dirval!(CryptoType = CryptoCurrency)) + } + api_enums::PaymentMethodType::Ach => match self.1 { + api_enums::PaymentMethod::BankDebit => Ok(dirval!(BankDebitType = Ach)), + api_enums::PaymentMethod::BankTransfer => Ok(dirval!(BankTransferType = Ach)), + api_enums::PaymentMethod::BankRedirect + | api_enums::PaymentMethod::Card + | api_enums::PaymentMethod::CardRedirect + | api_enums::PaymentMethod::PayLater + | api_enums::PaymentMethod::Wallet + | api_enums::PaymentMethod::Crypto + | api_enums::PaymentMethod::Reward + | api_enums::PaymentMethod::Upi + | api_enums::PaymentMethod::Voucher + | api_enums::PaymentMethod::GiftCard => Err(KgraphError::ContextConstructionError( + AnalysisErrorType::NotSupported, + )), + }, + api_enums::PaymentMethodType::Bacs => match self.1 { + api_enums::PaymentMethod::BankDebit => Ok(dirval!(BankDebitType = Bacs)), + api_enums::PaymentMethod::BankTransfer => Ok(dirval!(BankTransferType = Bacs)), + api_enums::PaymentMethod::BankRedirect + | api_enums::PaymentMethod::Card + | api_enums::PaymentMethod::CardRedirect + | api_enums::PaymentMethod::PayLater + | api_enums::PaymentMethod::Wallet + | api_enums::PaymentMethod::Crypto + | api_enums::PaymentMethod::Reward + | api_enums::PaymentMethod::Upi + | api_enums::PaymentMethod::Voucher + | api_enums::PaymentMethod::GiftCard => Err(KgraphError::ContextConstructionError( + AnalysisErrorType::NotSupported, + )), + }, + api_enums::PaymentMethodType::Becs => Ok(dirval!(BankDebitType = Becs)), + api_enums::PaymentMethodType::Sepa => match self.1 { + api_enums::PaymentMethod::BankDebit => Ok(dirval!(BankDebitType = Sepa)), + api_enums::PaymentMethod::BankTransfer => Ok(dirval!(BankTransferType = Sepa)), + api_enums::PaymentMethod::BankRedirect + | api_enums::PaymentMethod::Card + | api_enums::PaymentMethod::CardRedirect + | api_enums::PaymentMethod::PayLater + | api_enums::PaymentMethod::Wallet + | api_enums::PaymentMethod::Crypto + | api_enums::PaymentMethod::Reward + | api_enums::PaymentMethod::Upi + | api_enums::PaymentMethod::Voucher + | api_enums::PaymentMethod::GiftCard => Err(KgraphError::ContextConstructionError( + AnalysisErrorType::NotSupported, + )), + }, + api_enums::PaymentMethodType::AliPay => Ok(dirval!(WalletType = AliPay)), + api_enums::PaymentMethodType::AliPayHk => Ok(dirval!(WalletType = AliPayHk)), + api_enums::PaymentMethodType::BancontactCard => { + Ok(dirval!(BankRedirectType = BancontactCard)) + } + api_enums::PaymentMethodType::Blik => Ok(dirval!(BankRedirectType = Blik)), + api_enums::PaymentMethodType::MbWay => Ok(dirval!(WalletType = MbWay)), + api_enums::PaymentMethodType::MobilePay => Ok(dirval!(WalletType = MobilePay)), + api_enums::PaymentMethodType::Cashapp => Ok(dirval!(WalletType = Cashapp)), + api_enums::PaymentMethodType::Multibanco => Ok(dirval!(BankTransferType = Multibanco)), + api_enums::PaymentMethodType::Pix => Ok(dirval!(BankTransferType = Pix)), + api_enums::PaymentMethodType::Pse => Ok(dirval!(BankTransferType = Pse)), + api_enums::PaymentMethodType::Interac => Ok(dirval!(BankRedirectType = Interac)), + api_enums::PaymentMethodType::OnlineBankingCzechRepublic => { + Ok(dirval!(BankRedirectType = OnlineBankingCzechRepublic)) + } + api_enums::PaymentMethodType::OnlineBankingFinland => { + Ok(dirval!(BankRedirectType = OnlineBankingFinland)) + } + api_enums::PaymentMethodType::OnlineBankingPoland => { + Ok(dirval!(BankRedirectType = OnlineBankingPoland)) + } + api_enums::PaymentMethodType::OnlineBankingSlovakia => { + Ok(dirval!(BankRedirectType = OnlineBankingSlovakia)) + } + api_enums::PaymentMethodType::Swish => Ok(dirval!(WalletType = Swish)), + api_enums::PaymentMethodType::Trustly => Ok(dirval!(BankRedirectType = Trustly)), + api_enums::PaymentMethodType::Bizum => Ok(dirval!(BankRedirectType = Bizum)), + + api_enums::PaymentMethodType::PayBright => Ok(dirval!(PayLaterType = PayBright)), + api_enums::PaymentMethodType::Walley => Ok(dirval!(PayLaterType = Walley)), + api_enums::PaymentMethodType::Przelewy24 => Ok(dirval!(BankRedirectType = Przelewy24)), + api_enums::PaymentMethodType::WeChatPay => Ok(dirval!(WalletType = WeChatPay)), + + api_enums::PaymentMethodType::ClassicReward => Ok(dirval!(RewardType = ClassicReward)), + api_enums::PaymentMethodType::Evoucher => Ok(dirval!(RewardType = Evoucher)), + api_enums::PaymentMethodType::UpiCollect => Ok(dirval!(UpiType = UpiCollect)), + api_enums::PaymentMethodType::SamsungPay => Ok(dirval!(WalletType = SamsungPay)), + api_enums::PaymentMethodType::GoPay => Ok(dirval!(WalletType = GoPay)), + api_enums::PaymentMethodType::KakaoPay => Ok(dirval!(WalletType = KakaoPay)), + api_enums::PaymentMethodType::Twint => Ok(dirval!(WalletType = Twint)), + api_enums::PaymentMethodType::Gcash => Ok(dirval!(WalletType = Gcash)), + api_enums::PaymentMethodType::Vipps => Ok(dirval!(WalletType = Vipps)), + api_enums::PaymentMethodType::Momo => Ok(dirval!(WalletType = Momo)), + api_enums::PaymentMethodType::Alma => Ok(dirval!(PayLaterType = Alma)), + api_enums::PaymentMethodType::Dana => Ok(dirval!(WalletType = Dana)), + api_enums::PaymentMethodType::OnlineBankingFpx => { + Ok(dirval!(BankRedirectType = OnlineBankingFpx)) + } + api_enums::PaymentMethodType::OnlineBankingThailand => { + Ok(dirval!(BankRedirectType = OnlineBankingThailand)) + } + api_enums::PaymentMethodType::TouchNGo => Ok(dirval!(WalletType = TouchNGo)), + api_enums::PaymentMethodType::Atome => Ok(dirval!(PayLaterType = Atome)), + api_enums::PaymentMethodType::Boleto => Ok(dirval!(VoucherType = Boleto)), + api_enums::PaymentMethodType::Efecty => Ok(dirval!(VoucherType = Efecty)), + api_enums::PaymentMethodType::PagoEfectivo => Ok(dirval!(VoucherType = PagoEfectivo)), + api_enums::PaymentMethodType::RedCompra => Ok(dirval!(VoucherType = RedCompra)), + api_enums::PaymentMethodType::RedPagos => Ok(dirval!(VoucherType = RedPagos)), + api_enums::PaymentMethodType::Alfamart => Ok(dirval!(VoucherType = Alfamart)), + api_enums::PaymentMethodType::BcaBankTransfer => { + Ok(dirval!(BankTransferType = BcaBankTransfer)) + } + api_enums::PaymentMethodType::BniVa => Ok(dirval!(BankTransferType = BniVa)), + api_enums::PaymentMethodType::BriVa => Ok(dirval!(BankTransferType = BriVa)), + api_enums::PaymentMethodType::CimbVa => Ok(dirval!(BankTransferType = CimbVa)), + api_enums::PaymentMethodType::DanamonVa => Ok(dirval!(BankTransferType = DanamonVa)), + api_enums::PaymentMethodType::Indomaret => Ok(dirval!(VoucherType = Indomaret)), + api_enums::PaymentMethodType::MandiriVa => Ok(dirval!(BankTransferType = MandiriVa)), + api_enums::PaymentMethodType::PermataBankTransfer => { + Ok(dirval!(BankTransferType = PermataBankTransfer)) + } + api_enums::PaymentMethodType::PaySafeCard => Ok(dirval!(GiftCardType = PaySafeCard)), + api_enums::PaymentMethodType::SevenEleven => Ok(dirval!(VoucherType = SevenEleven)), + api_enums::PaymentMethodType::Lawson => Ok(dirval!(VoucherType = Lawson)), + api_enums::PaymentMethodType::MiniStop => Ok(dirval!(VoucherType = MiniStop)), + api_enums::PaymentMethodType::FamilyMart => Ok(dirval!(VoucherType = FamilyMart)), + api_enums::PaymentMethodType::Seicomart => Ok(dirval!(VoucherType = Seicomart)), + api_enums::PaymentMethodType::PayEasy => Ok(dirval!(VoucherType = PayEasy)), + api_enums::PaymentMethodType::Givex => Ok(dirval!(GiftCardType = Givex)), + api_enums::PaymentMethodType::Benefit => Ok(dirval!(CardRedirectType = Benefit)), + api_enums::PaymentMethodType::Knet => Ok(dirval!(CardRedirectType = Knet)), + api_enums::PaymentMethodType::OpenBankingUk => { + Ok(dirval!(BankRedirectType = OpenBankingUk)) + } + api_enums::PaymentMethodType::MomoAtm => Ok(dirval!(CardRedirectType = MomoAtm)), + api_enums::PaymentMethodType::Oxxo => Ok(dirval!(VoucherType = Oxxo)), + } + } +} + +impl IntoDirValue for api_enums::CardNetwork { + fn into_dir_value(self) -> Result { + match self { + Self::Visa => Ok(dirval!(CardNetwork = Visa)), + Self::Mastercard => Ok(dirval!(CardNetwork = Mastercard)), + Self::AmericanExpress => Ok(dirval!(CardNetwork = AmericanExpress)), + Self::JCB => Ok(dirval!(CardNetwork = JCB)), + Self::DinersClub => Ok(dirval!(CardNetwork = DinersClub)), + Self::Discover => Ok(dirval!(CardNetwork = Discover)), + Self::CartesBancaires => Ok(dirval!(CardNetwork = CartesBancaires)), + Self::UnionPay => Ok(dirval!(CardNetwork = UnionPay)), + Self::Interac => Ok(dirval!(CardNetwork = Interac)), + Self::RuPay => Ok(dirval!(CardNetwork = RuPay)), + Self::Maestro => Ok(dirval!(CardNetwork = Maestro)), + } + } +} + +impl IntoDirValue for api_enums::Currency { + fn into_dir_value(self) -> Result { + match self { + Self::AED => Ok(dirval!(PaymentCurrency = AED)), + Self::ALL => Ok(dirval!(PaymentCurrency = ALL)), + Self::AMD => Ok(dirval!(PaymentCurrency = AMD)), + Self::ANG => Ok(dirval!(PaymentCurrency = ANG)), + Self::ARS => Ok(dirval!(PaymentCurrency = ARS)), + Self::AUD => Ok(dirval!(PaymentCurrency = AUD)), + Self::AWG => Ok(dirval!(PaymentCurrency = AWG)), + Self::AZN => Ok(dirval!(PaymentCurrency = AZN)), + Self::BBD => Ok(dirval!(PaymentCurrency = BBD)), + Self::BDT => Ok(dirval!(PaymentCurrency = BDT)), + Self::BHD => Ok(dirval!(PaymentCurrency = BHD)), + Self::BIF => Ok(dirval!(PaymentCurrency = BIF)), + Self::BMD => Ok(dirval!(PaymentCurrency = BMD)), + Self::BND => Ok(dirval!(PaymentCurrency = BND)), + Self::BOB => Ok(dirval!(PaymentCurrency = BOB)), + Self::BRL => Ok(dirval!(PaymentCurrency = BRL)), + Self::BSD => Ok(dirval!(PaymentCurrency = BSD)), + Self::BWP => Ok(dirval!(PaymentCurrency = BWP)), + Self::BZD => Ok(dirval!(PaymentCurrency = BZD)), + Self::CAD => Ok(dirval!(PaymentCurrency = CAD)), + Self::CHF => Ok(dirval!(PaymentCurrency = CHF)), + Self::CLP => Ok(dirval!(PaymentCurrency = CLP)), + Self::CNY => Ok(dirval!(PaymentCurrency = CNY)), + Self::COP => Ok(dirval!(PaymentCurrency = COP)), + Self::CRC => Ok(dirval!(PaymentCurrency = CRC)), + Self::CUP => Ok(dirval!(PaymentCurrency = CUP)), + Self::CZK => Ok(dirval!(PaymentCurrency = CZK)), + Self::DJF => Ok(dirval!(PaymentCurrency = DJF)), + Self::DKK => Ok(dirval!(PaymentCurrency = DKK)), + Self::DOP => Ok(dirval!(PaymentCurrency = DOP)), + Self::DZD => Ok(dirval!(PaymentCurrency = DZD)), + Self::EGP => Ok(dirval!(PaymentCurrency = EGP)), + Self::ETB => Ok(dirval!(PaymentCurrency = ETB)), + Self::EUR => Ok(dirval!(PaymentCurrency = EUR)), + Self::FJD => Ok(dirval!(PaymentCurrency = FJD)), + Self::GBP => Ok(dirval!(PaymentCurrency = GBP)), + Self::GHS => Ok(dirval!(PaymentCurrency = GHS)), + Self::GIP => Ok(dirval!(PaymentCurrency = GIP)), + Self::GMD => Ok(dirval!(PaymentCurrency = GMD)), + Self::GNF => Ok(dirval!(PaymentCurrency = GNF)), + Self::GTQ => Ok(dirval!(PaymentCurrency = GTQ)), + Self::GYD => Ok(dirval!(PaymentCurrency = GYD)), + Self::HKD => Ok(dirval!(PaymentCurrency = HKD)), + Self::HNL => Ok(dirval!(PaymentCurrency = HNL)), + Self::HRK => Ok(dirval!(PaymentCurrency = HRK)), + Self::HTG => Ok(dirval!(PaymentCurrency = HTG)), + Self::HUF => Ok(dirval!(PaymentCurrency = HUF)), + Self::IDR => Ok(dirval!(PaymentCurrency = IDR)), + Self::ILS => Ok(dirval!(PaymentCurrency = ILS)), + Self::INR => Ok(dirval!(PaymentCurrency = INR)), + Self::JMD => Ok(dirval!(PaymentCurrency = JMD)), + Self::JOD => Ok(dirval!(PaymentCurrency = JOD)), + Self::JPY => Ok(dirval!(PaymentCurrency = JPY)), + Self::KES => Ok(dirval!(PaymentCurrency = KES)), + Self::KGS => Ok(dirval!(PaymentCurrency = KGS)), + Self::KHR => Ok(dirval!(PaymentCurrency = KHR)), + Self::KMF => Ok(dirval!(PaymentCurrency = KMF)), + Self::KRW => Ok(dirval!(PaymentCurrency = KRW)), + Self::KWD => Ok(dirval!(PaymentCurrency = KWD)), + Self::KYD => Ok(dirval!(PaymentCurrency = KYD)), + Self::KZT => Ok(dirval!(PaymentCurrency = KZT)), + Self::LAK => Ok(dirval!(PaymentCurrency = LAK)), + Self::LBP => Ok(dirval!(PaymentCurrency = LBP)), + Self::LKR => Ok(dirval!(PaymentCurrency = LKR)), + Self::LRD => Ok(dirval!(PaymentCurrency = LRD)), + Self::LSL => Ok(dirval!(PaymentCurrency = LSL)), + Self::MAD => Ok(dirval!(PaymentCurrency = MAD)), + Self::MDL => Ok(dirval!(PaymentCurrency = MDL)), + Self::MGA => Ok(dirval!(PaymentCurrency = MGA)), + Self::MKD => Ok(dirval!(PaymentCurrency = MKD)), + Self::MMK => Ok(dirval!(PaymentCurrency = MMK)), + Self::MNT => Ok(dirval!(PaymentCurrency = MNT)), + Self::MOP => Ok(dirval!(PaymentCurrency = MOP)), + Self::MUR => Ok(dirval!(PaymentCurrency = MUR)), + Self::MVR => Ok(dirval!(PaymentCurrency = MVR)), + Self::MWK => Ok(dirval!(PaymentCurrency = MWK)), + Self::MXN => Ok(dirval!(PaymentCurrency = MXN)), + Self::MYR => Ok(dirval!(PaymentCurrency = MYR)), + Self::NAD => Ok(dirval!(PaymentCurrency = NAD)), + Self::NGN => Ok(dirval!(PaymentCurrency = NGN)), + Self::NIO => Ok(dirval!(PaymentCurrency = NIO)), + Self::NOK => Ok(dirval!(PaymentCurrency = NOK)), + Self::NPR => Ok(dirval!(PaymentCurrency = NPR)), + Self::NZD => Ok(dirval!(PaymentCurrency = NZD)), + Self::OMR => Ok(dirval!(PaymentCurrency = OMR)), + Self::PEN => Ok(dirval!(PaymentCurrency = PEN)), + Self::PGK => Ok(dirval!(PaymentCurrency = PGK)), + Self::PHP => Ok(dirval!(PaymentCurrency = PHP)), + Self::PKR => Ok(dirval!(PaymentCurrency = PKR)), + Self::PLN => Ok(dirval!(PaymentCurrency = PLN)), + Self::PYG => Ok(dirval!(PaymentCurrency = PYG)), + Self::QAR => Ok(dirval!(PaymentCurrency = QAR)), + Self::RON => Ok(dirval!(PaymentCurrency = RON)), + Self::RUB => Ok(dirval!(PaymentCurrency = RUB)), + Self::RWF => Ok(dirval!(PaymentCurrency = RWF)), + Self::SAR => Ok(dirval!(PaymentCurrency = SAR)), + Self::SCR => Ok(dirval!(PaymentCurrency = SCR)), + Self::SEK => Ok(dirval!(PaymentCurrency = SEK)), + Self::SGD => Ok(dirval!(PaymentCurrency = SGD)), + Self::SLL => Ok(dirval!(PaymentCurrency = SLL)), + Self::SOS => Ok(dirval!(PaymentCurrency = SOS)), + Self::SSP => Ok(dirval!(PaymentCurrency = SSP)), + Self::SVC => Ok(dirval!(PaymentCurrency = SVC)), + Self::SZL => Ok(dirval!(PaymentCurrency = SZL)), + Self::THB => Ok(dirval!(PaymentCurrency = THB)), + Self::TRY => Ok(dirval!(PaymentCurrency = TRY)), + Self::TTD => Ok(dirval!(PaymentCurrency = TTD)), + Self::TWD => Ok(dirval!(PaymentCurrency = TWD)), + Self::TZS => Ok(dirval!(PaymentCurrency = TZS)), + Self::UGX => Ok(dirval!(PaymentCurrency = UGX)), + Self::USD => Ok(dirval!(PaymentCurrency = USD)), + Self::UYU => Ok(dirval!(PaymentCurrency = UYU)), + Self::UZS => Ok(dirval!(PaymentCurrency = UZS)), + Self::VND => Ok(dirval!(PaymentCurrency = VND)), + Self::VUV => Ok(dirval!(PaymentCurrency = VUV)), + Self::XAF => Ok(dirval!(PaymentCurrency = XAF)), + Self::XOF => Ok(dirval!(PaymentCurrency = XOF)), + Self::XPF => Ok(dirval!(PaymentCurrency = XPF)), + Self::YER => Ok(dirval!(PaymentCurrency = YER)), + Self::ZAR => Ok(dirval!(PaymentCurrency = ZAR)), + } + } +} + +pub fn get_dir_country_dir_value(c: api_enums::Country) -> dir::enums::Country { + match c { + api_enums::Country::Afghanistan => dir::enums::Country::Afghanistan, + api_enums::Country::AlandIslands => dir::enums::Country::AlandIslands, + api_enums::Country::Albania => dir::enums::Country::Albania, + api_enums::Country::Algeria => dir::enums::Country::Algeria, + api_enums::Country::AmericanSamoa => dir::enums::Country::AmericanSamoa, + api_enums::Country::Andorra => dir::enums::Country::Andorra, + api_enums::Country::Angola => dir::enums::Country::Angola, + api_enums::Country::Anguilla => dir::enums::Country::Anguilla, + api_enums::Country::Antarctica => dir::enums::Country::Antarctica, + api_enums::Country::AntiguaAndBarbuda => dir::enums::Country::AntiguaAndBarbuda, + api_enums::Country::Argentina => dir::enums::Country::Argentina, + api_enums::Country::Armenia => dir::enums::Country::Armenia, + api_enums::Country::Aruba => dir::enums::Country::Aruba, + api_enums::Country::Australia => dir::enums::Country::Australia, + api_enums::Country::Austria => dir::enums::Country::Austria, + api_enums::Country::Azerbaijan => dir::enums::Country::Azerbaijan, + api_enums::Country::Bahamas => dir::enums::Country::Bahamas, + api_enums::Country::Bahrain => dir::enums::Country::Bahrain, + api_enums::Country::Bangladesh => dir::enums::Country::Bangladesh, + api_enums::Country::Barbados => dir::enums::Country::Barbados, + api_enums::Country::Belarus => dir::enums::Country::Belarus, + api_enums::Country::Belgium => dir::enums::Country::Belgium, + api_enums::Country::Belize => dir::enums::Country::Belize, + api_enums::Country::Benin => dir::enums::Country::Benin, + api_enums::Country::Bermuda => dir::enums::Country::Bermuda, + api_enums::Country::Bhutan => dir::enums::Country::Bhutan, + api_enums::Country::BoliviaPlurinationalState => { + dir::enums::Country::BoliviaPlurinationalState + } + api_enums::Country::BonaireSintEustatiusAndSaba => { + dir::enums::Country::BonaireSintEustatiusAndSaba + } + api_enums::Country::BosniaAndHerzegovina => dir::enums::Country::BosniaAndHerzegovina, + api_enums::Country::Botswana => dir::enums::Country::Botswana, + api_enums::Country::BouvetIsland => dir::enums::Country::BouvetIsland, + api_enums::Country::Brazil => dir::enums::Country::Brazil, + api_enums::Country::BritishIndianOceanTerritory => { + dir::enums::Country::BritishIndianOceanTerritory + } + api_enums::Country::BruneiDarussalam => dir::enums::Country::BruneiDarussalam, + api_enums::Country::Bulgaria => dir::enums::Country::Bulgaria, + api_enums::Country::BurkinaFaso => dir::enums::Country::BurkinaFaso, + api_enums::Country::Burundi => dir::enums::Country::Burundi, + api_enums::Country::CaboVerde => dir::enums::Country::CaboVerde, + api_enums::Country::Cambodia => dir::enums::Country::Cambodia, + api_enums::Country::Cameroon => dir::enums::Country::Cameroon, + api_enums::Country::Canada => dir::enums::Country::Canada, + api_enums::Country::CaymanIslands => dir::enums::Country::CaymanIslands, + api_enums::Country::CentralAfricanRepublic => dir::enums::Country::CentralAfricanRepublic, + api_enums::Country::Chad => dir::enums::Country::Chad, + api_enums::Country::Chile => dir::enums::Country::Chile, + api_enums::Country::China => dir::enums::Country::China, + api_enums::Country::ChristmasIsland => dir::enums::Country::ChristmasIsland, + api_enums::Country::CocosKeelingIslands => dir::enums::Country::CocosKeelingIslands, + api_enums::Country::Colombia => dir::enums::Country::Colombia, + api_enums::Country::Comoros => dir::enums::Country::Comoros, + api_enums::Country::Congo => dir::enums::Country::Congo, + api_enums::Country::CongoDemocraticRepublic => dir::enums::Country::CongoDemocraticRepublic, + api_enums::Country::CookIslands => dir::enums::Country::CookIslands, + api_enums::Country::CostaRica => dir::enums::Country::CostaRica, + api_enums::Country::CotedIvoire => dir::enums::Country::CotedIvoire, + api_enums::Country::Croatia => dir::enums::Country::Croatia, + api_enums::Country::Cuba => dir::enums::Country::Cuba, + api_enums::Country::Curacao => dir::enums::Country::Curacao, + api_enums::Country::Cyprus => dir::enums::Country::Cyprus, + api_enums::Country::Czechia => dir::enums::Country::Czechia, + api_enums::Country::Denmark => dir::enums::Country::Denmark, + api_enums::Country::Djibouti => dir::enums::Country::Djibouti, + api_enums::Country::Dominica => dir::enums::Country::Dominica, + api_enums::Country::DominicanRepublic => dir::enums::Country::DominicanRepublic, + api_enums::Country::Ecuador => dir::enums::Country::Ecuador, + api_enums::Country::Egypt => dir::enums::Country::Egypt, + api_enums::Country::ElSalvador => dir::enums::Country::ElSalvador, + api_enums::Country::EquatorialGuinea => dir::enums::Country::EquatorialGuinea, + api_enums::Country::Eritrea => dir::enums::Country::Eritrea, + api_enums::Country::Estonia => dir::enums::Country::Estonia, + api_enums::Country::Ethiopia => dir::enums::Country::Ethiopia, + api_enums::Country::FalklandIslandsMalvinas => dir::enums::Country::FalklandIslandsMalvinas, + api_enums::Country::FaroeIslands => dir::enums::Country::FaroeIslands, + api_enums::Country::Fiji => dir::enums::Country::Fiji, + api_enums::Country::Finland => dir::enums::Country::Finland, + api_enums::Country::France => dir::enums::Country::France, + api_enums::Country::FrenchGuiana => dir::enums::Country::FrenchGuiana, + api_enums::Country::FrenchPolynesia => dir::enums::Country::FrenchPolynesia, + api_enums::Country::FrenchSouthernTerritories => { + dir::enums::Country::FrenchSouthernTerritories + } + api_enums::Country::Gabon => dir::enums::Country::Gabon, + api_enums::Country::Gambia => dir::enums::Country::Gambia, + api_enums::Country::Georgia => dir::enums::Country::Georgia, + api_enums::Country::Germany => dir::enums::Country::Germany, + api_enums::Country::Ghana => dir::enums::Country::Ghana, + api_enums::Country::Gibraltar => dir::enums::Country::Gibraltar, + api_enums::Country::Greece => dir::enums::Country::Greece, + api_enums::Country::Greenland => dir::enums::Country::Greenland, + api_enums::Country::Grenada => dir::enums::Country::Grenada, + api_enums::Country::Guadeloupe => dir::enums::Country::Guadeloupe, + api_enums::Country::Guam => dir::enums::Country::Guam, + api_enums::Country::Guatemala => dir::enums::Country::Guatemala, + api_enums::Country::Guernsey => dir::enums::Country::Guernsey, + api_enums::Country::Guinea => dir::enums::Country::Guinea, + api_enums::Country::GuineaBissau => dir::enums::Country::GuineaBissau, + api_enums::Country::Guyana => dir::enums::Country::Guyana, + api_enums::Country::Haiti => dir::enums::Country::Haiti, + api_enums::Country::HeardIslandAndMcDonaldIslands => { + dir::enums::Country::HeardIslandAndMcDonaldIslands + } + api_enums::Country::HolySee => dir::enums::Country::HolySee, + api_enums::Country::Honduras => dir::enums::Country::Honduras, + api_enums::Country::HongKong => dir::enums::Country::HongKong, + api_enums::Country::Hungary => dir::enums::Country::Hungary, + api_enums::Country::Iceland => dir::enums::Country::Iceland, + api_enums::Country::India => dir::enums::Country::India, + api_enums::Country::Indonesia => dir::enums::Country::Indonesia, + api_enums::Country::IranIslamicRepublic => dir::enums::Country::IranIslamicRepublic, + api_enums::Country::Iraq => dir::enums::Country::Iraq, + api_enums::Country::Ireland => dir::enums::Country::Ireland, + api_enums::Country::IsleOfMan => dir::enums::Country::IsleOfMan, + api_enums::Country::Israel => dir::enums::Country::Israel, + api_enums::Country::Italy => dir::enums::Country::Italy, + api_enums::Country::Jamaica => dir::enums::Country::Jamaica, + api_enums::Country::Japan => dir::enums::Country::Japan, + api_enums::Country::Jersey => dir::enums::Country::Jersey, + api_enums::Country::Jordan => dir::enums::Country::Jordan, + api_enums::Country::Kazakhstan => dir::enums::Country::Kazakhstan, + api_enums::Country::Kenya => dir::enums::Country::Kenya, + api_enums::Country::Kiribati => dir::enums::Country::Kiribati, + api_enums::Country::KoreaDemocraticPeoplesRepublic => { + dir::enums::Country::KoreaDemocraticPeoplesRepublic + } + api_enums::Country::KoreaRepublic => dir::enums::Country::KoreaRepublic, + api_enums::Country::Kuwait => dir::enums::Country::Kuwait, + api_enums::Country::Kyrgyzstan => dir::enums::Country::Kyrgyzstan, + api_enums::Country::LaoPeoplesDemocraticRepublic => { + dir::enums::Country::LaoPeoplesDemocraticRepublic + } + api_enums::Country::Latvia => dir::enums::Country::Latvia, + api_enums::Country::Lebanon => dir::enums::Country::Lebanon, + api_enums::Country::Lesotho => dir::enums::Country::Lesotho, + api_enums::Country::Liberia => dir::enums::Country::Liberia, + api_enums::Country::Libya => dir::enums::Country::Libya, + api_enums::Country::Liechtenstein => dir::enums::Country::Liechtenstein, + api_enums::Country::Lithuania => dir::enums::Country::Lithuania, + api_enums::Country::Luxembourg => dir::enums::Country::Luxembourg, + api_enums::Country::Macao => dir::enums::Country::Macao, + api_enums::Country::MacedoniaTheFormerYugoslavRepublic => { + dir::enums::Country::MacedoniaTheFormerYugoslavRepublic + } + api_enums::Country::Madagascar => dir::enums::Country::Madagascar, + api_enums::Country::Malawi => dir::enums::Country::Malawi, + api_enums::Country::Malaysia => dir::enums::Country::Malaysia, + api_enums::Country::Maldives => dir::enums::Country::Maldives, + api_enums::Country::Mali => dir::enums::Country::Mali, + api_enums::Country::Malta => dir::enums::Country::Malta, + api_enums::Country::MarshallIslands => dir::enums::Country::MarshallIslands, + api_enums::Country::Martinique => dir::enums::Country::Martinique, + api_enums::Country::Mauritania => dir::enums::Country::Mauritania, + api_enums::Country::Mauritius => dir::enums::Country::Mauritius, + api_enums::Country::Mayotte => dir::enums::Country::Mayotte, + api_enums::Country::Mexico => dir::enums::Country::Mexico, + api_enums::Country::MicronesiaFederatedStates => { + dir::enums::Country::MicronesiaFederatedStates + } + api_enums::Country::MoldovaRepublic => dir::enums::Country::MoldovaRepublic, + api_enums::Country::Monaco => dir::enums::Country::Monaco, + api_enums::Country::Mongolia => dir::enums::Country::Mongolia, + api_enums::Country::Montenegro => dir::enums::Country::Montenegro, + api_enums::Country::Montserrat => dir::enums::Country::Montserrat, + api_enums::Country::Morocco => dir::enums::Country::Morocco, + api_enums::Country::Mozambique => dir::enums::Country::Mozambique, + api_enums::Country::Myanmar => dir::enums::Country::Myanmar, + api_enums::Country::Namibia => dir::enums::Country::Namibia, + api_enums::Country::Nauru => dir::enums::Country::Nauru, + api_enums::Country::Nepal => dir::enums::Country::Nepal, + api_enums::Country::Netherlands => dir::enums::Country::Netherlands, + api_enums::Country::NewCaledonia => dir::enums::Country::NewCaledonia, + api_enums::Country::NewZealand => dir::enums::Country::NewZealand, + api_enums::Country::Nicaragua => dir::enums::Country::Nicaragua, + api_enums::Country::Niger => dir::enums::Country::Niger, + api_enums::Country::Nigeria => dir::enums::Country::Nigeria, + api_enums::Country::Niue => dir::enums::Country::Niue, + api_enums::Country::NorfolkIsland => dir::enums::Country::NorfolkIsland, + api_enums::Country::NorthernMarianaIslands => dir::enums::Country::NorthernMarianaIslands, + api_enums::Country::Norway => dir::enums::Country::Norway, + api_enums::Country::Oman => dir::enums::Country::Oman, + api_enums::Country::Pakistan => dir::enums::Country::Pakistan, + api_enums::Country::Palau => dir::enums::Country::Palau, + api_enums::Country::PalestineState => dir::enums::Country::PalestineState, + api_enums::Country::Panama => dir::enums::Country::Panama, + api_enums::Country::PapuaNewGuinea => dir::enums::Country::PapuaNewGuinea, + api_enums::Country::Paraguay => dir::enums::Country::Paraguay, + api_enums::Country::Peru => dir::enums::Country::Peru, + api_enums::Country::Philippines => dir::enums::Country::Philippines, + api_enums::Country::Pitcairn => dir::enums::Country::Pitcairn, + + api_enums::Country::Poland => dir::enums::Country::Poland, + api_enums::Country::Portugal => dir::enums::Country::Portugal, + api_enums::Country::PuertoRico => dir::enums::Country::PuertoRico, + + api_enums::Country::Qatar => dir::enums::Country::Qatar, + api_enums::Country::Reunion => dir::enums::Country::Reunion, + api_enums::Country::Romania => dir::enums::Country::Romania, + api_enums::Country::RussianFederation => dir::enums::Country::RussianFederation, + api_enums::Country::Rwanda => dir::enums::Country::Rwanda, + api_enums::Country::SaintBarthelemy => dir::enums::Country::SaintBarthelemy, + api_enums::Country::SaintHelenaAscensionAndTristandaCunha => { + dir::enums::Country::SaintHelenaAscensionAndTristandaCunha + } + api_enums::Country::SaintKittsAndNevis => dir::enums::Country::SaintKittsAndNevis, + api_enums::Country::SaintLucia => dir::enums::Country::SaintLucia, + api_enums::Country::SaintMartinFrenchpart => dir::enums::Country::SaintMartinFrenchpart, + api_enums::Country::SaintPierreAndMiquelon => dir::enums::Country::SaintPierreAndMiquelon, + api_enums::Country::SaintVincentAndTheGrenadines => { + dir::enums::Country::SaintVincentAndTheGrenadines + } + api_enums::Country::Samoa => dir::enums::Country::Samoa, + api_enums::Country::SanMarino => dir::enums::Country::SanMarino, + api_enums::Country::SaoTomeAndPrincipe => dir::enums::Country::SaoTomeAndPrincipe, + api_enums::Country::SaudiArabia => dir::enums::Country::SaudiArabia, + api_enums::Country::Senegal => dir::enums::Country::Senegal, + api_enums::Country::Serbia => dir::enums::Country::Serbia, + api_enums::Country::Seychelles => dir::enums::Country::Seychelles, + api_enums::Country::SierraLeone => dir::enums::Country::SierraLeone, + api_enums::Country::Singapore => dir::enums::Country::Singapore, + api_enums::Country::SintMaartenDutchpart => dir::enums::Country::SintMaartenDutchpart, + api_enums::Country::Slovakia => dir::enums::Country::Slovakia, + api_enums::Country::Slovenia => dir::enums::Country::Slovenia, + api_enums::Country::SolomonIslands => dir::enums::Country::SolomonIslands, + api_enums::Country::Somalia => dir::enums::Country::Somalia, + api_enums::Country::SouthAfrica => dir::enums::Country::SouthAfrica, + api_enums::Country::SouthGeorgiaAndTheSouthSandwichIslands => { + dir::enums::Country::SouthGeorgiaAndTheSouthSandwichIslands + } + api_enums::Country::SouthSudan => dir::enums::Country::SouthSudan, + api_enums::Country::Spain => dir::enums::Country::Spain, + api_enums::Country::SriLanka => dir::enums::Country::SriLanka, + api_enums::Country::Sudan => dir::enums::Country::Sudan, + api_enums::Country::Suriname => dir::enums::Country::Suriname, + api_enums::Country::SvalbardAndJanMayen => dir::enums::Country::SvalbardAndJanMayen, + api_enums::Country::Swaziland => dir::enums::Country::Swaziland, + api_enums::Country::Sweden => dir::enums::Country::Sweden, + api_enums::Country::Switzerland => dir::enums::Country::Switzerland, + api_enums::Country::SyrianArabRepublic => dir::enums::Country::SyrianArabRepublic, + api_enums::Country::TaiwanProvinceOfChina => dir::enums::Country::TaiwanProvinceOfChina, + api_enums::Country::Tajikistan => dir::enums::Country::Tajikistan, + api_enums::Country::TanzaniaUnitedRepublic => dir::enums::Country::TanzaniaUnitedRepublic, + api_enums::Country::Thailand => dir::enums::Country::Thailand, + api_enums::Country::TimorLeste => dir::enums::Country::TimorLeste, + api_enums::Country::Togo => dir::enums::Country::Togo, + api_enums::Country::Tokelau => dir::enums::Country::Tokelau, + api_enums::Country::Tonga => dir::enums::Country::Tonga, + api_enums::Country::TrinidadAndTobago => dir::enums::Country::TrinidadAndTobago, + api_enums::Country::Tunisia => dir::enums::Country::Tunisia, + api_enums::Country::Turkey => dir::enums::Country::Turkey, + api_enums::Country::Turkmenistan => dir::enums::Country::Turkmenistan, + api_enums::Country::TurksAndCaicosIslands => dir::enums::Country::TurksAndCaicosIslands, + api_enums::Country::Tuvalu => dir::enums::Country::Tuvalu, + api_enums::Country::Uganda => dir::enums::Country::Uganda, + api_enums::Country::Ukraine => dir::enums::Country::Ukraine, + api_enums::Country::UnitedArabEmirates => dir::enums::Country::UnitedArabEmirates, + api_enums::Country::UnitedKingdomOfGreatBritainAndNorthernIreland => { + dir::enums::Country::UnitedKingdomOfGreatBritainAndNorthernIreland + } + api_enums::Country::UnitedStatesOfAmerica => dir::enums::Country::UnitedStatesOfAmerica, + api_enums::Country::UnitedStatesMinorOutlyingIslands => { + dir::enums::Country::UnitedStatesMinorOutlyingIslands + } + api_enums::Country::Uruguay => dir::enums::Country::Uruguay, + api_enums::Country::Uzbekistan => dir::enums::Country::Uzbekistan, + api_enums::Country::Vanuatu => dir::enums::Country::Vanuatu, + api_enums::Country::VenezuelaBolivarianRepublic => { + dir::enums::Country::VenezuelaBolivarianRepublic + } + api_enums::Country::Vietnam => dir::enums::Country::Vietnam, + api_enums::Country::VirginIslandsBritish => dir::enums::Country::VirginIslandsBritish, + api_enums::Country::VirginIslandsUS => dir::enums::Country::VirginIslandsUS, + api_enums::Country::WallisAndFutuna => dir::enums::Country::WallisAndFutuna, + api_enums::Country::WesternSahara => dir::enums::Country::WesternSahara, + api_enums::Country::Yemen => dir::enums::Country::Yemen, + api_enums::Country::Zambia => dir::enums::Country::Zambia, + api_enums::Country::Zimbabwe => dir::enums::Country::Zimbabwe, + } +} + +pub fn business_country_to_dir_value(c: api_enums::Country) -> dir::DirValue { + dir::DirValue::BusinessCountry(get_dir_country_dir_value(c)) +} + +pub fn billing_country_to_dir_value(c: api_enums::Country) -> dir::DirValue { + dir::DirValue::BillingCountry(get_dir_country_dir_value(c)) +} diff --git a/crates/router/Cargo.toml b/crates/router/Cargo.toml index 81b23314ffb8..9ab955813336 100644 --- a/crates/router/Cargo.toml +++ b/crates/router/Cargo.toml @@ -9,20 +9,23 @@ readme = "README.md" license.workspace = true [features] -default = ["kv_store", "stripe", "oltp", "olap", "accounts_cache", "dummy_connector", "payouts"] +default = ["kv_store", "stripe", "oltp", "olap", "backwards_compatibility", "accounts_cache", "dummy_connector", "payouts"] s3 = ["dep:aws-sdk-s3", "dep:aws-config"] kms = ["external_services/kms", "dep:aws-config"] email = ["external_services/email", "dep:aws-config"] basilisk = ["kms"] stripe = ["dep:serde_qs"] -release = ["kms", "stripe","basilisk","s3", "email","accounts_cache","kv_store"] +release = ["kms", "stripe", "basilisk", "s3", "email", "business_profile_routing", "accounts_cache", "kv_store"] olap = ["data_models/olap", "storage_impl/olap", "scheduler/olap"] oltp = ["data_models/oltp", "storage_impl/oltp"] kv_store = ["scheduler/kv_store"] accounts_cache = [] openapi = ["olap", "oltp", "payouts"] vergen = ["router_env/vergen"] -dummy_connector = ["api_models/dummy_connector"] +backwards_compatibility = ["api_models/backwards_compatibility", "euclid/backwards_compatibility", "kgraph_utils/backwards_compatibility"] +business_profile_routing=["api_models/business_profile_routing"] +dummy_connector = ["api_models/dummy_connector", "euclid/dummy_connector", "kgraph_utils/dummy_connector"] +connector_choice_mca_id = ["api_models/connector_choice_mca_id", "euclid/connector_choice_mca_id", "kgraph_utils/connector_choice_mca_id"] external_access_dc = ["dummy_connector"] detailed_errors = ["api_models/detailed_errors", "error-stack/serde"] payouts = [] @@ -66,10 +69,12 @@ num_cpus = "1.15.0" once_cell = "1.18.0" qrcode = "0.12.0" rand = "0.8.5" +rand_chacha = "0.3.1" regex = "1.8.4" reqwest = { version = "0.11.18", features = ["json", "native-tls", "gzip", "multipart"] } ring = "0.16.20" roxmltree = "0.18.0" +rustc-hash = "1.1.0" serde = { version = "1.0.163", features = ["derive"] } serde_json = "1.0.96" serde_path_to_error = "0.1.11" @@ -96,6 +101,7 @@ api_models = { version = "0.1.0", path = "../api_models", features = ["errors"] cards = { version = "0.1.0", path = "../cards" } common_utils = { version = "0.1.0", path = "../common_utils", features = ["signals", "async_ext", "logs"] } external_services = { version = "0.1.0", path = "../external_services" } +euclid = { version = "0.1.0", path = "../euclid", features = ["valued_jit"] } masking = { version = "0.1.0", path = "../masking" } redis_interface = { version = "0.1.0", path = "../redis_interface" } router_derive = { version = "0.1.0", path = "../router_derive" } @@ -103,6 +109,7 @@ router_env = { version = "0.1.0", path = "../router_env", features = ["log_extra diesel_models = { version = "0.1.0", path = "../diesel_models", features = ["kv_store"] } scheduler = { version = "0.1.0", path = "../scheduler", default-features = false} data_models = { version = "0.1.0", path = "../data_models", default-features = false } +kgraph_utils = { version = "0.1.0", path = "../kgraph_utils" } storage_impl = { version = "0.1.0", path = "../storage_impl", default-features = false } [target.'cfg(not(target_os = "windows"))'.dependencies] diff --git a/crates/router/src/compatibility/stripe/payment_intents.rs b/crates/router/src/compatibility/stripe/payment_intents.rs index 1076dfe410fc..c237f21dde66 100644 --- a/crates/router/src/compatibility/stripe/payment_intents.rs +++ b/crates/router/src/compatibility/stripe/payment_intents.rs @@ -9,7 +9,7 @@ use crate::{ core::{api_locking::GetLockingInput, payment_methods::Oss, payments}, routes, services::{api, authentication as auth}, - types::api::{self as api_types}, + types::api as api_types, }; #[instrument(skip_all, fields(flow = ?Flow::PaymentsCreate))] @@ -50,6 +50,7 @@ pub async fn payment_intents_create( &req, create_payment_req, |state, auth, req| { + let eligible_connectors = req.connector.clone(); payments::payments_core::( state, auth.merchant_account, @@ -58,6 +59,7 @@ pub async fn payment_intents_create( req, api::AuthFlow::Merchant, payments::CallConnectorAction::Trigger, + eligible_connectors, api_types::HeaderPayload::default(), ) }, @@ -117,6 +119,7 @@ pub async fn payment_intents_retrieve( payload, auth_flow, payments::CallConnectorAction::Trigger, + None, api_types::HeaderPayload::default(), ) }, @@ -180,6 +183,7 @@ pub async fn payment_intents_retrieve_with_gateway_creds( req, api::AuthFlow::Merchant, payments::CallConnectorAction::Trigger, + None, api_types::HeaderPayload::default(), ) }, @@ -236,6 +240,7 @@ pub async fn payment_intents_update( &req, payload, |state, auth, req| { + let eligible_connectors = req.connector.clone(); payments::payments_core::( state, auth.merchant_account, @@ -244,6 +249,7 @@ pub async fn payment_intents_update( req, auth_flow, payments::CallConnectorAction::Trigger, + eligible_connectors, api_types::HeaderPayload::default(), ) }, @@ -302,6 +308,7 @@ pub async fn payment_intents_confirm( &req, payload, |state, auth, req| { + let eligible_connectors = req.connector.clone(); payments::payments_core::( state, auth.merchant_account, @@ -310,6 +317,7 @@ pub async fn payment_intents_confirm( req, auth_flow, payments::CallConnectorAction::Trigger, + eligible_connectors, api_types::HeaderPayload::default(), ) }, @@ -366,6 +374,7 @@ pub async fn payment_intents_capture( payload, api::AuthFlow::Merchant, payments::CallConnectorAction::Trigger, + None, api_types::HeaderPayload::default(), ) }, @@ -426,6 +435,7 @@ pub async fn payment_intents_cancel( req, auth_flow, payments::CallConnectorAction::Trigger, + None, api_types::HeaderPayload::default(), ) }, diff --git a/crates/router/src/compatibility/stripe/payment_intents/types.rs b/crates/router/src/compatibility/stripe/payment_intents/types.rs index 4d9632f8885e..c713011b80c8 100644 --- a/crates/router/src/compatibility/stripe/payment_intents/types.rs +++ b/crates/router/src/compatibility/stripe/payment_intents/types.rs @@ -282,9 +282,17 @@ impl TryFrom for payments::PaymentsRequest { let routing = routable_connector .map(|connector| { - crate::types::api::RoutingAlgorithm::Single( - api_models::admin::RoutableConnectorChoice::ConnectorName(connector), - ) + api_models::routing::RoutingAlgorithm::Single(Box::new( + api_models::routing::RoutableConnectorChoice { + #[cfg(feature = "backwards_compatibility")] + choice_kind: api_models::routing::RoutableChoiceKind::FullStruct, + connector, + #[cfg(feature = "connector_choice_mca_id")] + merchant_connector_id: None, + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label: None, + }, + )) }) .map(|r| { serde_json::to_value(r) diff --git a/crates/router/src/compatibility/stripe/setup_intents.rs b/crates/router/src/compatibility/stripe/setup_intents.rs index 311498e1af58..515e41ec91fa 100644 --- a/crates/router/src/compatibility/stripe/setup_intents.rs +++ b/crates/router/src/compatibility/stripe/setup_intents.rs @@ -69,6 +69,7 @@ pub async fn setup_intents_create( req, api::AuthFlow::Merchant, payments::CallConnectorAction::Trigger, + None, api_types::HeaderPayload::default(), ) }, @@ -128,6 +129,7 @@ pub async fn setup_intents_retrieve( payload, auth_flow, payments::CallConnectorAction::Trigger, + None, api_types::HeaderPayload::default(), ) }, @@ -200,6 +202,7 @@ pub async fn setup_intents_update( req, auth_flow, payments::CallConnectorAction::Trigger, + None, api_types::HeaderPayload::default(), ) }, @@ -273,6 +276,7 @@ pub async fn setup_intents_confirm( req, auth_flow, payments::CallConnectorAction::Trigger, + None, api_types::HeaderPayload::default(), ) }, diff --git a/crates/router/src/compatibility/stripe/setup_intents/types.rs b/crates/router/src/compatibility/stripe/setup_intents/types.rs index 661a08e090e0..dde378e55925 100644 --- a/crates/router/src/compatibility/stripe/setup_intents/types.rs +++ b/crates/router/src/compatibility/stripe/setup_intents/types.rs @@ -185,9 +185,17 @@ impl TryFrom for payments::PaymentsRequest { let routing = routable_connector .map(|connector| { - crate::types::api::RoutingAlgorithm::Single( - api_models::admin::RoutableConnectorChoice::ConnectorName(connector), - ) + api_models::routing::RoutingAlgorithm::Single(Box::new( + api_models::routing::RoutableConnectorChoice { + #[cfg(feature = "backwards_compatibility")] + choice_kind: api_models::routing::RoutableChoiceKind::FullStruct, + connector, + #[cfg(feature = "connector_choice_mca_id")] + merchant_connector_id: None, + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label: None, + }, + )) }) .map(|r| { serde_json::to_value(r) diff --git a/crates/router/src/consts.rs b/crates/router/src/consts.rs index 02db8b1754ed..f76df7466581 100644 --- a/crates/router/src/consts.rs +++ b/crates/router/src/consts.rs @@ -46,3 +46,5 @@ pub(crate) const QR_IMAGE_DATA_SOURCE_STRING: &str = "data:image/png;base64"; pub(crate) const MERCHANT_ID_FIELD_EXTENSION_ID: &str = "1.2.840.113635.100.6.32"; pub(crate) const METRICS_HOST_TAG_NAME: &str = "host"; +pub const MAX_ROUTING_CONFIGS_PER_MERCHANT: usize = 100; +pub const ROUTING_CONFIG_ID_LENGTH: usize = 10; diff --git a/crates/router/src/core.rs b/crates/router/src/core.rs index a3bb3c78915c..d87ff64003b4 100644 --- a/crates/router/src/core.rs +++ b/crates/router/src/core.rs @@ -16,6 +16,7 @@ pub mod payments; #[cfg(feature = "payouts")] pub mod payouts; pub mod refunds; +pub mod routing; pub mod utils; #[cfg(all(feature = "olap", feature = "kms"))] pub mod verification; diff --git a/crates/router/src/core/admin.rs b/crates/router/src/core/admin.rs index 5c9f44ffe575..5de273de0cef 100644 --- a/crates/router/src/core/admin.rs +++ b/crates/router/src/core/admin.rs @@ -1,6 +1,8 @@ +use std::str::FromStr; + use api_models::{ admin::{self as admin_types}, - enums as api_enums, + enums as api_enums, routing as routing_types, }; use common_utils::{ crypto::{generate_cryptographically_secure_random_string, OptionalSecretValue}, @@ -18,6 +20,7 @@ use crate::{ core::{ errors::{self, RouterResponse, RouterResult, StorageErrorExt}, payments::helpers, + routing::helpers as routing_helpers, utils as core_utils, }, db::StorageInterface, @@ -89,7 +92,7 @@ pub async fn create_merchant_account( .transpose()?; if let Some(ref routing_algorithm) = req.routing_algorithm { - let _: api::RoutingAlgorithm = routing_algorithm + let _: api_models::routing::RoutingAlgorithm = routing_algorithm .clone() .parse_value("RoutingAlgorithm") .change_context(errors::ApiErrorResponse::InvalidDataValue { @@ -178,7 +181,10 @@ pub async fn create_merchant_account( .await?, return_url: req.return_url.map(|a| a.to_string()), webhook_details, - routing_algorithm: req.routing_algorithm, + routing_algorithm: Some(serde_json::json!({ + "algorithm_id": null, + "timestamp": 0 + })), sub_merchants_enabled: req.sub_merchants_enabled, parent_merchant_id, enable_payment_response_hash, @@ -470,7 +476,7 @@ pub async fn merchant_account_update( } if let Some(ref routing_algorithm) = req.routing_algorithm { - let _: api::RoutingAlgorithm = routing_algorithm + let _: api_models::routing::RoutingAlgorithm = routing_algorithm .clone() .parse_value("RoutingAlgorithm") .change_context(errors::ApiErrorResponse::InvalidDataValue { @@ -756,6 +762,9 @@ pub async fn create_payment_connector( ) .await?; + let routable_connector = + api_enums::RoutableConnectors::from_str(&req.connector_name.to_string()).ok(); + let business_profile = state .store .find_business_profile_by_profile_id(&profile_id) @@ -828,6 +837,37 @@ pub async fn create_payment_connector( let frm_configs = get_frm_config_as_secret(req.frm_configs); + // The purpose of this merchant account update is just to update the + // merchant account `modified_at` field for KGraph cache invalidation + let merchant_account_update = storage::MerchantAccountUpdate::Update { + merchant_name: None, + merchant_details: None, + return_url: None, + webhook_details: None, + sub_merchants_enabled: None, + parent_merchant_id: None, + enable_payment_response_hash: None, + locker_id: None, + payment_response_hash_key: None, + primary_business_details: None, + metadata: None, + publishable_key: None, + redirect_to_merchant_with_http_post: None, + routing_algorithm: None, + intent_fulfillment_time: None, + frm_routing_algorithm: None, + payout_routing_algorithm: None, + default_profile: None, + payment_link_config: None, + }; + + state + .store + .update_specific_fields_in_merchant(merchant_id, merchant_account_update, &key_store) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("error updating the merchant account when creating payment connector")?; + let merchant_connector_account = domain::MerchantConnectorAccount { merchant_id: merchant_id.to_string(), connector_type: req.connector_type, @@ -852,7 +892,7 @@ pub async fn create_payment_connector( connector_label: Some(connector_label), business_country: req.business_country, business_label: req.business_label.clone(), - business_sub_label: req.business_sub_label, + business_sub_label: req.business_sub_label.clone(), created_at: common_utils::date_time::now(), modified_at: common_utils::date_time::now(), id: None, @@ -873,6 +913,9 @@ pub async fn create_payment_connector( pm_auth_config: req.pm_auth_config.clone(), }; + let mut default_routing_config = + routing_helpers::get_merchant_default_config(&*state.store, merchant_id).await?; + let mca = state .store .insert_merchant_connector_account(merchant_connector_account, &key_store) @@ -884,6 +927,28 @@ pub async fn create_payment_connector( }, )?; + if let Some(routable_connector_val) = routable_connector { + let choice = routing_types::RoutableConnectorChoice { + #[cfg(feature = "backwards_compatibility")] + choice_kind: routing_types::RoutableChoiceKind::FullStruct, + connector: routable_connector_val, + #[cfg(feature = "connector_choice_mca_id")] + merchant_connector_id: Some(mca.merchant_connector_id.clone()), + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label: req.business_sub_label, + }; + + if !default_routing_config.contains(&choice) { + default_routing_config.push(choice); + routing_helpers::update_merchant_default_config( + &*state.store, + merchant_id, + default_routing_config, + ) + .await?; + } + } + metrics::MCA_CREATE.add( &metrics::CONTEXT, 1, @@ -1248,7 +1313,7 @@ pub async fn create_business_profile( .to_not_found_response(errors::ApiErrorResponse::MerchantAccountNotFound)?; if let Some(ref routing_algorithm) = request.routing_algorithm { - let _: api::RoutingAlgorithm = routing_algorithm + let _: api_models::routing::RoutingAlgorithm = routing_algorithm .clone() .parse_value("RoutingAlgorithm") .change_context(errors::ApiErrorResponse::InvalidDataValue { @@ -1360,7 +1425,7 @@ pub async fn update_business_profile( .transpose()?; if let Some(ref routing_algorithm) = request.routing_algorithm { - let _: api::RoutingAlgorithm = routing_algorithm + let _: api_models::routing::RoutingAlgorithm = routing_algorithm .clone() .parse_value("RoutingAlgorithm") .change_context(errors::ApiErrorResponse::InvalidDataValue { diff --git a/crates/router/src/core/errors.rs b/crates/router/src/core/errors.rs index 1c062b7035af..dc1d56721e88 100644 --- a/crates/router/src/core/errors.rs +++ b/crates/router/src/core/errors.rs @@ -325,3 +325,49 @@ pub mod error_stack_parsing { } #[cfg(feature = "detailed_errors")] pub use error_stack_parsing::*; + +#[derive(Debug, Clone, thiserror::Error)] +pub enum RoutingError { + #[error("Merchant routing algorithm not found in cache")] + CacheMiss, + #[error("Final connector selection failed")] + ConnectorSelectionFailed, + #[error("[DSL] Missing required field in payment data: '{field_name}'")] + DslMissingRequiredField { field_name: String }, + #[error("The lock on the DSL cache is most probably poisoned")] + DslCachePoisoned, + #[error("Expected DSL to be saved in DB but did not find")] + DslMissingInDb, + #[error("Unable to parse DSL from JSON")] + DslParsingError, + #[error("Failed to initialize DSL backend")] + DslBackendInitError, + #[error("Error updating merchant with latest dsl cache contents")] + DslMerchantUpdateError, + #[error("Error executing the DSL")] + DslExecutionError, + #[error("Final connector selection failed")] + DslFinalConnectorSelectionFailed, + #[error("[DSL] Received incorrect selection algorithm as DSL output")] + DslIncorrectSelectionAlgorithm, + #[error("there was an error saving/retrieving values from the kgraph cache")] + KgraphCacheFailure, + #[error("failed to refresh the kgraph cache")] + KgraphCacheRefreshFailed, + #[error("there was an error during the kgraph analysis phase")] + KgraphAnalysisError, + #[error("'profile_id' was not provided")] + ProfileIdMissing, + #[error("the profile was not found in the database")] + ProfileNotFound, + #[error("failed to fetch the fallback config for the merchant")] + FallbackConfigFetchFailed, + #[error("Invalid connector name received: '{0}'")] + InvalidConnectorName(String), + #[error("The routing algorithm in merchant account had invalid structure")] + InvalidRoutingAlgorithmStructure, + #[error("Volume split failed")] + VolumeSplitFailed, + #[error("Unable to parse metadata")] + MetadataParsingError, +} diff --git a/crates/router/src/core/payment_methods/cards.rs b/crates/router/src/core/payment_methods/cards.rs index 2161ab69222e..417b030f5494 100644 --- a/crates/router/src/core/payment_methods/cards.rs +++ b/crates/router/src/core/payment_methods/cards.rs @@ -31,7 +31,10 @@ use crate::{ transformers::{self as payment_methods}, vault, }, - payments::helpers, + payments::{ + helpers, + routing::{self, SessionFlowRoutingInput}, + }, }, db, logger, pii::prelude::*, @@ -42,7 +45,7 @@ use crate::{ }, services, types::{ - api::{self, PaymentMethodCreateExt}, + api::{self, routing as routing_types, PaymentMethodCreateExt}, domain::{ self, types::{decrypt, encrypt_optional, AsyncLift}, @@ -933,6 +936,135 @@ pub async fn list_payment_methods( .await?; } + if let Some((payment_attempt, payment_intent)) = + payment_attempt.as_ref().zip(payment_intent.as_ref()) + { + let routing_enabled_pms = HashSet::from([ + api_enums::PaymentMethod::BankTransfer, + api_enums::PaymentMethod::BankDebit, + api_enums::PaymentMethod::BankRedirect, + ]); + + let routing_enabled_pm_types = HashSet::from([ + api_enums::PaymentMethodType::GooglePay, + api_enums::PaymentMethodType::ApplePay, + api_enums::PaymentMethodType::Klarna, + api_enums::PaymentMethodType::Paypal, + ]); + + let mut chosen = Vec::::new(); + for intermediate in &response { + if routing_enabled_pm_types.contains(&intermediate.payment_method_type) + || routing_enabled_pms.contains(&intermediate.payment_method) + { + let connector_data = api::ConnectorData::get_connector_by_name( + &state.clone().conf.connectors, + &intermediate.connector, + api::GetToken::from(intermediate.payment_method_type), + None, + ) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("invalid connector name received")?; + + chosen.push(api::SessionConnectorData { + payment_method_type: intermediate.payment_method_type, + connector: connector_data, + business_sub_label: None, + }); + } + } + let sfr = SessionFlowRoutingInput { + state: &state, + country: shipping_address.clone().and_then(|ad| ad.country), + key_store: &key_store, + merchant_account: &merchant_account, + payment_attempt, + payment_intent, + chosen, + }; + let result = routing::perform_session_flow_routing(sfr) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("error performing session flow routing")?; + + response.retain(|intermediate| { + if !routing_enabled_pm_types.contains(&intermediate.payment_method_type) + && !routing_enabled_pms.contains(&intermediate.payment_method) + { + return true; + } + + if let Some(choice) = result.get(&intermediate.payment_method_type) { + intermediate.connector == choice.connector.connector_name.to_string() + } else { + false + } + }); + + let mut routing_info: storage::PaymentRoutingInfo = payment_attempt + .straight_through_algorithm + .clone() + .map(|val| val.parse_value("PaymentRoutingInfo")) + .transpose() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Invalid PaymentRoutingInfo format found in payment attempt")? + .unwrap_or_else(|| storage::PaymentRoutingInfo { + algorithm: None, + pre_routing_results: None, + }); + + let mut pre_routing_results: HashMap< + api_enums::PaymentMethodType, + routing_types::RoutableConnectorChoice, + > = HashMap::new(); + + for (pm_type, choice) in result { + let routable_choice = routing_types::RoutableConnectorChoice { + #[cfg(feature = "backwards_compatibility")] + choice_kind: routing_types::RoutableChoiceKind::FullStruct, + connector: choice + .connector + .connector_name + .to_string() + .parse() + .into_report() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("")?, + #[cfg(feature = "connector_choice_mca_id")] + merchant_connector_id: choice.connector.merchant_connector_id, + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label: choice.sub_label, + }; + + pre_routing_results.insert(pm_type, routable_choice); + } + + routing_info.pre_routing_results = Some(pre_routing_results); + + let encoded = utils::Encode::::encode_to_value(&routing_info) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Unable to serialize payment routing info to value")?; + + let attempt_update = storage::PaymentAttemptUpdate::UpdateTrackers { + payment_token: None, + connector: None, + straight_through_algorithm: Some(encoded), + amount_capturable: None, + updated_by: merchant_account.storage_scheme.to_string(), + merchant_connector_id: None, + }; + + state + .store + .update_payment_attempt_with_attempt_id( + payment_attempt.clone(), + attempt_update, + merchant_account.storage_scheme, + ) + .await + .to_not_found_response(errors::ApiErrorResponse::PaymentNotFound)?; + } + let req = api_models::payments::PaymentsRequest::foreign_from(( payment_attempt.as_ref(), shipping_address.as_ref(), diff --git a/crates/router/src/core/payments.rs b/crates/router/src/core/payments.rs index 9aa0e3c70d25..586126467e18 100644 --- a/crates/router/src/core/payments.rs +++ b/crates/router/src/core/payments.rs @@ -3,11 +3,12 @@ pub mod customers; pub mod flows; pub mod helpers; pub mod operations; +pub mod routing; pub mod tokenization; pub mod transformers; pub mod types; -use std::{fmt::Debug, marker::PhantomData, ops::Deref, time::Instant}; +use std::{fmt::Debug, marker::PhantomData, ops::Deref, time::Instant, vec::IntoIter}; use api_models::{ enums, @@ -35,6 +36,7 @@ pub use self::operations::{ use self::{ flows::{ConstructFlowSpecificData, Feature}, operations::{payment_complete_authorize, BoxedOperation, Operation}, + routing::{self as self_routing, SessionFlowRoutingInput}, }; use super::errors::StorageErrorExt; use crate::{ @@ -49,8 +51,11 @@ use crate::{ routes::{metrics, payment_methods::ParentPaymentMethodToken, AppState}, services::{self, api::Authenticate}, types::{ - self as router_types, api, domain, + self as router_types, + api::{self, ConnectorCallType}, + domain, storage::{self, enums as storage_enums}, + transformers::ForeignTryInto, }, utils::{ add_apple_pay_flow_metrics, add_connector_http_status_code_metrics, Encode, OptionExt, @@ -69,6 +74,7 @@ pub async fn payments_operation_core( req: Req, call_connector_action: CallConnectorAction, auth_flow: services::AuthFlow, + eligible_connectors: Option>, header_payload: HeaderPayload, ) -> RouterResult<( PaymentData, @@ -136,28 +142,11 @@ where &merchant_account, &key_store, &mut payment_data, + eligible_connectors, ) .await?; - let schedule_time = match &connector { - Some(api::ConnectorCallType::Single(connector_data)) => { - if should_add_task_to_process_tracker(&payment_data) { - payment_sync::get_sync_process_schedule_time( - &*state.store, - connector_data.connector.id(), - &merchant_account.merchant_id, - 0, - ) - .await - .into_report() - .change_context(errors::ApiErrorResponse::InternalServerError) - .attach_printable("Failed while getting process schedule time")? - } else { - None - } - } - _ => None, - }; + let should_add_task_to_process_tracker = should_add_task_to_process_tracker(&payment_data); payment_data = tokenize_in_router_when_confirm_false( state, @@ -171,7 +160,21 @@ where let mut external_latency = None; if let Some(connector_details) = connector { payment_data = match connector_details { - api::ConnectorCallType::Single(connector) => { + api::ConnectorCallType::PreDetermined(connector) => { + let schedule_time = if should_add_task_to_process_tracker { + payment_sync::get_sync_process_schedule_time( + &*state.store, + connector.connector.id(), + &merchant_account.merchant_id, + 0, + ) + .await + .into_report() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Failed while getting process schedule time")? + } else { + None + }; let router_data = call_connector_service( state, &merchant_account, @@ -186,6 +189,57 @@ where header_payload, ) .await?; + let operation = Box::new(PaymentResponse); + let db = &*state.store; + connector_http_status_code = router_data.connector_http_status_code; + external_latency = router_data.external_latency; + //add connector http status code metrics + add_connector_http_status_code_metrics(connector_http_status_code); + operation + .to_post_update_tracker()? + .update_tracker( + db, + &validate_result.payment_id, + payment_data, + router_data, + merchant_account.storage_scheme, + ) + .await? + } + + api::ConnectorCallType::Retryable(connectors) => { + let mut connectors = connectors.into_iter(); + + let connector_data = get_connector_data(&mut connectors)?; + + let schedule_time = if should_add_task_to_process_tracker { + payment_sync::get_sync_process_schedule_time( + &*state.store, + connector_data.connector.id(), + &merchant_account.merchant_id, + 0, + ) + .await + .into_report() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Failed while getting process schedule time")? + } else { + None + }; + let router_data = call_connector_service( + state, + &merchant_account, + &key_store, + connector_data, + &operation, + &mut payment_data, + &customer, + call_connector_action, + &validate_result, + schedule_time, + header_payload, + ) + .await?; let operation = Box::new(PaymentResponse); let db = &*state.store; @@ -205,7 +259,7 @@ where .await? } - api::ConnectorCallType::Multiple(connectors) => { + api::ConnectorCallType::SessionMultiple(connectors) => { call_multiple_connectors_service( state, &merchant_account, @@ -258,6 +312,17 @@ where )) } +#[inline] +pub fn get_connector_data( + connectors: &mut IntoIter, +) -> RouterResult { + connectors + .next() + .ok_or(errors::ApiErrorResponse::InternalServerError) + .into_report() + .attach_printable("Connector not found in connectors iterator") +} + #[allow(clippy::too_many_arguments)] pub async fn payments_core( state: AppState, @@ -267,6 +332,7 @@ pub async fn payments_core( req: Req, auth_flow: services::AuthFlow, call_connector_action: CallConnectorAction, + eligible_connectors: Option>, header_payload: HeaderPayload, ) -> RouterResponse where @@ -287,6 +353,12 @@ where // To perform router related operation for PaymentResponse PaymentResponse: Operation, { + let eligible_routable_connectors = eligible_connectors.map(|connectors| { + connectors + .into_iter() + .flat_map(|c| c.foreign_try_into()) + .collect() + }); let (payment_data, req, customer, connector_http_status_code, external_latency) = payments_operation_core::<_, _, _, _, Ctx>( &state, @@ -296,6 +368,7 @@ where req, call_connector_action, auth_flow, + eligible_routable_connectors, header_payload, ) .await?; @@ -470,6 +543,7 @@ impl PaymentRedirectFlow for PaymentRedirectCom payment_confirm_req, services::api::AuthFlow::Merchant, connector_action, + None, HeaderPayload::default(), ) .await @@ -565,11 +639,11 @@ impl PaymentRedirectFlow for PaymentRedirectSyn payment_sync_req, services::api::AuthFlow::Merchant, connector_action, + None, HeaderPayload::default(), ) .await } - fn generate_response( &self, payments_response: api_models::payments::PaymentsResponse, @@ -1842,7 +1916,7 @@ pub fn update_straight_through_routing( where F: Send + Clone, { - let _: api::RoutingAlgorithm = request_straight_through + let _: api_models::routing::RoutingAlgorithm = request_straight_through .clone() .parse_value("RoutingAlgorithm") .attach_printable("Invalid straight through routing rules format")?; @@ -1859,7 +1933,8 @@ pub async fn get_connector_choice( merchant_account: &domain::MerchantAccount, key_store: &domain::MerchantKeyStore, payment_data: &mut PaymentData, -) -> RouterResult> + eligible_connectors: Option>, +) -> RouterResult> where F: Send + Clone, Ctx: PaymentMethodRetrieve, @@ -1868,7 +1943,7 @@ where .to_domain()? .get_connector( merchant_account, - state, + &state.clone(), req, &payment_data.payment_intent, key_store, @@ -1877,39 +1952,132 @@ where let connector = if should_call_connector(operation, payment_data) { Some(match connector_choice { - api::ConnectorChoice::SessionMultiple(session_connectors) => { - api::ConnectorCallType::Multiple(session_connectors) + api::ConnectorChoice::SessionMultiple(connectors) => { + let routing_output = perform_session_token_routing( + state.clone(), + merchant_account, + key_store, + payment_data, + connectors, + ) + .await?; + api::ConnectorCallType::SessionMultiple(routing_output) } - api::ConnectorChoice::StraightThrough(straight_through) => connector_selection( - state, - merchant_account, - payment_data, - Some(straight_through), - )?, + api::ConnectorChoice::StraightThrough(straight_through) => { + connector_selection( + state, + merchant_account, + key_store, + payment_data, + Some(straight_through), + eligible_connectors, + ) + .await? + } api::ConnectorChoice::Decide => { - connector_selection(state, merchant_account, payment_data, None)? + connector_selection( + state, + merchant_account, + key_store, + payment_data, + None, + eligible_connectors, + ) + .await? } }) - } else if let api::ConnectorChoice::StraightThrough(val) = connector_choice { - update_straight_through_routing(payment_data, val) + } else if let api::ConnectorChoice::StraightThrough(algorithm) = connector_choice { + update_straight_through_routing(payment_data, algorithm) .change_context(errors::ApiErrorResponse::InternalServerError) .attach_printable("Failed to update straight through routing algorithm")?; + None } else { None }; - Ok(connector) } -pub fn connector_selection( +pub async fn connector_selection( state: &AppState, merchant_account: &domain::MerchantAccount, + key_store: &domain::MerchantKeyStore, payment_data: &mut PaymentData, request_straight_through: Option, -) -> RouterResult + eligible_connectors: Option>, +) -> RouterResult +where + F: Send + Clone, +{ + let request_straight_through: Option = + request_straight_through + .map(|val| val.parse_value("RoutingAlgorithm")) + .transpose() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Invalid straight through routing rules format")?; + + let mut routing_data = storage::RoutingData { + routed_through: payment_data.payment_attempt.connector.clone(), + #[cfg(feature = "connector_choice_mca_id")] + merchant_connector_id: payment_data.payment_attempt.merchant_connector_id.clone(), + #[cfg(not(feature = "connector_choice_mca_id"))] + business_sub_label: payment_data.payment_attempt.business_sub_label.clone(), + algorithm: request_straight_through.clone(), + routing_info: payment_data + .payment_attempt + .straight_through_algorithm + .clone() + .map(|val| val.parse_value("PaymentRoutingInfo")) + .transpose() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Invalid straight through algorithm format found in payment attempt")? + .unwrap_or_else(|| storage::PaymentRoutingInfo { + algorithm: None, + pre_routing_results: None, + }), + }; + + let decided_connector = decide_connector( + state.clone(), + merchant_account, + key_store, + payment_data, + request_straight_through, + &mut routing_data, + eligible_connectors, + ) + .await?; + + let encoded_info = + Encode::::encode_to_value(&routing_data.routing_info) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("error serializing payment routing info to serde value")?; + + payment_data.payment_attempt.connector = routing_data.routed_through; + #[cfg(feature = "connector_choice_mca_id")] + { + payment_data.payment_attempt.merchant_connector_id = routing_data.merchant_connector_id; + } + #[cfg(not(feature = "connector_choice_mca_id"))] + { + payment_data.payment_attempt.business_sub_label = routing_data.business_sub_label; + } + payment_data.payment_attempt.straight_through_algorithm = Some(encoded_info); + + Ok(decided_connector) +} + +pub async fn decide_connector( + state: AppState, + merchant_account: &domain::MerchantAccount, + key_store: &domain::MerchantKeyStore, + payment_data: &mut PaymentData, + request_straight_through: Option, + routing_data: &mut storage::RoutingData, + eligible_connectors: Option>, +) -> RouterResult where F: Send + Clone, { @@ -1925,111 +2093,424 @@ where payment_data.payment_attempt.merchant_connector_id.clone(), ) .change_context(errors::ApiErrorResponse::InternalServerError) - .attach_printable("invalid connector name received in payment attempt")?; + .attach_printable("Invalid connector name received in 'routed_through'")?; - return Ok(api::ConnectorCallType::Single(connector_data)); + routing_data.routed_through = Some(connector_name.clone()); + return Ok(api::ConnectorCallType::PreDetermined(connector_data)); } - let request_straight_through = request_straight_through - .map(|val| val.parse_value::("StraightThroughAlgorithm")) - .transpose() + if let Some(mandate_connector_details) = payment_data.mandate_connector.as_ref() { + let connector_data = api::ConnectorData::get_connector_by_name( + &state.conf.connectors, + &mandate_connector_details.connector, + api::GetToken::Connector, + #[cfg(feature = "connector_choice_mca_id")] + mandate_connector_details.merchant_connector_id.clone(), + #[cfg(not(feature = "connector_choice_mca_id"))] + None, + ) .change_context(errors::ApiErrorResponse::InternalServerError) - .attach_printable("Invalid straight through routing rules format") - .transpose(); + .attach_printable("Invalid connector name received in 'routed_through'")?; + + routing_data.routed_through = Some(mandate_connector_details.connector.clone()); + #[cfg(feature = "connector_choice_mca_id")] + { + routing_data.merchant_connector_id = + mandate_connector_details.merchant_connector_id.clone(); + } + return Ok(api::ConnectorCallType::PreDetermined(connector_data)); + } + + if let Some((pre_routing_results, storage_pm_type)) = routing_data + .routing_info + .pre_routing_results + .as_ref() + .zip(payment_data.payment_attempt.payment_method_type.as_ref()) + { + if let Some(choice) = pre_routing_results.get(storage_pm_type) { + let connector_data = api::ConnectorData::get_connector_by_name( + &state.conf.connectors, + &choice.connector.to_string(), + api::GetToken::Connector, + #[cfg(feature = "connector_choice_mca_id")] + choice.merchant_connector_id.clone(), + #[cfg(not(feature = "connector_choice_mca_id"))] + None, + ) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Invalid connector name received")?; + + routing_data.routed_through = Some(choice.connector.to_string()); + #[cfg(feature = "connector_choice_mca_id")] + { + routing_data.merchant_connector_id = choice.merchant_connector_id.clone(); + } + #[cfg(not(feature = "connector_choice_mca_id"))] + { + routing_data.business_sub_label = choice.sub_label.clone(); + } + return Ok(api::ConnectorCallType::PreDetermined(connector_data)); + } + } + + if let Some(routing_algorithm) = request_straight_through { + let (mut connectors, check_eligibility) = + routing::perform_straight_through_routing(&routing_algorithm, payment_data) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Failed execution of straight through routing")?; + + if check_eligibility { + connectors = routing::perform_eligibility_analysis_with_fallback( + &state.clone(), + key_store, + merchant_account.modified_at.assume_utc().unix_timestamp(), + connectors, + payment_data, + eligible_connectors, + #[cfg(feature = "business_profile_routing")] + payment_data.payment_intent.profile_id.clone(), + ) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("failed eligibility analysis and fallback")?; + } + + let first_connector_choice = connectors + .first() + .ok_or(errors::ApiErrorResponse::IncorrectPaymentMethodConfiguration) + .into_report() + .attach_printable("Empty connector list returned")? + .clone(); - let payment_routing_algorithm = request_straight_through.or(payment_data + let connector_data = connectors + .into_iter() + .map(|conn| { + api::ConnectorData::get_connector_by_name( + &state.conf.connectors, + &conn.connector.to_string(), + api::GetToken::Connector, + #[cfg(feature = "connector_choice_mca_id")] + conn.merchant_connector_id.clone(), + #[cfg(not(feature = "connector_choice_mca_id"))] + None, + ) + }) + .collect::, _>>() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Invalid connector name received")?; + + routing_data.routed_through = Some(first_connector_choice.connector.to_string()); + #[cfg(feature = "connector_choice_mca_id")] + { + routing_data.merchant_connector_id = first_connector_choice.merchant_connector_id; + } + #[cfg(not(feature = "connector_choice_mca_id"))] + { + routing_data.business_sub_label = first_connector_choice.sub_label.clone(); + } + routing_data.routing_info.algorithm = Some(routing_algorithm); + return Ok(api::ConnectorCallType::Retryable(connector_data)); + } + + if let Some(ref routing_algorithm) = routing_data.routing_info.algorithm { + let (mut connectors, check_eligibility) = + routing::perform_straight_through_routing(routing_algorithm, payment_data) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Failed execution of straight through routing")?; + + if check_eligibility { + connectors = routing::perform_eligibility_analysis_with_fallback( + &state, + key_store, + merchant_account.modified_at.assume_utc().unix_timestamp(), + connectors, + payment_data, + eligible_connectors, + #[cfg(feature = "business_profile_routing")] + payment_data.payment_intent.profile_id.clone(), + ) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("failed eligibility analysis and fallback")?; + } + + let first_connector_choice = connectors + .first() + .ok_or(errors::ApiErrorResponse::IncorrectPaymentMethodConfiguration) + .into_report() + .attach_printable("Empty connector list returned")? + .clone(); + + let connector_data = connectors + .into_iter() + .map(|conn| { + api::ConnectorData::get_connector_by_name( + &state.conf.connectors, + &conn.connector.to_string(), + api::GetToken::Connector, + #[cfg(feature = "connector_choice_mca_id")] + conn.merchant_connector_id, + #[cfg(not(feature = "connector_choice_mca_id"))] + None, + ) + }) + .collect::, _>>() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Invalid connector name received")?; + + routing_data.routed_through = Some(first_connector_choice.connector.to_string()); + #[cfg(feature = "connector_choice_mca_id")] + { + routing_data.merchant_connector_id = first_connector_choice.merchant_connector_id; + } + #[cfg(not(feature = "connector_choice_mca_id"))] + { + routing_data.business_sub_label = first_connector_choice.sub_label; + } + return Ok(api::ConnectorCallType::Retryable(connector_data)); + } + + route_connector_v1( + &state, + merchant_account, + key_store, + payment_data, + routing_data, + eligible_connectors, + ) + .await +} + +pub fn should_add_task_to_process_tracker(payment_data: &PaymentData) -> bool { + let connector = payment_data.payment_attempt.connector.as_deref(); + + !matches!( + (payment_data.payment_attempt.payment_method, connector), + ( + Some(storage_enums::PaymentMethod::BankTransfer), + Some("stripe") + ) + ) +} + +pub async fn perform_session_token_routing( + state: AppState, + merchant_account: &domain::MerchantAccount, + key_store: &domain::MerchantKeyStore, + payment_data: &mut PaymentData, + connectors: Vec, +) -> RouterResult> +where + F: Clone, +{ + let routing_info: Option = payment_data .payment_attempt .straight_through_algorithm .clone() - .map(|val| val.parse_value::("RoutingAlgorithm")) + .map(|val| val.parse_value("PaymentRoutingInfo")) .transpose() .change_context(errors::ApiErrorResponse::InternalServerError) - .attach_printable("Invalid straight through algorithm format in payment attempt") - .transpose()); + .attach_printable("invalid payment routing info format found in payment attempt")?; - let routing_algorithm = payment_routing_algorithm - .or(merchant_account - .routing_algorithm - .clone() - .map(|merchant_routing_algorithm| { - merchant_routing_algorithm - .parse_value::("RoutingAlgorithm") - .change_context(errors::ApiErrorResponse::InternalServerError) // Deserialization failed - .attach_printable("Unable to deserialize merchant routing algorithm") - })) - .get_required_value("RoutingAlgorithm") - .change_context(errors::ApiErrorResponse::PreconditionFailed { - message: "no routing algorithm has been configured".to_string(), - })??; + if let Some(storage::PaymentRoutingInfo { + pre_routing_results: Some(pre_routing_results), + .. + }) = routing_info + { + let mut payment_methods: rustc_hash::FxHashMap< + (String, enums::PaymentMethodType), + api::SessionConnectorData, + > = rustc_hash::FxHashMap::from_iter(connectors.iter().map(|c| { + ( + ( + c.connector.connector_name.to_string(), + c.payment_method_type, + ), + c.clone(), + ) + })); - let mut routing_data = storage::RoutingData { - routed_through: payment_data.payment_attempt.connector.clone(), - algorithm: Some(routing_algorithm), - }; + let mut final_list: Vec = Vec::new(); + for (routed_pm_type, choice) in pre_routing_results.into_iter() { + if let Some(session_connector_data) = + payment_methods.remove(&(choice.to_string(), routed_pm_type)) + { + final_list.push(session_connector_data); + } + } - let (decided_connector, connector_id) = decide_connector(state, &mut routing_data)?; + if !final_list.is_empty() { + return Ok(final_list); + } + } - let encoded_algorithm = routing_data - .algorithm - .map(|algo| Encode::::encode_to_value(&algo)) - .transpose() + let routing_enabled_pms = std::collections::HashSet::from([ + enums::PaymentMethodType::GooglePay, + enums::PaymentMethodType::ApplePay, + enums::PaymentMethodType::Klarna, + enums::PaymentMethodType::Paypal, + ]); + + let mut chosen = Vec::::new(); + for connector_data in &connectors { + if routing_enabled_pms.contains(&connector_data.payment_method_type) { + chosen.push(connector_data.clone()); + } + } + let sfr = SessionFlowRoutingInput { + state: &state, + country: payment_data + .address + .billing + .as_ref() + .and_then(|address| address.address.as_ref()) + .and_then(|details| details.country), + key_store, + merchant_account, + payment_attempt: &payment_data.payment_attempt, + payment_intent: &payment_data.payment_intent, + + chosen, + }; + let result = self_routing::perform_session_flow_routing(sfr) + .await .change_context(errors::ApiErrorResponse::InternalServerError) - .attach_printable("Unable to serialize routing algorithm to serde value")?; + .attach_printable("error performing session flow routing")?; + + let mut final_list: Vec = Vec::new(); + + #[cfg(not(feature = "connector_choice_mca_id"))] + for mut connector_data in connectors { + if !routing_enabled_pms.contains(&connector_data.payment_method_type) { + final_list.push(connector_data); + } else if let Some(choice) = result.get(&connector_data.payment_method_type) { + if connector_data.connector.connector_name == choice.connector.connector_name { + connector_data.business_sub_label = choice.sub_label.clone(); + final_list.push(connector_data); + } + } + } - payment_data.payment_attempt.connector = routing_data.routed_through; - payment_data.payment_attempt.straight_through_algorithm = encoded_algorithm; - payment_data.payment_attempt.merchant_connector_id = connector_id; + #[cfg(feature = "connector_choice_mca_id")] + for connector_data in connectors { + if !routing_enabled_pms.contains(&connector_data.payment_method_type) { + final_list.push(connector_data); + } else if let Some(choice) = result.get(&connector_data.payment_method_type) { + if connector_data.connector.connector_name == choice.connector.connector_name { + final_list.push(connector_data); + } + } + } - Ok(decided_connector) + Ok(final_list) } -pub fn decide_connector( +pub async fn route_connector_v1( state: &AppState, + merchant_account: &domain::MerchantAccount, + key_store: &domain::MerchantKeyStore, + payment_data: &mut PaymentData, routing_data: &mut storage::RoutingData, -) -> RouterResult<(api::ConnectorCallType, Option)> { - let routing_algorithm = routing_data - .algorithm + eligible_connectors: Option>, +) -> RouterResult +where + F: Send + Clone, +{ + #[cfg(not(feature = "business_profile_routing"))] + let algorithm_ref: api::routing::RoutingAlgorithmRef = merchant_account + .routing_algorithm .clone() - .get_required_value("Routing algorithm")?; + .map(|ra| ra.parse_value("RoutingAlgorithmRef")) + .transpose() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Could not decode merchant routing algorithm ref")? + .unwrap_or_default(); - let (connector_name, merchant_connector_id) = match routing_algorithm { - api::StraightThroughAlgorithm::Single(routable_connector_choice) => { - match routable_connector_choice { - api_models::admin::RoutableConnectorChoice::ConnectorName(routable_connector) => { - (routable_connector.to_string(), None) - } - api_models::admin::RoutableConnectorChoice::ConnectorId { - merchant_connector_id, - connector, - } => (connector.to_string(), Some(merchant_connector_id)), - } - } + #[cfg(feature = "business_profile_routing")] + let algorithm_ref: api::routing::RoutingAlgorithmRef = { + let profile_id = payment_data + .payment_intent + .profile_id + .as_ref() + .get_required_value("profile_id") + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("'profile_id' not set in payment intent")?; + + let business_profile = state + .store + .find_business_profile_by_profile_id(profile_id) + .await + .to_not_found_response(errors::ApiErrorResponse::BusinessProfileNotFound { + id: profile_id.to_string(), + })?; + + business_profile + .routing_algorithm + .clone() + .map(|ra| ra.parse_value("RoutingAlgorithmRef")) + .transpose() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Could not decode merchant routing algorithm ref")? + .unwrap_or_default() }; - let connector_data = api::ConnectorData::get_connector_by_name( - &state.conf.connectors, - &connector_name, - api::GetToken::Connector, - merchant_connector_id.clone(), + let connectors = routing::perform_static_routing_v1( + state, + &merchant_account.merchant_id, + algorithm_ref, + payment_data, + ) + .await + .change_context(errors::ApiErrorResponse::InternalServerError)?; + + let connectors = routing::perform_eligibility_analysis_with_fallback( + &state.clone(), + key_store, + merchant_account.modified_at.assume_utc().unix_timestamp(), + connectors, + payment_data, + eligible_connectors, + #[cfg(feature = "business_profile_routing")] + payment_data.payment_intent.profile_id.clone(), ) + .await .change_context(errors::ApiErrorResponse::InternalServerError) - .attach_printable("Invalid connector name received in routing algorithm")?; + .attach_printable("failed eligibility analysis and fallback")?; - routing_data.routed_through = Some(connector_name); - Ok(( - api::ConnectorCallType::Single(connector_data), - merchant_connector_id, - )) -} + let first_connector_choice = connectors + .first() + .ok_or(errors::ApiErrorResponse::IncorrectPaymentMethodConfiguration) + .into_report() + .attach_printable("Empty connector list returned")? + .clone(); -pub fn should_add_task_to_process_tracker(payment_data: &PaymentData) -> bool { - let connector = payment_data.payment_attempt.connector.as_deref(); + routing_data.routed_through = Some(first_connector_choice.connector.to_string()); - !matches!( - (payment_data.payment_attempt.payment_method, connector), - ( - Some(storage_enums::PaymentMethod::BankTransfer), - Some("stripe") - ) - ) + #[cfg(feature = "connector_choice_mca_id")] + { + routing_data.merchant_connector_id = first_connector_choice.merchant_connector_id; + } + #[cfg(not(feature = "connector_choice_mca_id"))] + { + routing_data.business_sub_label = first_connector_choice.sub_label; + } + + let connector_data = connectors + .into_iter() + .map(|conn| { + api::ConnectorData::get_connector_by_name( + &state.conf.connectors, + &conn.connector.to_string(), + api::GetToken::Connector, + #[cfg(feature = "connector_choice_mca_id")] + conn.merchant_connector_id, + #[cfg(not(feature = "connector_choice_mca_id"))] + None, + ) + }) + .collect::, _>>() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Invalid connector name received")?; + + Ok(ConnectorCallType::Retryable(connector_data)) } diff --git a/crates/router/src/core/payments/routing.rs b/crates/router/src/core/payments/routing.rs new file mode 100644 index 000000000000..4134ddf65ea0 --- /dev/null +++ b/crates/router/src/core/payments/routing.rs @@ -0,0 +1,950 @@ +mod transformers; + +use std::{ + collections::hash_map, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use api_models::{ + admin as admin_api, + enums::{self as api_enums, CountryAlpha2}, + routing::ConnectorSelection, +}; +use common_utils::static_cache::StaticCache; +use diesel_models::enums as storage_enums; +use error_stack::{IntoReport, ResultExt}; +use euclid::{ + backend::{self, inputs as dsl_inputs, EuclidBackend}, + dssa::graph::{self as euclid_graph, Memoization}, + enums as euclid_enums, + frontend::ast, +}; +use kgraph_utils::{ + mca as mca_graph, + transformers::{IntoContext, IntoDirValue}, +}; +use masking::PeekInterface; +use rand::{ + distributions::{self, Distribution}, + SeedableRng, +}; +use rustc_hash::FxHashMap; + +#[cfg(not(feature = "business_profile_routing"))] +use crate::utils::StringExt; +use crate::{ + core::{ + errors as oss_errors, errors, payments as payments_oss, routing::helpers as routing_helpers, + }, + logger, + types::{ + api, api::routing as routing_types, domain, storage as oss_storage, + transformers::ForeignInto, + }, + utils::{OptionExt, ValueExt}, + AppState, +}; + +pub(super) enum CachedAlgorithm { + Single(Box), + Priority(Vec), + VolumeSplit(Vec), + Advanced(backend::VirInterpreterBackend), +} + +pub struct SessionFlowRoutingInput<'a> { + pub state: &'a AppState, + pub country: Option, + pub key_store: &'a domain::MerchantKeyStore, + pub merchant_account: &'a domain::MerchantAccount, + pub payment_attempt: &'a oss_storage::PaymentAttempt, + pub payment_intent: &'a oss_storage::PaymentIntent, + pub chosen: Vec, +} + +pub struct SessionRoutingPmTypeInput<'a> { + state: &'a AppState, + key_store: &'a domain::MerchantKeyStore, + merchant_last_modified: i64, + attempt_id: &'a str, + routing_algorithm: &'a MerchantAccountRoutingAlgorithm, + backend_input: dsl_inputs::BackendInput, + allowed_connectors: FxHashMap, + #[cfg(feature = "business_profile_routing")] + profile_id: Option, +} +static ROUTING_CACHE: StaticCache = StaticCache::new(); +static KGRAPH_CACHE: StaticCache> = StaticCache::new(); + +type RoutingResult = oss_errors::CustomResult; + +#[derive(Debug, serde::Serialize, serde::Deserialize)] +#[serde(untagged)] +enum MerchantAccountRoutingAlgorithm { + V1(routing_types::RoutingAlgorithmRef), +} + +impl Default for MerchantAccountRoutingAlgorithm { + fn default() -> Self { + Self::V1(routing_types::RoutingAlgorithmRef::default()) + } +} + +pub fn make_dsl_input( + payment_data: &payments_oss::PaymentData, +) -> RoutingResult +where + F: Clone, +{ + let mandate_data = dsl_inputs::MandateData { + mandate_acceptance_type: payment_data + .setup_mandate + .as_ref() + .and_then(|mandate_data| { + mandate_data + .customer_acceptance + .clone() + .map(|cat| match cat.acceptance_type { + data_models::mandates::AcceptanceType::Online => { + euclid_enums::MandateAcceptanceType::Online + } + data_models::mandates::AcceptanceType::Offline => { + euclid_enums::MandateAcceptanceType::Offline + } + }) + }), + mandate_type: payment_data + .setup_mandate + .as_ref() + .and_then(|mandate_data| { + mandate_data.mandate_type.clone().map(|mt| match mt { + data_models::mandates::MandateDataType::SingleUse(_) => { + euclid_enums::MandateType::SingleUse + } + data_models::mandates::MandateDataType::MultiUse(_) => { + euclid_enums::MandateType::MultiUse + } + }) + }), + payment_type: Some(payment_data.setup_mandate.clone().map_or_else( + || euclid_enums::PaymentType::NonMandate, + |_| euclid_enums::PaymentType::SetupMandate, + )), + }; + let payment_method_input = dsl_inputs::PaymentMethodInput { + payment_method: payment_data.payment_attempt.payment_method, + payment_method_type: payment_data.payment_attempt.payment_method_type, + card_network: payment_data + .payment_method_data + .as_ref() + .and_then(|pm_data| match pm_data { + api::PaymentMethodData::Card(card) => card.card_network.clone(), + + _ => None, + }), + }; + + let payment_input = dsl_inputs::PaymentInput { + amount: payment_data.payment_intent.amount, + card_bin: payment_data + .payment_method_data + .as_ref() + .and_then(|pm_data| match pm_data { + api::PaymentMethodData::Card(card) => { + Some(card.card_number.peek().chars().take(6).collect()) + } + _ => None, + }), + currency: payment_data.currency, + authentication_type: payment_data.payment_attempt.authentication_type, + capture_method: payment_data + .payment_attempt + .capture_method + .and_then(|cm| cm.foreign_into()), + business_country: payment_data + .payment_intent + .business_country + .map(api_enums::Country::from_alpha2), + billing_country: payment_data + .address + .billing + .as_ref() + .and_then(|bic| bic.address.as_ref()) + .and_then(|add| add.country) + .map(api_enums::Country::from_alpha2), + business_label: payment_data.payment_intent.business_label.clone(), + setup_future_usage: payment_data.payment_intent.setup_future_usage, + }; + + let metadata = payment_data + .payment_intent + .metadata + .clone() + .map(|val| val.parse_value("routing_parameters")) + .transpose() + .change_context(errors::RoutingError::MetadataParsingError) + .attach_printable("Unable to parse routing_parameters from metadata of payment_intent") + .unwrap_or_else(|err| { + logger::error!(error=?err); + None + }); + + Ok(dsl_inputs::BackendInput { + metadata, + payment: payment_input, + payment_method: payment_method_input, + mandate: mandate_data, + }) +} + +pub async fn perform_static_routing_v1( + state: &AppState, + merchant_id: &str, + algorithm_ref: routing_types::RoutingAlgorithmRef, + payment_data: &mut payments_oss::PaymentData, +) -> RoutingResult> { + let algorithm_id = if let Some(id) = algorithm_ref.algorithm_id { + id + } else { + let fallback_config = + routing_helpers::get_merchant_default_config(&*state.clone().store, merchant_id) + .await + .change_context(errors::RoutingError::FallbackConfigFetchFailed)?; + + return Ok(fallback_config); + }; + let key = ensure_algorithm_cached_v1( + state, + merchant_id, + algorithm_ref.timestamp, + &algorithm_id, + #[cfg(feature = "business_profile_routing")] + payment_data.payment_intent.profile_id.clone(), + ) + .await?; + let cached_algorithm: Arc = ROUTING_CACHE + .retrieve(&key) + .into_report() + .change_context(errors::RoutingError::CacheMiss) + .attach_printable("Unable to retrieve cached routing algorithm even after refresh")?; + + Ok(match cached_algorithm.as_ref() { + CachedAlgorithm::Single(conn) => vec![(**conn).clone()], + + CachedAlgorithm::Priority(plist) => plist.clone(), + + CachedAlgorithm::VolumeSplit(splits) => perform_volume_split(splits.to_vec(), None) + .change_context(errors::RoutingError::ConnectorSelectionFailed)?, + + CachedAlgorithm::Advanced(interpreter) => { + let backend_input = make_dsl_input(payment_data)?; + + execute_dsl_and_get_connector_v1(backend_input, interpreter)? + } + }) +} + +async fn ensure_algorithm_cached_v1( + state: &AppState, + merchant_id: &str, + timestamp: i64, + algorithm_id: &str, + #[cfg(feature = "business_profile_routing")] profile_id: Option, +) -> RoutingResult { + #[cfg(feature = "business_profile_routing")] + let key = { + let profile_id = profile_id + .clone() + .get_required_value("profile_id") + .change_context(errors::RoutingError::ProfileIdMissing)?; + + format!("routing_config_{merchant_id}_{profile_id}") + }; + + #[cfg(not(feature = "business_profile_routing"))] + let key = format!("dsl_{merchant_id}"); + + let present = ROUTING_CACHE + .present(&key) + .into_report() + .change_context(errors::RoutingError::DslCachePoisoned) + .attach_printable("Error checking presence of DSL")?; + + let expired = ROUTING_CACHE + .expired(&key, timestamp) + .into_report() + .change_context(errors::RoutingError::DslCachePoisoned) + .attach_printable("Error checking expiry of DSL in cache")?; + + if !present || expired { + refresh_routing_cache_v1( + state, + key.clone(), + algorithm_id, + timestamp, + #[cfg(feature = "business_profile_routing")] + profile_id, + ) + .await?; + }; + + Ok(key) +} + +pub fn perform_straight_through_routing( + algorithm: &routing_types::StraightThroughAlgorithm, + payment_data: &payments_oss::PaymentData, +) -> RoutingResult<(Vec, bool)> { + Ok(match algorithm { + routing_types::StraightThroughAlgorithm::Single(conn) => ( + vec![(**conn).clone()], + payment_data.creds_identifier.is_none(), + ), + + routing_types::StraightThroughAlgorithm::Priority(conns) => (conns.clone(), true), + + routing_types::StraightThroughAlgorithm::VolumeSplit(splits) => ( + perform_volume_split(splits.to_vec(), None) + .change_context(errors::RoutingError::ConnectorSelectionFailed) + .attach_printable( + "Volume Split connector selection error in straight through routing", + )?, + true, + ), + }) +} + +fn execute_dsl_and_get_connector_v1( + backend_input: dsl_inputs::BackendInput, + interpreter: &backend::VirInterpreterBackend, +) -> RoutingResult> { + let routing_output: routing_types::RoutingAlgorithm = interpreter + .execute(backend_input) + .map(|out| out.connector_selection.foreign_into()) + .into_report() + .change_context(errors::RoutingError::DslExecutionError)?; + + Ok(match routing_output { + routing_types::RoutingAlgorithm::Priority(plist) => plist, + + routing_types::RoutingAlgorithm::VolumeSplit(splits) => perform_volume_split(splits, None) + .change_context(errors::RoutingError::DslFinalConnectorSelectionFailed)?, + + _ => Err(errors::RoutingError::DslIncorrectSelectionAlgorithm) + .into_report() + .attach_printable("Unsupported algorithm received as a result of static routing")?, + }) +} + +pub async fn refresh_routing_cache_v1( + state: &AppState, + key: String, + algorithm_id: &str, + timestamp: i64, + #[cfg(feature = "business_profile_routing")] profile_id: Option, +) -> RoutingResult<()> { + #[cfg(feature = "business_profile_routing")] + let algorithm = { + let algorithm = state + .store + .find_routing_algorithm_by_profile_id_algorithm_id( + &profile_id.unwrap_or_default(), + algorithm_id, + ) + .await + .change_context(errors::RoutingError::DslMissingInDb)?; + let algorithm: routing_types::RoutingAlgorithm = algorithm + .algorithm_data + .parse_value("RoutingAlgorithm") + .change_context(errors::RoutingError::DslParsingError)?; + algorithm + }; + + #[cfg(not(feature = "business_profile_routing"))] + let algorithm = { + let config = state + .store + .find_config_by_key(algorithm_id) + .await + .change_context(errors::RoutingError::DslMissingInDb) + .attach_printable("DSL not found in DB")?; + + let algorithm: routing_types::RoutingAlgorithm = config + .config + .parse_struct("Program") + .change_context(errors::RoutingError::DslParsingError) + .attach_printable("Error parsing routing algorithm from configs")?; + algorithm + }; + let cached_algorithm = match algorithm { + routing_types::RoutingAlgorithm::Single(conn) => CachedAlgorithm::Single(conn), + routing_types::RoutingAlgorithm::Priority(plist) => CachedAlgorithm::Priority(plist), + routing_types::RoutingAlgorithm::VolumeSplit(splits) => { + CachedAlgorithm::VolumeSplit(splits) + } + routing_types::RoutingAlgorithm::Advanced(program) => { + let interpreter = backend::VirInterpreterBackend::with_program(program) + .into_report() + .change_context(errors::RoutingError::DslBackendInitError) + .attach_printable("Error initializing DSL interpreter backend")?; + + CachedAlgorithm::Advanced(interpreter) + } + }; + + ROUTING_CACHE + .save(key, cached_algorithm, timestamp) + .into_report() + .change_context(errors::RoutingError::DslCachePoisoned) + .attach_printable("Error saving DSL to cache")?; + + Ok(()) +} + +pub fn perform_volume_split( + mut splits: Vec, + rng_seed: Option<&str>, +) -> RoutingResult> { + let weights: Vec = splits.iter().map(|sp| sp.split).collect(); + let weighted_index = distributions::WeightedIndex::new(weights) + .into_report() + .change_context(errors::RoutingError::VolumeSplitFailed) + .attach_printable("Error creating weighted distribution for volume split")?; + + let idx = if let Some(seed) = rng_seed { + let mut hasher = hash_map::DefaultHasher::new(); + seed.hash(&mut hasher); + let hash = hasher.finish(); + + let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(hash); + weighted_index.sample(&mut rng) + } else { + let mut rng = rand::thread_rng(); + weighted_index.sample(&mut rng) + }; + + splits + .get(idx) + .ok_or(errors::RoutingError::VolumeSplitFailed) + .into_report() + .attach_printable("Volume split index lookup failed")?; + + // Panic Safety: We have performed a `get(idx)` operation just above which will + // ensure that the index is always present, else throw an error. + let removed = splits.remove(idx); + splits.insert(0, removed); + + Ok(splits.into_iter().map(|sp| sp.connector).collect()) +} + +pub async fn get_merchant_kgraph<'a>( + state: &AppState, + key_store: &domain::MerchantKeyStore, + merchant_last_modified: i64, + #[cfg(feature = "business_profile_routing")] profile_id: Option, +) -> RoutingResult>> { + #[cfg(feature = "business_profile_routing")] + let key = { + let profile_id = profile_id + .clone() + .get_required_value("profile_id") + .change_context(errors::RoutingError::ProfileIdMissing)?; + + format!("kgraph_{}_{profile_id}", key_store.merchant_id) + }; + + #[cfg(not(feature = "business_profile_routing"))] + let key = format!("kgraph_{}", key_store.merchant_id); + + let kgraph_present = KGRAPH_CACHE + .present(&key) + .into_report() + .change_context(errors::RoutingError::KgraphCacheFailure) + .attach_printable("when checking kgraph presence")?; + + let kgraph_expired = KGRAPH_CACHE + .expired(&key, merchant_last_modified) + .into_report() + .change_context(errors::RoutingError::KgraphCacheFailure) + .attach_printable("when checking kgraph expiry")?; + + if !kgraph_present || kgraph_expired { + refresh_kgraph_cache( + state, + key_store, + merchant_last_modified, + key.clone(), + #[cfg(feature = "business_profile_routing")] + profile_id, + ) + .await?; + } + + let cached_kgraph = KGRAPH_CACHE + .retrieve(&key) + .into_report() + .change_context(errors::RoutingError::CacheMiss) + .attach_printable("when retrieving kgraph")?; + + Ok(cached_kgraph) +} + +pub async fn refresh_kgraph_cache( + state: &AppState, + key_store: &domain::MerchantKeyStore, + timestamp: i64, + key: String, + #[cfg(feature = "business_profile_routing")] profile_id: Option, +) -> RoutingResult<()> { + let mut merchant_connector_accounts = state + .store + .find_merchant_connector_account_by_merchant_id_and_disabled_list( + &key_store.merchant_id, + false, + key_store, + ) + .await + .change_context(errors::RoutingError::KgraphCacheRefreshFailed)?; + + merchant_connector_accounts + .retain(|mca| mca.connector_type != storage_enums::ConnectorType::PaymentVas); + + #[cfg(feature = "business_profile_routing")] + let merchant_connector_accounts = payments_oss::helpers::filter_mca_based_on_business_profile( + merchant_connector_accounts, + profile_id, + ); + + let api_mcas: Vec = merchant_connector_accounts + .into_iter() + .map(|acct| acct.try_into()) + .collect::>() + .change_context(errors::RoutingError::KgraphCacheRefreshFailed)?; + + let kgraph = mca_graph::make_mca_graph(api_mcas) + .into_report() + .change_context(errors::RoutingError::KgraphCacheRefreshFailed) + .attach_printable("when construction kgraph")?; + + KGRAPH_CACHE + .save(key, kgraph, timestamp) + .into_report() + .change_context(errors::RoutingError::KgraphCacheRefreshFailed) + .attach_printable("when saving kgraph to cache")?; + + Ok(()) +} + +async fn perform_kgraph_filtering( + state: &AppState, + key_store: &domain::MerchantKeyStore, + merchant_last_modified: i64, + chosen: Vec, + backend_input: dsl_inputs::BackendInput, + eligible_connectors: Option<&Vec>, + #[cfg(feature = "business_profile_routing")] profile_id: Option, +) -> RoutingResult> { + let context = euclid_graph::AnalysisContext::from_dir_values( + backend_input + .into_context() + .into_report() + .change_context(errors::RoutingError::KgraphAnalysisError)?, + ); + let cached_kgraph = get_merchant_kgraph( + state, + key_store, + merchant_last_modified, + #[cfg(feature = "business_profile_routing")] + profile_id, + ) + .await?; + + let mut final_selection = Vec::::new(); + for choice in chosen { + let routable_connector = choice.connector; + let euclid_choice: ast::ConnectorChoice = choice.clone().foreign_into(); + let dir_val = euclid_choice + .into_dir_value() + .into_report() + .change_context(errors::RoutingError::KgraphAnalysisError)?; + let kgraph_eligible = cached_kgraph + .check_value_validity(dir_val, &context, &mut Memoization::new()) + .into_report() + .change_context(errors::RoutingError::KgraphAnalysisError)?; + + let filter_eligible = + eligible_connectors.map_or(true, |list| list.contains(&routable_connector)); + + if kgraph_eligible && filter_eligible { + final_selection.push(choice); + } + } + + Ok(final_selection) +} + +pub async fn perform_eligibility_analysis( + state: &AppState, + key_store: &domain::MerchantKeyStore, + merchant_last_modified: i64, + chosen: Vec, + payment_data: &payments_oss::PaymentData, + eligible_connectors: Option<&Vec>, + #[cfg(feature = "business_profile_routing")] profile_id: Option, +) -> RoutingResult> { + let backend_input = make_dsl_input(payment_data)?; + + perform_kgraph_filtering( + state, + key_store, + merchant_last_modified, + chosen, + backend_input, + eligible_connectors, + #[cfg(feature = "business_profile_routing")] + profile_id, + ) + .await +} + +pub async fn perform_fallback_routing( + state: &AppState, + key_store: &domain::MerchantKeyStore, + merchant_last_modified: i64, + payment_data: &payments_oss::PaymentData, + eligible_connectors: Option<&Vec>, + #[cfg(feature = "business_profile_routing")] profile_id: Option, +) -> RoutingResult> { + let fallback_config = + routing_helpers::get_merchant_default_config(&*state.store, &key_store.merchant_id) + .await + .change_context(errors::RoutingError::FallbackConfigFetchFailed)?; + let backend_input = make_dsl_input(payment_data)?; + + perform_kgraph_filtering( + state, + key_store, + merchant_last_modified, + fallback_config, + backend_input, + eligible_connectors, + #[cfg(feature = "business_profile_routing")] + profile_id, + ) + .await +} + +pub async fn perform_eligibility_analysis_with_fallback( + state: &AppState, + key_store: &domain::MerchantKeyStore, + merchant_last_modified: i64, + chosen: Vec, + payment_data: &payments_oss::PaymentData, + eligible_connectors: Option>, + #[cfg(feature = "business_profile_routing")] profile_id: Option, +) -> RoutingResult> { + let mut final_selection = perform_eligibility_analysis( + state, + key_store, + merchant_last_modified, + chosen, + payment_data, + eligible_connectors.as_ref(), + #[cfg(feature = "business_profile_routing")] + profile_id.clone(), + ) + .await?; + + let fallback_selection = perform_fallback_routing( + state, + key_store, + merchant_last_modified, + payment_data, + eligible_connectors.as_ref(), + #[cfg(feature = "business_profile_routing")] + profile_id, + ) + .await; + + final_selection.append( + &mut fallback_selection + .unwrap_or_default() + .iter() + .filter(|&routable_connector_choice| { + !final_selection.contains(routable_connector_choice) + }) + .cloned() + .collect::>(), + ); + + let final_selected_connectors = final_selection + .iter() + .map(|item| item.connector) + .collect::>(); + logger::debug!(final_selected_connectors_for_routing=?final_selected_connectors, "List of final selected connectors for routing"); + + Ok(final_selection) +} + +pub async fn perform_session_flow_routing( + session_input: SessionFlowRoutingInput<'_>, +) -> RoutingResult> { + let mut pm_type_map: FxHashMap> = + FxHashMap::default(); + let merchant_last_modified = session_input + .merchant_account + .modified_at + .assume_utc() + .unix_timestamp(); + + #[cfg(feature = "business_profile_routing")] + let routing_algorithm: MerchantAccountRoutingAlgorithm = { + let profile_id = session_input + .payment_intent + .profile_id + .clone() + .get_required_value("profile_id") + .change_context(errors::RoutingError::ProfileIdMissing)?; + + let business_profile = session_input + .state + .store + .find_business_profile_by_profile_id(&profile_id) + .await + .change_context(errors::RoutingError::ProfileNotFound)?; + + business_profile + .routing_algorithm + .clone() + .map(|val| val.parse_value("MerchantAccountRoutingAlgorithm")) + .transpose() + .change_context(errors::RoutingError::InvalidRoutingAlgorithmStructure)? + .unwrap_or_default() + }; + + #[cfg(not(feature = "business_profile_routing"))] + let routing_algorithm: MerchantAccountRoutingAlgorithm = { + session_input + .merchant_account + .routing_algorithm + .clone() + .map(|val| val.parse_value("MerchantAccountRoutingAlgorithm")) + .transpose() + .change_context(errors::RoutingError::InvalidRoutingAlgorithmStructure)? + .unwrap_or_default() + }; + + let payment_method_input = dsl_inputs::PaymentMethodInput { + payment_method: None, + payment_method_type: None, + card_network: None, + }; + + let payment_input = dsl_inputs::PaymentInput { + amount: session_input.payment_intent.amount, + currency: session_input + .payment_intent + .currency + .get_required_value("Currency") + .change_context(errors::RoutingError::DslMissingRequiredField { + field_name: "currency".to_string(), + })?, + authentication_type: session_input.payment_attempt.authentication_type, + card_bin: None, + capture_method: session_input + .payment_attempt + .capture_method + .and_then(|cm| cm.foreign_into()), + business_country: session_input + .payment_intent + .business_country + .map(api_enums::Country::from_alpha2), + billing_country: session_input + .country + .map(storage_enums::Country::from_alpha2), + business_label: session_input.payment_intent.business_label.clone(), + setup_future_usage: session_input.payment_intent.setup_future_usage, + }; + + let metadata = session_input + .payment_intent + .metadata + .clone() + .map(|val| val.parse_value("routing_parameters")) + .transpose() + .change_context(errors::RoutingError::MetadataParsingError) + .attach_printable("Unable to parse routing_parameters from metadata of payment_intent") + .unwrap_or_else(|err| { + logger::error!(?err); + None + }); + + let mut backend_input = dsl_inputs::BackendInput { + metadata, + payment: payment_input, + payment_method: payment_method_input, + mandate: dsl_inputs::MandateData { + mandate_acceptance_type: None, + mandate_type: None, + payment_type: None, + }, + }; + + for connector_data in session_input.chosen.iter() { + pm_type_map + .entry(connector_data.payment_method_type) + .or_default() + .insert( + connector_data.connector.connector_name.to_string(), + connector_data.connector.get_token.clone(), + ); + } + + let mut result: FxHashMap = + FxHashMap::default(); + + for (pm_type, allowed_connectors) in pm_type_map { + let euclid_pmt: euclid_enums::PaymentMethodType = pm_type; + let euclid_pm: euclid_enums::PaymentMethod = euclid_pmt.into(); + + backend_input.payment_method.payment_method = Some(euclid_pm); + backend_input.payment_method.payment_method_type = Some(euclid_pmt); + + let session_pm_input = SessionRoutingPmTypeInput { + state: session_input.state, + key_store: session_input.key_store, + merchant_last_modified, + attempt_id: &session_input.payment_attempt.attempt_id, + routing_algorithm: &routing_algorithm, + backend_input: backend_input.clone(), + allowed_connectors, + #[cfg(feature = "business_profile_routing")] + profile_id: session_input.payment_intent.clone().profile_id, + }; + let maybe_choice = perform_session_routing_for_pm_type(session_pm_input).await?; + + // (connector, sub_label) + if let Some(data) = maybe_choice { + result.insert( + pm_type, + routing_types::SessionRoutingChoice { + connector: data.0, + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label: data.1, + payment_method_type: pm_type, + }, + ); + } + } + + Ok(result) +} + +async fn perform_session_routing_for_pm_type( + session_pm_input: SessionRoutingPmTypeInput<'_>, +) -> RoutingResult)>> { + let merchant_id = &session_pm_input.key_store.merchant_id; + + let chosen_connectors = match session_pm_input.routing_algorithm { + MerchantAccountRoutingAlgorithm::V1(algorithm_ref) => { + if let Some(ref algorithm_id) = algorithm_ref.algorithm_id { + let key = ensure_algorithm_cached_v1( + &session_pm_input.state.clone(), + merchant_id, + algorithm_ref.timestamp, + algorithm_id, + #[cfg(feature = "business_profile_routing")] + session_pm_input.profile_id.clone(), + ) + .await?; + + let cached_algorithm = ROUTING_CACHE + .retrieve(&key) + .into_report() + .change_context(errors::RoutingError::CacheMiss) + .attach_printable("unable to retrieve cached routing algorithm")?; + + match cached_algorithm.as_ref() { + CachedAlgorithm::Single(conn) => vec![(**conn).clone()], + CachedAlgorithm::Priority(plist) => plist.clone(), + CachedAlgorithm::VolumeSplit(splits) => { + perform_volume_split(splits.to_vec(), Some(session_pm_input.attempt_id)) + .change_context(errors::RoutingError::ConnectorSelectionFailed)? + } + CachedAlgorithm::Advanced(interpreter) => execute_dsl_and_get_connector_v1( + session_pm_input.backend_input.clone(), + interpreter, + )?, + } + } else { + routing_helpers::get_merchant_default_config( + &*session_pm_input.state.clone().store, + merchant_id, + ) + .await + .change_context(errors::RoutingError::FallbackConfigFetchFailed)? + } + } + }; + + let mut final_selection = perform_kgraph_filtering( + &session_pm_input.state.clone(), + session_pm_input.key_store, + session_pm_input.merchant_last_modified, + chosen_connectors, + session_pm_input.backend_input.clone(), + None, + #[cfg(feature = "business_profile_routing")] + session_pm_input.profile_id.clone(), + ) + .await?; + + if final_selection.is_empty() { + let fallback = routing_helpers::get_merchant_default_config( + &*session_pm_input.state.clone().store, + merchant_id, + ) + .await + .change_context(errors::RoutingError::FallbackConfigFetchFailed)?; + + final_selection = perform_kgraph_filtering( + &session_pm_input.state.clone(), + session_pm_input.key_store, + session_pm_input.merchant_last_modified, + fallback, + session_pm_input.backend_input, + None, + #[cfg(feature = "business_profile_routing")] + session_pm_input.profile_id.clone(), + ) + .await?; + } + + let mut final_choice: Option<(api::ConnectorData, Option)> = None; + + for selection in final_selection { + let connector_name = selection.connector.to_string(); + if let Some(get_token) = session_pm_input.allowed_connectors.get(&connector_name) { + let connector_data = api::ConnectorData::get_connector_by_name( + &session_pm_input.state.clone().conf.connectors, + &connector_name, + get_token.clone(), + #[cfg(feature = "connector_choice_mca_id")] + selection.merchant_connector_id, + #[cfg(not(feature = "connector_choice_mca_id"))] + None, + ) + .change_context(errors::RoutingError::InvalidConnectorName(connector_name))?; + #[cfg(not(feature = "connector_choice_mca_id"))] + let sub_label = selection.sub_label; + #[cfg(feature = "connector_choice_mca_id")] + let sub_label = None; + + final_choice = Some((connector_data, sub_label)); + break; + } + } + + Ok(final_choice) +} diff --git a/crates/router/src/core/payments/routing/transformers.rs b/crates/router/src/core/payments/routing/transformers.rs new file mode 100644 index 000000000000..de94a36248ff --- /dev/null +++ b/crates/router/src/core/payments/routing/transformers.rs @@ -0,0 +1,121 @@ +use api_models::{self, enums as api_enums, routing as routing_types}; +use diesel_models::enums as storage_enums; +use euclid::{enums as dsl_enums, frontend::ast as dsl_ast}; + +use crate::types::transformers::{ForeignFrom, ForeignInto}; + +impl ForeignFrom for dsl_ast::ConnectorChoice { + fn foreign_from(from: routing_types::RoutableConnectorChoice) -> Self { + Self { + // #[cfg(feature = "backwards_compatibility")] + // choice_kind: from.choice_kind.foreign_into(), + connector: from.connector.foreign_into(), + #[cfg(not(feature = "connector_choice_mca_id"))] + sub_label: from.sub_label, + } + } +} + +impl ForeignFrom for Option { + fn foreign_from(value: storage_enums::CaptureMethod) -> Self { + match value { + storage_enums::CaptureMethod::Automatic => Some(dsl_enums::CaptureMethod::Automatic), + storage_enums::CaptureMethod::Manual => Some(dsl_enums::CaptureMethod::Manual), + _ => None, + } + } +} + +impl ForeignFrom for dsl_enums::MandateAcceptanceType { + fn foreign_from(from: api_models::payments::AcceptanceType) -> Self { + match from { + api_models::payments::AcceptanceType::Online => Self::Online, + api_models::payments::AcceptanceType::Offline => Self::Offline, + } + } +} + +impl ForeignFrom for dsl_enums::MandateType { + fn foreign_from(from: api_models::payments::MandateType) -> Self { + match from { + api_models::payments::MandateType::MultiUse(_) => Self::MultiUse, + api_models::payments::MandateType::SingleUse(_) => Self::SingleUse, + } + } +} + +impl ForeignFrom for dsl_enums::MandateType { + fn foreign_from(from: storage_enums::MandateDataType) -> Self { + match from { + storage_enums::MandateDataType::MultiUse(_) => Self::MultiUse, + storage_enums::MandateDataType::SingleUse(_) => Self::SingleUse, + } + } +} + +impl ForeignFrom for dsl_enums::Connector { + fn foreign_from(from: api_enums::RoutableConnectors) -> Self { + match from { + #[cfg(feature = "dummy_connector")] + api_enums::RoutableConnectors::DummyConnector1 => Self::DummyConnector1, + #[cfg(feature = "dummy_connector")] + api_enums::RoutableConnectors::DummyConnector2 => Self::DummyConnector2, + #[cfg(feature = "dummy_connector")] + api_enums::RoutableConnectors::DummyConnector3 => Self::DummyConnector3, + #[cfg(feature = "dummy_connector")] + api_enums::RoutableConnectors::DummyConnector4 => Self::DummyConnector4, + #[cfg(feature = "dummy_connector")] + api_enums::RoutableConnectors::DummyConnector5 => Self::DummyConnector5, + #[cfg(feature = "dummy_connector")] + api_enums::RoutableConnectors::DummyConnector6 => Self::DummyConnector6, + #[cfg(feature = "dummy_connector")] + api_enums::RoutableConnectors::DummyConnector7 => Self::DummyConnector7, + api_enums::RoutableConnectors::Aci => Self::Aci, + api_enums::RoutableConnectors::Adyen => Self::Adyen, + api_enums::RoutableConnectors::Airwallex => Self::Airwallex, + api_enums::RoutableConnectors::Authorizedotnet => Self::Authorizedotnet, + api_enums::RoutableConnectors::Bitpay => Self::Bitpay, + api_enums::RoutableConnectors::Bambora => Self::Bambora, + api_enums::RoutableConnectors::Bluesnap => Self::Bluesnap, + api_enums::RoutableConnectors::Boku => Self::Boku, + api_enums::RoutableConnectors::Braintree => Self::Braintree, + api_enums::RoutableConnectors::Cashtocode => Self::Cashtocode, + api_enums::RoutableConnectors::Checkout => Self::Checkout, + api_enums::RoutableConnectors::Coinbase => Self::Coinbase, + api_enums::RoutableConnectors::Cryptopay => Self::Cryptopay, + api_enums::RoutableConnectors::Cybersource => Self::Cybersource, + api_enums::RoutableConnectors::Dlocal => Self::Dlocal, + api_enums::RoutableConnectors::Fiserv => Self::Fiserv, + api_enums::RoutableConnectors::Forte => Self::Forte, + api_enums::RoutableConnectors::Globalpay => Self::Globalpay, + api_enums::RoutableConnectors::Globepay => Self::Globepay, + api_enums::RoutableConnectors::Gocardless => Self::Gocardless, + api_enums::RoutableConnectors::Helcim => Self::Helcim, + api_enums::RoutableConnectors::Iatapay => Self::Iatapay, + api_enums::RoutableConnectors::Klarna => Self::Klarna, + api_enums::RoutableConnectors::Mollie => Self::Mollie, + api_enums::RoutableConnectors::Multisafepay => Self::Multisafepay, + api_enums::RoutableConnectors::Nexinets => Self::Nexinets, + api_enums::RoutableConnectors::Nmi => Self::Nmi, + api_enums::RoutableConnectors::Noon => Self::Noon, + api_enums::RoutableConnectors::Nuvei => Self::Nuvei, + api_enums::RoutableConnectors::Opennode => Self::Opennode, + api_enums::RoutableConnectors::Payme => Self::Payme, + api_enums::RoutableConnectors::Paypal => Self::Paypal, + api_enums::RoutableConnectors::Payu => Self::Payu, + api_enums::RoutableConnectors::Powertranz => Self::Powertranz, + api_enums::RoutableConnectors::Rapyd => Self::Rapyd, + api_enums::RoutableConnectors::Shift4 => Self::Shift4, + api_enums::RoutableConnectors::Square => Self::Square, + api_enums::RoutableConnectors::Stax => Self::Stax, + api_enums::RoutableConnectors::Stripe => Self::Stripe, + api_enums::RoutableConnectors::Trustpay => Self::Trustpay, + api_enums::RoutableConnectors::Tsys => Self::Tsys, + api_enums::RoutableConnectors::Volt => Self::Volt, + api_enums::RoutableConnectors::Wise => Self::Wise, + api_enums::RoutableConnectors::Worldline => Self::Worldline, + api_enums::RoutableConnectors::Worldpay => Self::Worldpay, + api_enums::RoutableConnectors::Zen => Self::Zen, + } + } +} diff --git a/crates/router/src/core/routing.rs b/crates/router/src/core/routing.rs new file mode 100644 index 000000000000..8033cc792b54 --- /dev/null +++ b/crates/router/src/core/routing.rs @@ -0,0 +1,713 @@ +pub mod helpers; +pub mod transformers; + +use api_models::routing as routing_types; +#[cfg(feature = "business_profile_routing")] +use api_models::routing::{RoutingRetrieveLinkQuery, RoutingRetrieveQuery}; +#[cfg(not(feature = "business_profile_routing"))] +use common_utils::ext_traits::{Encode, StringExt}; +#[cfg(not(feature = "business_profile_routing"))] +use diesel_models::configs; +#[cfg(feature = "business_profile_routing")] +use diesel_models::routing_algorithm::RoutingAlgorithm; +use error_stack::{IntoReport, ResultExt}; +use rustc_hash::FxHashSet; + +#[cfg(feature = "business_profile_routing")] +use crate::core::utils::validate_and_get_business_profile; +#[cfg(feature = "business_profile_routing")] +use crate::types::transformers::{ForeignInto, ForeignTryInto}; +use crate::{ + consts, + core::errors::{RouterResponse, StorageErrorExt}, + routes::AppState, + types::domain, + utils::{self, OptionExt, ValueExt}, +}; +#[cfg(not(feature = "business_profile_routing"))] +use crate::{core::errors, services::api as service_api, types::storage}; +#[cfg(feature = "business_profile_routing")] +use crate::{errors, services::api as service_api}; + +pub async fn retrieve_merchant_routing_dictionary( + state: AppState, + merchant_account: domain::MerchantAccount, + #[cfg(feature = "business_profile_routing")] query_params: RoutingRetrieveQuery, +) -> RouterResponse { + #[cfg(feature = "business_profile_routing")] + { + let routing_metadata = state + .store + .list_routing_algorithm_metadata_by_merchant_id( + &merchant_account.merchant_id, + i64::from(query_params.limit.unwrap_or_default()), + i64::from(query_params.offset.unwrap_or_default()), + ) + .await + .to_not_found_response(errors::ApiErrorResponse::ResourceIdNotFound)?; + let result = routing_metadata + .into_iter() + .map(ForeignInto::foreign_into) + .collect::>(); + + Ok(service_api::ApplicationResponse::Json( + routing_types::RoutingKind::RoutingAlgorithm(result), + )) + } + #[cfg(not(feature = "business_profile_routing"))] + Ok(service_api::ApplicationResponse::Json( + routing_types::RoutingKind::Config( + helpers::get_merchant_routing_dictionary( + state.store.as_ref(), + &merchant_account.merchant_id, + ) + .await?, + ), + )) +} + +pub async fn create_routing_config( + state: AppState, + merchant_account: domain::MerchantAccount, + key_store: domain::MerchantKeyStore, + request: routing_types::RoutingConfigRequest, +) -> RouterResponse { + let db = state.store.as_ref(); + + let name = request + .name + .get_required_value("name") + .change_context(errors::ApiErrorResponse::MissingRequiredField { field_name: "name" }) + .attach_printable("Name of config not given")?; + + let description = request + .description + .get_required_value("description") + .change_context(errors::ApiErrorResponse::MissingRequiredField { + field_name: "description", + }) + .attach_printable("Description of config not given")?; + + let algorithm = request + .algorithm + .get_required_value("algorithm") + .change_context(errors::ApiErrorResponse::MissingRequiredField { + field_name: "algorithm", + }) + .attach_printable("Algorithm of config not given")?; + + let algorithm_id = common_utils::generate_id( + consts::ROUTING_CONFIG_ID_LENGTH, + &format!("routing_{}", &merchant_account.merchant_id), + ); + + #[cfg(feature = "business_profile_routing")] + { + let profile_id = request + .profile_id + .get_required_value("profile_id") + .change_context(errors::ApiErrorResponse::MissingRequiredField { + field_name: "profile_id", + }) + .attach_printable("Profile_id not provided")?; + + validate_and_get_business_profile(db, Some(&profile_id), &merchant_account.merchant_id) + .await?; + + helpers::validate_connectors_in_routing_config( + db, + &key_store, + &merchant_account.merchant_id, + &profile_id, + &algorithm, + ) + .await?; + + let timestamp = common_utils::date_time::now(); + let algo = RoutingAlgorithm { + algorithm_id: algorithm_id.clone(), + profile_id, + merchant_id: merchant_account.merchant_id, + name: name.clone(), + description: Some(description.clone()), + kind: algorithm.get_kind().foreign_into(), + algorithm_data: serde_json::json!(algorithm), + created_at: timestamp, + modified_at: timestamp, + }; + let record = db + .insert_routing_algorithm(algo) + .await + .to_not_found_response(errors::ApiErrorResponse::ResourceIdNotFound)?; + + let new_record = record.foreign_into(); + + Ok(service_api::ApplicationResponse::Json(new_record)) + } + + #[cfg(not(feature = "business_profile_routing"))] + { + let algorithm_str = + utils::Encode::::encode_to_string_of_json(&algorithm) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Unable to serialize routing algorithm to string")?; + + let mut algorithm_ref: routing_types::RoutingAlgorithmRef = merchant_account + .routing_algorithm + .clone() + .map(|val| val.parse_value("RoutingAlgorithmRef")) + .transpose() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("unable to deserialize routing algorithm ref from merchant account")? + .unwrap_or_default(); + let mut merchant_dictionary = + helpers::get_merchant_routing_dictionary(db, &merchant_account.merchant_id).await?; + + utils::when( + merchant_dictionary.records.len() >= consts::MAX_ROUTING_CONFIGS_PER_MERCHANT, + || { + Err(errors::ApiErrorResponse::PreconditionFailed { + message: format!("Reached the maximum number of routing configs ({}), please delete some to create new ones", consts::MAX_ROUTING_CONFIGS_PER_MERCHANT), + }) + .into_report() + }, + )?; + let timestamp = common_utils::date_time::now_unix_timestamp(); + let records_are_empty = merchant_dictionary.records.is_empty(); + + let new_record = routing_types::RoutingDictionaryRecord { + id: algorithm_id.clone(), + name: name.clone(), + kind: algorithm.get_kind(), + description: description.clone(), + created_at: timestamp, + modified_at: timestamp, + }; + merchant_dictionary.records.push(new_record.clone()); + + let new_algorithm_config = configs::ConfigNew { + key: algorithm_id.clone(), + config: algorithm_str, + }; + + db.insert_config(new_algorithm_config) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Failed to save new routing algorithm config to DB")?; + + if records_are_empty { + merchant_dictionary.active_id = Some(algorithm_id.clone()); + algorithm_ref.update_algorithm_id(algorithm_id); + helpers::update_merchant_active_algorithm_ref(db, &key_store, algorithm_ref).await?; + } + + helpers::update_merchant_routing_dictionary( + db, + &merchant_account.merchant_id, + merchant_dictionary, + ) + .await?; + + Ok(service_api::ApplicationResponse::Json(new_record)) + } +} + +pub async fn link_routing_config( + state: AppState, + merchant_account: domain::MerchantAccount, + #[cfg(not(feature = "business_profile_routing"))] key_store: domain::MerchantKeyStore, + algorithm_id: String, +) -> RouterResponse { + let db = state.store.as_ref(); + #[cfg(feature = "business_profile_routing")] + { + let routing_algorithm = db + .find_routing_algorithm_by_algorithm_id_merchant_id( + &algorithm_id, + &merchant_account.merchant_id, + ) + .await + .change_context(errors::ApiErrorResponse::ResourceIdNotFound)?; + + let business_profile = validate_and_get_business_profile( + db, + Some(&routing_algorithm.profile_id), + &merchant_account.merchant_id, + ) + .await? + .get_required_value("BusinessProfile") + .change_context(errors::ApiErrorResponse::BusinessProfileNotFound { + id: routing_algorithm.profile_id.clone(), + })?; + + let mut routing_ref: routing_types::RoutingAlgorithmRef = business_profile + .routing_algorithm + .clone() + .map(|val| val.parse_value("RoutingAlgorithmRef")) + .transpose() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("unable to deserialize routing algorithm ref from merchant account")? + .unwrap_or_default(); + + utils::when( + routing_ref.algorithm_id == Some(algorithm_id.clone()), + || { + Err(errors::ApiErrorResponse::PreconditionFailed { + message: "Algorithm is already active".to_string(), + }) + .into_report() + }, + )?; + + routing_ref.update_algorithm_id(algorithm_id); + helpers::update_business_profile_active_algorithm_ref(db, business_profile, routing_ref) + .await?; + + Ok(service_api::ApplicationResponse::Json( + routing_algorithm.foreign_into(), + )) + } + + #[cfg(not(feature = "business_profile_routing"))] + { + let mut routing_ref: routing_types::RoutingAlgorithmRef = merchant_account + .routing_algorithm + .clone() + .map(|val| val.parse_value("RoutingAlgorithmRef")) + .transpose() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("unable to deserialize routing algorithm ref from merchant account")? + .unwrap_or_default(); + + utils::when( + routing_ref.algorithm_id == Some(algorithm_id.clone()), + || { + Err(errors::ApiErrorResponse::PreconditionFailed { + message: "Algorithm is already active".to_string(), + }) + .into_report() + }, + )?; + let mut merchant_dictionary = + helpers::get_merchant_routing_dictionary(db, &merchant_account.merchant_id).await?; + + let modified_at = common_utils::date_time::now_unix_timestamp(); + let record = merchant_dictionary + .records + .iter_mut() + .find(|rec| rec.id == algorithm_id) + .ok_or(errors::ApiErrorResponse::ResourceIdNotFound) + .into_report() + .attach_printable("Record with given ID not found for routing config activation")?; + + record.modified_at = modified_at; + merchant_dictionary.active_id = Some(record.id.clone()); + let response = record.clone(); + routing_ref.update_algorithm_id(algorithm_id); + helpers::update_merchant_routing_dictionary( + db, + &merchant_account.merchant_id, + merchant_dictionary, + ) + .await?; + helpers::update_merchant_active_algorithm_ref(db, &key_store, routing_ref).await?; + + Ok(service_api::ApplicationResponse::Json(response)) + } +} + +pub async fn retrieve_routing_config( + state: AppState, + merchant_account: domain::MerchantAccount, + algorithm_id: String, +) -> RouterResponse { + let db = state.store.as_ref(); + #[cfg(feature = "business_profile_routing")] + { + let routing_algorithm = db + .find_routing_algorithm_by_algorithm_id_merchant_id( + &algorithm_id, + &merchant_account.merchant_id, + ) + .await + .to_not_found_response(errors::ApiErrorResponse::ResourceIdNotFound)?; + + validate_and_get_business_profile( + db, + Some(&routing_algorithm.profile_id), + &merchant_account.merchant_id, + ) + .await? + .get_required_value("BusinessProfile") + .change_context(errors::ApiErrorResponse::ResourceIdNotFound)?; + + let response = routing_algorithm + .foreign_try_into() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("unable to parse routing algorithm")?; + Ok(service_api::ApplicationResponse::Json(response)) + } + + #[cfg(not(feature = "business_profile_routing"))] + { + let merchant_dictionary = + helpers::get_merchant_routing_dictionary(db, &merchant_account.merchant_id).await?; + + let record = merchant_dictionary + .records + .into_iter() + .find(|rec| rec.id == algorithm_id) + .ok_or(errors::ApiErrorResponse::ResourceIdNotFound) + .into_report() + .attach_printable("Algorithm with the given ID not found in the merchant dictionary")?; + + let algorithm_config = db + .find_config_by_key(&algorithm_id) + .await + .change_context(errors::ApiErrorResponse::ResourceIdNotFound) + .attach_printable("Routing config not found in DB")?; + + let algorithm: routing_types::RoutingAlgorithm = algorithm_config + .config + .parse_struct("RoutingAlgorithm") + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Error deserializing routing algorithm config")?; + + let response = routing_types::MerchantRoutingAlgorithm { + id: record.id, + name: record.name, + description: record.description, + algorithm, + created_at: record.created_at, + modified_at: record.modified_at, + }; + + Ok(service_api::ApplicationResponse::Json(response)) + } +} +pub async fn unlink_routing_config( + state: AppState, + merchant_account: domain::MerchantAccount, + #[cfg(not(feature = "business_profile_routing"))] key_store: domain::MerchantKeyStore, + #[cfg(feature = "business_profile_routing")] request: routing_types::RoutingConfigRequest, +) -> RouterResponse { + let db = state.store.as_ref(); + #[cfg(feature = "business_profile_routing")] + { + let profile_id = request + .profile_id + .get_required_value("profile_id") + .change_context(errors::ApiErrorResponse::MissingRequiredField { + field_name: "profile_id", + }) + .attach_printable("Profile_id not provided")?; + let business_profile = + validate_and_get_business_profile(db, Some(&profile_id), &merchant_account.merchant_id) + .await?; + match business_profile { + Some(business_profile) => { + let routing_algo_ref: routing_types::RoutingAlgorithmRef = business_profile + .routing_algorithm + .clone() + .map(|val| val.parse_value("RoutingAlgorithmRef")) + .transpose() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable( + "unable to deserialize routing algorithm ref from merchant account", + )? + .unwrap_or_default(); + + let timestamp = common_utils::date_time::now_unix_timestamp(); + + match routing_algo_ref.algorithm_id { + Some(algorithm_id) => { + let routing_algorithm: routing_types::RoutingAlgorithmRef = + routing_types::RoutingAlgorithmRef { + algorithm_id: None, + timestamp, + config_algo_id: routing_algo_ref.config_algo_id.clone(), + surcharge_config_algo_id: routing_algo_ref.surcharge_config_algo_id, + }; + + let record = db + .find_routing_algorithm_by_profile_id_algorithm_id( + &profile_id, + &algorithm_id, + ) + .await + .to_not_found_response(errors::ApiErrorResponse::ResourceIdNotFound)?; + let response = record.foreign_into(); + helpers::update_business_profile_active_algorithm_ref( + db, + business_profile, + routing_algorithm, + ) + .await?; + Ok(service_api::ApplicationResponse::Json(response)) + } + None => Err(errors::ApiErrorResponse::PreconditionFailed { + message: "Algorithm is already inactive".to_string(), + }) + .into_report()?, + } + } + None => Err(errors::ApiErrorResponse::InvalidRequestData { + message: "The business_profile is not present".to_string(), + } + .into()), + } + } + + #[cfg(not(feature = "business_profile_routing"))] + { + let mut merchant_dictionary = + helpers::get_merchant_routing_dictionary(db, &merchant_account.merchant_id).await?; + + let routing_algo_ref: routing_types::RoutingAlgorithmRef = merchant_account + .routing_algorithm + .clone() + .map(|val| val.parse_value("RoutingAlgorithmRef")) + .transpose() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("unable to deserialize routing algorithm ref from merchant account")? + .unwrap_or_default(); + let timestamp = common_utils::date_time::now_unix_timestamp(); + + utils::when(routing_algo_ref.algorithm_id.is_none(), || { + Err(errors::ApiErrorResponse::PreconditionFailed { + message: "Algorithm is already inactive".to_string(), + }) + .into_report() + })?; + let routing_algorithm: routing_types::RoutingAlgorithmRef = + routing_types::RoutingAlgorithmRef { + algorithm_id: None, + timestamp, + config_algo_id: routing_algo_ref.config_algo_id.clone(), + surcharge_config_algo_id: routing_algo_ref.surcharge_config_algo_id, + }; + + let active_algorithm_id = merchant_dictionary + .active_id + .or(routing_algo_ref.algorithm_id.clone()) + .ok_or(errors::ApiErrorResponse::PreconditionFailed { + // When the merchant_dictionary doesn't have any active algorithm and merchant_account doesn't have any routing_algorithm configured + message: "Algorithm is already inactive".to_string(), + }) + .into_report()?; + + let record = merchant_dictionary + .records + .iter_mut() + .find(|rec| rec.id == active_algorithm_id) + .ok_or(errors::ApiErrorResponse::ResourceIdNotFound) + .into_report() + .attach_printable("Record with the given ID not found for de-activation")?; + + let response = record.clone(); + + merchant_dictionary.active_id = None; + + helpers::update_merchant_routing_dictionary( + db, + &merchant_account.merchant_id, + merchant_dictionary, + ) + .await?; + + let ref_value = + Encode::::encode_to_value(&routing_algorithm) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Failed converting routing algorithm ref to json value")?; + + let merchant_account_update = storage::MerchantAccountUpdate::Update { + merchant_name: None, + merchant_details: None, + return_url: None, + webhook_details: None, + sub_merchants_enabled: None, + parent_merchant_id: None, + enable_payment_response_hash: None, + payment_response_hash_key: None, + redirect_to_merchant_with_http_post: None, + publishable_key: None, + locker_id: None, + metadata: None, + routing_algorithm: Some(ref_value), + primary_business_details: None, + intent_fulfillment_time: None, + frm_routing_algorithm: None, + payout_routing_algorithm: None, + default_profile: None, + payment_link_config: None, + }; + + db.update_specific_fields_in_merchant( + &key_store.merchant_id, + merchant_account_update, + &key_store, + ) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Failed to update routing algorithm ref in merchant account")?; + + Ok(service_api::ApplicationResponse::Json(response)) + } +} + +pub async fn update_default_routing_config( + state: AppState, + merchant_account: domain::MerchantAccount, + updated_config: Vec, +) -> RouterResponse> { + let db = state.store.as_ref(); + let default_config = + helpers::get_merchant_default_config(db, &merchant_account.merchant_id).await?; + + utils::when(default_config.len() != updated_config.len(), || { + Err(errors::ApiErrorResponse::PreconditionFailed { + message: "current config and updated config have different lengths".to_string(), + }) + .into_report() + })?; + + let existing_set: FxHashSet = + FxHashSet::from_iter(default_config.iter().map(|c| c.to_string())); + let updated_set: FxHashSet = + FxHashSet::from_iter(updated_config.iter().map(|c| c.to_string())); + + let symmetric_diff: Vec = existing_set + .symmetric_difference(&updated_set) + .cloned() + .collect(); + + utils::when(!symmetric_diff.is_empty(), || { + Err(errors::ApiErrorResponse::InvalidRequestData { + message: format!( + "connector mismatch between old and new configs ({})", + symmetric_diff.join(", ") + ), + }) + .into_report() + })?; + + helpers::update_merchant_default_config( + db, + &merchant_account.merchant_id, + updated_config.clone(), + ) + .await?; + + Ok(service_api::ApplicationResponse::Json(updated_config)) +} + +pub async fn retrieve_default_routing_config( + state: AppState, + merchant_account: domain::MerchantAccount, +) -> RouterResponse> { + let db = state.store.as_ref(); + + helpers::get_merchant_default_config(db, &merchant_account.merchant_id) + .await + .map(service_api::ApplicationResponse::Json) +} + +pub async fn retrieve_linked_routing_config( + state: AppState, + merchant_account: domain::MerchantAccount, + #[cfg(feature = "business_profile_routing")] query_params: RoutingRetrieveLinkQuery, +) -> RouterResponse { + let db = state.store.as_ref(); + + #[cfg(feature = "business_profile_routing")] + { + let business_profiles = if let Some(profile_id) = query_params.profile_id { + validate_and_get_business_profile(db, Some(&profile_id), &merchant_account.merchant_id) + .await? + .map(|profile| vec![profile]) + .get_required_value("BusinessProfile") + .change_context(errors::ApiErrorResponse::BusinessProfileNotFound { + id: profile_id, + })? + } else { + db.list_business_profile_by_merchant_id(&merchant_account.merchant_id) + .await + .to_not_found_response(errors::ApiErrorResponse::ResourceIdNotFound)? + }; + + let mut active_algorithms = Vec::new(); + + for business_profile in business_profiles { + let routing_ref: routing_types::RoutingAlgorithmRef = business_profile + .routing_algorithm + .clone() + .map(|val| val.parse_value("RoutingAlgorithmRef")) + .transpose() + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable( + "unable to deserialize routing algorithm ref from merchant account", + )? + .unwrap_or_default(); + + if let Some(algorithm_id) = routing_ref.algorithm_id { + let record = db + .find_routing_algorithm_metadata_by_algorithm_id_profile_id( + &algorithm_id, + &business_profile.profile_id, + ) + .await + .to_not_found_response(errors::ApiErrorResponse::ResourceIdNotFound)?; + + active_algorithms.push(record.foreign_into()); + } + } + + Ok(service_api::ApplicationResponse::Json( + routing_types::LinkedRoutingConfigRetrieveResponse::ProfileBased(active_algorithms), + )) + } + #[cfg(not(feature = "business_profile_routing"))] + { + let merchant_dictionary = + helpers::get_merchant_routing_dictionary(db, &merchant_account.merchant_id).await?; + + let algorithm = if let Some(algorithm_id) = merchant_dictionary.active_id { + let record = merchant_dictionary + .records + .into_iter() + .find(|rec| rec.id == algorithm_id) + .ok_or(errors::ApiErrorResponse::ResourceIdNotFound) + .into_report() + .attach_printable("record for active algorithm not found in merchant dictionary")?; + + let config = db + .find_config_by_key(&algorithm_id) + .await + .to_not_found_response(errors::ApiErrorResponse::InternalServerError) + .attach_printable("error finding routing config in db")?; + + let the_algorithm: routing_types::RoutingAlgorithm = config + .config + .parse_struct("RoutingAlgorithm") + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("unable to parse routing algorithm")?; + + Some(routing_types::MerchantRoutingAlgorithm { + id: record.id, + name: record.name, + description: record.description, + algorithm: the_algorithm, + created_at: record.created_at, + modified_at: record.modified_at, + }) + } else { + None + }; + + let response = routing_types::LinkedRoutingConfigRetrieveResponse::MerchantAccountBased( + routing_types::RoutingRetrieveResponse { algorithm }, + ); + + Ok(service_api::ApplicationResponse::Json(response)) + } +} diff --git a/crates/router/src/core/routing/helpers.rs b/crates/router/src/core/routing/helpers.rs new file mode 100644 index 000000000000..6eec39f53bc6 --- /dev/null +++ b/crates/router/src/core/routing/helpers.rs @@ -0,0 +1,479 @@ +//! Analysis for usage of all helper functions for use case of routing +//! +//! Functions that are used to perform the retrieval of merchant's +//! routing dict, configs, defaults +use api_models::routing as routing_types; +use common_utils::ext_traits::Encode; +use diesel_models::{ + business_profile::{BusinessProfile, BusinessProfileUpdateInternal}, + configs, +}; +use error_stack::ResultExt; +use rustc_hash::FxHashSet; + +use crate::{ + core::errors::{self, RouterResult}, + db::StorageInterface, + types::{domain, storage}, + utils::{self, StringExt}, +}; + +/// provides the complete merchant routing dictionary that is basically a list of all the routing +/// configs a merchant configured with an active_id field that specifies the current active routing +/// config +pub async fn get_merchant_routing_dictionary( + db: &dyn StorageInterface, + merchant_id: &str, +) -> RouterResult { + let key = get_routing_dictionary_key(merchant_id); + let maybe_dict = db.find_config_by_key(&key).await; + + match maybe_dict { + Ok(config) => config + .config + .parse_struct("RoutingDictionary") + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Merchant routing dictionary has invalid structure"), + + Err(e) if e.current_context().is_db_not_found() => { + let new_dictionary = routing_types::RoutingDictionary { + merchant_id: merchant_id.to_string(), + active_id: None, + records: Vec::new(), + }; + + let serialized = + utils::Encode::::encode_to_string_of_json( + &new_dictionary, + ) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Error serializing newly created merchant dictionary")?; + + let new_config = configs::ConfigNew { + key, + config: serialized, + }; + + db.insert_config(new_config) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Error inserting new routing dictionary for merchant")?; + + Ok(new_dictionary) + } + + Err(e) => Err(e) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Error fetching routing dictionary for merchant"), + } +} + +/// Provides us with all the configured configs of the Merchant in the ascending time configured +/// manner and chooses the first of them +pub async fn get_merchant_default_config( + db: &dyn StorageInterface, + merchant_id: &str, +) -> RouterResult> { + let key = get_default_config_key(merchant_id); + let maybe_config = db.find_config_by_key(&key).await; + + match maybe_config { + Ok(config) => config + .config + .parse_struct("Vec") + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Merchant default config has invalid structure"), + + Err(e) if e.current_context().is_db_not_found() => { + let new_config_conns = Vec::::new(); + let serialized = + utils::Encode::>::encode_to_string_of_json( + &new_config_conns, + ) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable( + "Error while creating and serializing new merchant default config", + )?; + + let new_config = configs::ConfigNew { + key, + config: serialized, + }; + + db.insert_config(new_config) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Error inserting new default routing config into DB")?; + + Ok(new_config_conns) + } + + Err(e) => Err(e) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Error fetching default config for merchant"), + } +} + +/// Merchant's already created config can be updated and this change will be reflected +/// in DB as well for the particular updated config +pub async fn update_merchant_default_config( + db: &dyn StorageInterface, + merchant_id: &str, + connectors: Vec, +) -> RouterResult<()> { + let key = get_default_config_key(merchant_id); + let config_str = + Encode::>::encode_to_string_of_json( + &connectors, + ) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Unable to serialize merchant default routing config during update")?; + + let config_update = configs::ConfigUpdate::Update { + config: Some(config_str), + }; + + db.update_config_by_key(&key, config_update) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Error updating the default routing config in DB")?; + + Ok(()) +} + +pub async fn update_merchant_routing_dictionary( + db: &dyn StorageInterface, + merchant_id: &str, + dictionary: routing_types::RoutingDictionary, +) -> RouterResult<()> { + let key = get_routing_dictionary_key(merchant_id); + let dictionary_str = + Encode::::encode_to_string_of_json(&dictionary) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Unable to serialize routing dictionary during update")?; + + let config_update = configs::ConfigUpdate::Update { + config: Some(dictionary_str), + }; + + db.update_config_by_key(&key, config_update) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Error saving routing dictionary to DB")?; + + Ok(()) +} + +pub async fn update_routing_algorithm( + db: &dyn StorageInterface, + algorithm_id: String, + algorithm: routing_types::RoutingAlgorithm, +) -> RouterResult<()> { + let algorithm_str = + Encode::::encode_to_string_of_json(&algorithm) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Unable to serialize routing algorithm to string")?; + + let config_update = configs::ConfigUpdate::Update { + config: Some(algorithm_str), + }; + + db.update_config_by_key(&algorithm_id, config_update) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Error updating the routing algorithm in DB")?; + + Ok(()) +} + +/// This will help make one of all configured algorithms to be in active state for a particular +/// merchant +pub async fn update_merchant_active_algorithm_ref( + db: &dyn StorageInterface, + key_store: &domain::MerchantKeyStore, + algorithm_id: routing_types::RoutingAlgorithmRef, +) -> RouterResult<()> { + let ref_value = Encode::::encode_to_value(&algorithm_id) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Failed converting routing algorithm ref to json value")?; + + let merchant_account_update = storage::MerchantAccountUpdate::Update { + merchant_name: None, + merchant_details: None, + return_url: None, + webhook_details: None, + sub_merchants_enabled: None, + parent_merchant_id: None, + enable_payment_response_hash: None, + payment_response_hash_key: None, + redirect_to_merchant_with_http_post: None, + publishable_key: None, + locker_id: None, + metadata: None, + routing_algorithm: Some(ref_value), + primary_business_details: None, + intent_fulfillment_time: None, + frm_routing_algorithm: None, + payout_routing_algorithm: None, + default_profile: None, + payment_link_config: None, + }; + + db.update_specific_fields_in_merchant( + &key_store.merchant_id, + merchant_account_update, + key_store, + ) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Failed to update routing algorithm ref in merchant account")?; + + Ok(()) +} + +pub async fn update_business_profile_active_algorithm_ref( + db: &dyn StorageInterface, + current_business_profile: BusinessProfile, + algorithm_id: routing_types::RoutingAlgorithmRef, +) -> RouterResult<()> { + let ref_val = Encode::::encode_to_value(&algorithm_id) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Failed to convert routing ref to value")?; + + let business_profile_update = BusinessProfileUpdateInternal { + profile_name: None, + return_url: None, + enable_payment_response_hash: None, + payment_response_hash_key: None, + redirect_to_merchant_with_http_post: None, + webhook_details: None, + metadata: None, + routing_algorithm: Some(ref_val), + intent_fulfillment_time: None, + frm_routing_algorithm: None, + payout_routing_algorithm: None, + applepay_verified_domains: None, + modified_at: None, + is_recon_enabled: None, + }; + db.update_business_profile_by_profile_id(current_business_profile, business_profile_update) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("Failed to update routing algorithm ref in business profile")?; + Ok(()) +} + +pub async fn get_merchant_connector_agnostic_mandate_config( + db: &dyn StorageInterface, + merchant_id: &str, +) -> RouterResult> { + let key = get_pg_agnostic_mandate_config_key(merchant_id); + let maybe_config = db.find_config_by_key(&key).await; + + match maybe_config { + Ok(config) => config + .config + .parse_struct("Vec") + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("pg agnostic mandate config has invalid structure"), + + Err(e) if e.current_context().is_db_not_found() => { + let new_mandate_config: Vec = Vec::new(); + + let serialized = + utils::Encode::>::encode_to_string_of_json( + &new_mandate_config, + ) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("error serializing newly created pg agnostic mandate config")?; + + let new_config = configs::ConfigNew { + key, + config: serialized, + }; + + db.insert_config(new_config) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("error inserting new pg agnostic mandate config in db")?; + + Ok(new_mandate_config) + } + + Err(e) => Err(e) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("error fetching pg agnostic mandate config for merchant from db"), + } +} + +pub async fn update_merchant_connector_agnostic_mandate_config( + db: &dyn StorageInterface, + merchant_id: &str, + mandate_config: Vec, +) -> RouterResult> { + let key = get_pg_agnostic_mandate_config_key(merchant_id); + let mandate_config_str = + Encode::>::encode_to_string_of_json( + &mandate_config, + ) + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("unable to serialize pg agnostic mandate config during update")?; + + let config_update = configs::ConfigUpdate::Update { + config: Some(mandate_config_str), + }; + + db.update_config_by_key(&key, config_update) + .await + .change_context(errors::ApiErrorResponse::InternalServerError) + .attach_printable("error saving pg agnostic mandate config to db")?; + + Ok(mandate_config) +} + +pub async fn validate_connectors_in_routing_config( + db: &dyn StorageInterface, + key_store: &domain::MerchantKeyStore, + merchant_id: &str, + profile_id: &str, + routing_algorithm: &routing_types::RoutingAlgorithm, +) -> RouterResult<()> { + let all_mcas = db + .find_merchant_connector_account_by_merchant_id_and_disabled_list( + merchant_id, + true, + key_store, + ) + .await + .change_context(errors::ApiErrorResponse::MerchantConnectorAccountNotFound { + id: merchant_id.to_string(), + })?; + + #[cfg(feature = "connector_choice_mca_id")] + let name_mca_id_set = all_mcas + .iter() + .filter(|mca| mca.profile_id.as_deref() == Some(profile_id)) + .map(|mca| (&mca.connector_name, &mca.merchant_connector_id)) + .collect::>(); + + let name_set = all_mcas + .iter() + .filter(|mca| mca.profile_id.as_deref() == Some(profile_id)) + .map(|mca| &mca.connector_name) + .collect::>(); + + #[cfg(feature = "connector_choice_mca_id")] + let check_connector_choice = |choice: &routing_types::RoutableConnectorChoice| { + if let Some(ref mca_id) = choice.merchant_connector_id { + error_stack::ensure!( + name_mca_id_set.contains(&(&choice.connector.to_string(), mca_id)), + errors::ApiErrorResponse::InvalidRequestData { + message: format!( + "connector with name '{}' and merchant connector account id '{}' not found for the given profile", + choice.connector, + mca_id, + ) + } + ); + } else { + error_stack::ensure!( + name_set.contains(&choice.connector.to_string()), + errors::ApiErrorResponse::InvalidRequestData { + message: format!( + "connector with name '{}' not found for the given profile", + choice.connector, + ) + } + ); + } + + Ok(()) + }; + + #[cfg(not(feature = "connector_choice_mca_id"))] + let check_connector_choice = |choice: &routing_types::RoutableConnectorChoice| { + error_stack::ensure!( + name_set.contains(&choice.connector.to_string()), + errors::ApiErrorResponse::InvalidRequestData { + message: format!( + "connector with name '{}' not found for the given profile", + choice.connector, + ) + } + ); + + Ok(()) + }; + + match routing_algorithm { + routing_types::RoutingAlgorithm::Single(choice) => { + check_connector_choice(choice)?; + } + + routing_types::RoutingAlgorithm::Priority(list) => { + for choice in list { + check_connector_choice(choice)?; + } + } + + routing_types::RoutingAlgorithm::VolumeSplit(splits) => { + for split in splits { + check_connector_choice(&split.connector)?; + } + } + + routing_types::RoutingAlgorithm::Advanced(program) => { + let check_connector_selection = + |selection: &routing_types::ConnectorSelection| -> RouterResult<()> { + match selection { + routing_types::ConnectorSelection::VolumeSplit(splits) => { + for split in splits { + check_connector_choice(&split.connector)?; + } + } + + routing_types::ConnectorSelection::Priority(list) => { + for choice in list { + check_connector_choice(choice)?; + } + } + } + + Ok(()) + }; + + check_connector_selection(&program.default_selection)?; + + for rule in &program.rules { + check_connector_selection(&rule.connector_selection)?; + } + } + } + + Ok(()) +} + +/// Provides the identifier for the specific merchant's routing_dictionary_key +#[inline(always)] +pub fn get_routing_dictionary_key(merchant_id: &str) -> String { + format!("routing_dict_{merchant_id}") +} + +/// Provides the identifier for the specific merchant's agnostic_mandate_config +#[inline(always)] +pub fn get_pg_agnostic_mandate_config_key(merchant_id: &str) -> String { + format!("pg_agnostic_mandate_{merchant_id}") +} + +/// Provides the identifier for the specific merchant's default_config +#[inline(always)] +pub fn get_default_config_key(merchant_id: &str) -> String { + format!("routing_default_{merchant_id}") +} +pub fn get_payment_config_routing_id(merchant_id: &str) -> String { + format!("payment_config_id_{merchant_id}") +} + +pub fn get_payment_method_surcharge_routing_id(merchant_id: &str) -> String { + format!("payment_method_surcharge_id_{merchant_id}") +} diff --git a/crates/router/src/core/routing/transformers.rs b/crates/router/src/core/routing/transformers.rs new file mode 100644 index 000000000000..e5f1f1e1d5f0 --- /dev/null +++ b/crates/router/src/core/routing/transformers.rs @@ -0,0 +1,86 @@ +use api_models::routing::{ + MerchantRoutingAlgorithm, RoutingAlgorithm as Algorithm, RoutingAlgorithmKind, + RoutingDictionaryRecord, +}; +use common_utils::ext_traits::ValueExt; +use diesel_models::{ + enums as storage_enums, + routing_algorithm::{RoutingAlgorithm, RoutingProfileMetadata}, +}; + +use crate::{ + core::errors, + types::transformers::{ForeignFrom, ForeignInto, ForeignTryFrom}, +}; + +impl ForeignFrom for RoutingDictionaryRecord { + fn foreign_from(value: RoutingProfileMetadata) -> Self { + Self { + id: value.algorithm_id, + #[cfg(feature = "business_profile_routing")] + profile_id: value.profile_id, + name: value.name, + + kind: value.kind.foreign_into(), + description: value.description.unwrap_or_default(), + created_at: value.created_at.assume_utc().unix_timestamp(), + modified_at: value.modified_at.assume_utc().unix_timestamp(), + } + } +} + +impl ForeignFrom for RoutingDictionaryRecord { + fn foreign_from(value: RoutingAlgorithm) -> Self { + Self { + id: value.algorithm_id, + #[cfg(feature = "business_profile_routing")] + profile_id: value.profile_id, + name: value.name, + kind: value.kind.foreign_into(), + description: value.description.unwrap_or_default(), + created_at: value.created_at.assume_utc().unix_timestamp(), + modified_at: value.modified_at.assume_utc().unix_timestamp(), + } + } +} + +impl ForeignTryFrom for MerchantRoutingAlgorithm { + type Error = error_stack::Report; + + fn foreign_try_from(value: RoutingAlgorithm) -> Result { + Ok(Self { + id: value.algorithm_id, + name: value.name, + #[cfg(feature = "business_profile_routing")] + profile_id: value.profile_id, + description: value.description.unwrap_or_default(), + algorithm: value + .algorithm_data + .parse_value::("RoutingAlgorithm")?, + created_at: value.created_at.assume_utc().unix_timestamp(), + modified_at: value.modified_at.assume_utc().unix_timestamp(), + }) + } +} + +impl ForeignFrom for RoutingAlgorithmKind { + fn foreign_from(value: storage_enums::RoutingAlgorithmKind) -> Self { + match value { + storage_enums::RoutingAlgorithmKind::Single => Self::Single, + storage_enums::RoutingAlgorithmKind::Priority => Self::Priority, + storage_enums::RoutingAlgorithmKind::VolumeSplit => Self::VolumeSplit, + storage_enums::RoutingAlgorithmKind::Advanced => Self::Advanced, + } + } +} + +impl ForeignFrom for storage_enums::RoutingAlgorithmKind { + fn foreign_from(value: RoutingAlgorithmKind) -> Self { + match value { + RoutingAlgorithmKind::Single => Self::Single, + RoutingAlgorithmKind::Priority => Self::Priority, + RoutingAlgorithmKind::VolumeSplit => Self::VolumeSplit, + RoutingAlgorithmKind::Advanced => Self::Advanced, + } + } +} diff --git a/crates/router/src/core/webhooks.rs b/crates/router/src/core/webhooks.rs index eb2e19081ff3..8b7df2a14be7 100644 --- a/crates/router/src/core/webhooks.rs +++ b/crates/router/src/core/webhooks.rs @@ -98,6 +98,7 @@ pub async fn payments_incoming_webhook_flow< }, services::AuthFlow::Merchant, consume_or_trigger_flow, + None, HeaderPayload::default(), ) .await; @@ -579,6 +580,7 @@ async fn bank_transfer_webhook_flow Box; diff --git a/crates/router/src/db/routing_algorithm.rs b/crates/router/src/db/routing_algorithm.rs new file mode 100644 index 000000000000..58550b2f01fa --- /dev/null +++ b/crates/router/src/db/routing_algorithm.rs @@ -0,0 +1,199 @@ +use diesel_models::routing_algorithm as routing_storage; +use error_stack::IntoReport; +use storage_impl::mock_db::MockDb; + +use crate::{ + connection, + core::errors::{self, CustomResult}, + services::Store, +}; + +type StorageResult = CustomResult; + +#[async_trait::async_trait] +pub trait RoutingAlgorithmInterface { + async fn insert_routing_algorithm( + &self, + routing_algorithm: routing_storage::RoutingAlgorithm, + ) -> StorageResult; + + async fn find_routing_algorithm_by_profile_id_algorithm_id( + &self, + profile_id: &str, + algorithm_id: &str, + ) -> StorageResult; + + async fn find_routing_algorithm_by_algorithm_id_merchant_id( + &self, + algorithm_id: &str, + merchant_id: &str, + ) -> StorageResult; + + async fn find_routing_algorithm_metadata_by_algorithm_id_profile_id( + &self, + algorithm_id: &str, + profile_id: &str, + ) -> StorageResult; + + async fn list_routing_algorithm_metadata_by_profile_id( + &self, + profile_id: &str, + limit: i64, + offset: i64, + ) -> StorageResult>; + + async fn list_routing_algorithm_metadata_by_merchant_id( + &self, + merchant_id: &str, + limit: i64, + offset: i64, + ) -> StorageResult>; +} + +#[async_trait::async_trait] +impl RoutingAlgorithmInterface for Store { + async fn insert_routing_algorithm( + &self, + routing_algorithm: routing_storage::RoutingAlgorithm, + ) -> StorageResult { + let conn = connection::pg_connection_write(self).await?; + routing_algorithm + .insert(&conn) + .await + .map_err(Into::into) + .into_report() + } + + async fn find_routing_algorithm_by_profile_id_algorithm_id( + &self, + profile_id: &str, + algorithm_id: &str, + ) -> StorageResult { + let conn = connection::pg_connection_write(self).await?; + routing_storage::RoutingAlgorithm::find_by_algorithm_id_profile_id( + &conn, + algorithm_id, + profile_id, + ) + .await + .map_err(Into::into) + .into_report() + } + + async fn find_routing_algorithm_by_algorithm_id_merchant_id( + &self, + algorithm_id: &str, + merchant_id: &str, + ) -> StorageResult { + let conn = connection::pg_connection_write(self).await?; + routing_storage::RoutingAlgorithm::find_by_algorithm_id_merchant_id( + &conn, + algorithm_id, + merchant_id, + ) + .await + .map_err(Into::into) + .into_report() + } + + async fn find_routing_algorithm_metadata_by_algorithm_id_profile_id( + &self, + algorithm_id: &str, + profile_id: &str, + ) -> StorageResult { + let conn = connection::pg_connection_write(self).await?; + routing_storage::RoutingAlgorithm::find_metadata_by_algorithm_id_profile_id( + &conn, + algorithm_id, + profile_id, + ) + .await + .map_err(Into::into) + .into_report() + } + + async fn list_routing_algorithm_metadata_by_profile_id( + &self, + profile_id: &str, + limit: i64, + offset: i64, + ) -> StorageResult> { + let conn = connection::pg_connection_write(self).await?; + routing_storage::RoutingAlgorithm::list_metadata_by_profile_id( + &conn, profile_id, limit, offset, + ) + .await + .map_err(Into::into) + .into_report() + } + + async fn list_routing_algorithm_metadata_by_merchant_id( + &self, + merchant_id: &str, + limit: i64, + offset: i64, + ) -> StorageResult> { + let conn = connection::pg_connection_write(self).await?; + routing_storage::RoutingAlgorithm::list_metadata_by_merchant_id( + &conn, + merchant_id, + limit, + offset, + ) + .await + .map_err(Into::into) + .into_report() + } +} + +#[async_trait::async_trait] +impl RoutingAlgorithmInterface for MockDb { + async fn insert_routing_algorithm( + &self, + _routing_algorithm: routing_storage::RoutingAlgorithm, + ) -> StorageResult { + Err(errors::StorageError::MockDbError)? + } + + async fn find_routing_algorithm_by_profile_id_algorithm_id( + &self, + _profile_id: &str, + _algorithm_id: &str, + ) -> StorageResult { + Err(errors::StorageError::MockDbError)? + } + + async fn find_routing_algorithm_by_algorithm_id_merchant_id( + &self, + _algorithm_id: &str, + _merchant_id: &str, + ) -> StorageResult { + Err(errors::StorageError::MockDbError)? + } + + async fn find_routing_algorithm_metadata_by_algorithm_id_profile_id( + &self, + _algorithm_id: &str, + _profile_id: &str, + ) -> StorageResult { + Err(errors::StorageError::MockDbError)? + } + + async fn list_routing_algorithm_metadata_by_profile_id( + &self, + _profile_id: &str, + _limit: i64, + _offset: i64, + ) -> StorageResult> { + Err(errors::StorageError::MockDbError)? + } + + async fn list_routing_algorithm_metadata_by_merchant_id( + &self, + _merchant_id: &str, + _limit: i64, + _offset: i64, + ) -> StorageResult> { + Err(errors::StorageError::MockDbError)? + } +} diff --git a/crates/router/src/lib.rs b/crates/router/src/lib.rs index 11efec64055b..21ebfc06137b 100644 --- a/crates/router/src/lib.rs +++ b/crates/router/src/lib.rs @@ -141,6 +141,7 @@ pub fn mk_app( .service(routes::ApiKeys::server(state.clone())) .service(routes::Files::server(state.clone())) .service(routes::Disputes::server(state.clone())) + .service(routes::Routing::server(state.clone())) } #[cfg(all(feature = "olap", feature = "kms"))] diff --git a/crates/router/src/routes.rs b/crates/router/src/routes.rs index 307797e8ac9d..38f95c4cdda8 100644 --- a/crates/router/src/routes.rs +++ b/crates/router/src/routes.rs @@ -20,6 +20,8 @@ pub mod payments; #[cfg(feature = "payouts")] pub mod payouts; pub mod refunds; +#[cfg(feature = "olap")] +pub mod routing; #[cfg(all(feature = "olap", feature = "kms"))] pub mod verification; pub mod webhooks; @@ -28,6 +30,8 @@ pub mod webhooks; pub use self::app::DummyConnector; #[cfg(feature = "payouts")] pub use self::app::Payouts; +#[cfg(feature = "olap")] +pub use self::app::Routing; #[cfg(all(feature = "olap", feature = "kms"))] pub use self::app::Verify; pub use self::app::{ diff --git a/crates/router/src/routes/app.rs b/crates/router/src/routes/app.rs index 5b16e93404ae..0369bb612668 100644 --- a/crates/router/src/routes/app.rs +++ b/crates/router/src/routes/app.rs @@ -14,6 +14,8 @@ use tokio::sync::oneshot; use super::dummy_connector::*; #[cfg(feature = "payouts")] use super::payouts::*; +#[cfg(feature = "olap")] +use super::routing as cloud_routing; #[cfg(all(feature = "olap", feature = "kms"))] use super::verification::{apple_pay_merchant_registration, retrieve_apple_pay_verified_domains}; #[cfg(feature = "olap")] @@ -274,6 +276,43 @@ impl Payments { } } +#[cfg(feature = "olap")] +pub struct Routing; + +#[cfg(feature = "olap")] +impl Routing { + pub fn server(state: AppState) -> Scope { + web::scope("/routing") + .app_data(web::Data::new(state.clone())) + .service( + web::resource("/active") + .route(web::get().to(cloud_routing::routing_retrieve_linked_config)), + ) + .service( + web::resource("") + .route(web::get().to(cloud_routing::routing_retrieve_dictionary)) + .route(web::post().to(cloud_routing::routing_create_config)), + ) + .service( + web::resource("/default") + .route(web::get().to(cloud_routing::routing_retrieve_default_config)) + .route(web::post().to(cloud_routing::routing_update_default_config)), + ) + .service( + web::resource("/deactivate") + .route(web::post().to(cloud_routing::routing_unlink_config)), + ) + .service( + web::resource("/{algorithm_id}") + .route(web::get().to(cloud_routing::routing_retrieve_config)), + ) + .service( + web::resource("/{algorithm_id}/activate") + .route(web::post().to(cloud_routing::routing_link_config)), + ) + } +} + pub struct Customers; #[cfg(any(feature = "olap", feature = "oltp"))] diff --git a/crates/router/src/routes/lock_utils.rs b/crates/router/src/routes/lock_utils.rs index 5be361098bcc..14614268d79d 100644 --- a/crates/router/src/routes/lock_utils.rs +++ b/crates/router/src/routes/lock_utils.rs @@ -22,6 +22,7 @@ pub enum ApiIdentifier { Verification, ApiKeys, PaymentLink, + Routing, } impl From for ApiIdentifier { @@ -33,6 +34,17 @@ impl From for ApiIdentifier { | Flow::MerchantsAccountDelete | Flow::MerchantAccountList => Self::MerchantAccount, + Flow::RoutingCreateConfig + | Flow::RoutingLinkConfig + | Flow::RoutingUnlinkConfig + | Flow::RoutingRetrieveConfig + | Flow::RoutingRetrieveActiveConfig + | Flow::RoutingRetrieveDefaultConfig + | Flow::RoutingRetrieveDictionary + | Flow::RoutingUpdateConfig + | Flow::RoutingUpdateDefaultConfig + | Flow::RoutingDeleteConfig => Self::Routing, + Flow::MerchantConnectorsCreate | Flow::MerchantConnectorsRetrieve | Flow::MerchantConnectorsUpdate diff --git a/crates/router/src/routes/payments.rs b/crates/router/src/routes/payments.rs index 4bc05826a3e4..5ed73df1c175 100644 --- a/crates/router/src/routes/payments.rs +++ b/crates/router/src/routes/payments.rs @@ -178,6 +178,7 @@ pub async fn payments_start( req, api::AuthFlow::Client, payments::CallConnectorAction::Trigger, + None, HeaderPayload::default(), ) }, @@ -244,6 +245,7 @@ pub async fn payments_retrieve( req, auth_flow, payments::CallConnectorAction::Trigger, + None, HeaderPayload::default(), ) }, @@ -305,6 +307,7 @@ pub async fn payments_retrieve_with_gateway_creds( req, api::AuthFlow::Merchant, payments::CallConnectorAction::Trigger, + None, HeaderPayload::default(), ) }, @@ -509,6 +512,7 @@ pub async fn payments_capture( payload, api::AuthFlow::Merchant, payments::CallConnectorAction::Trigger, + None, HeaderPayload::default(), ) }, @@ -564,6 +568,7 @@ pub async fn payments_connector_session( payload, api::AuthFlow::Client, payments::CallConnectorAction::Trigger, + None, HeaderPayload::default(), ) }, @@ -774,6 +779,7 @@ pub async fn payments_cancel( req, api::AuthFlow::Merchant, payments::CallConnectorAction::Trigger, + None, HeaderPayload::default(), ) }, @@ -897,6 +903,7 @@ where // the operation are flow agnostic, and the flow is only required in the post_update_tracker // Thus the flow can be generated just before calling the connector instead of explicitly passing it here. + let eligible_connectors = req.connector.clone(); match req.payment_type.unwrap_or_default() { api_models::enums::PaymentType::Normal | api_models::enums::PaymentType::RecurringMandate @@ -916,6 +923,7 @@ where req, auth_flow, payments::CallConnectorAction::Trigger, + eligible_connectors, header_payload, ) .await @@ -936,6 +944,7 @@ where req, auth_flow, payments::CallConnectorAction::Trigger, + eligible_connectors, header_payload, ) .await diff --git a/crates/router/src/routes/routing.rs b/crates/router/src/routes/routing.rs new file mode 100644 index 000000000000..1d5ccdf502fc --- /dev/null +++ b/crates/router/src/routes/routing.rs @@ -0,0 +1,298 @@ +//! Analysis for usage of Routing in Payment flows +//! +//! Functions that are used to perform the api level configuration, retrieval, updation +//! of Routing configs. +use actix_web::{web, HttpRequest, Responder}; +use api_models::routing as routing_types; +#[cfg(feature = "business_profile_routing")] +use api_models::routing::{RoutingRetrieveLinkQuery, RoutingRetrieveQuery}; +use router_env::{ + tracing::{self, instrument}, + Flow, +}; + +use crate::{ + core::{api_locking, routing}, + routes::AppState, + services::{api as oss_api, authentication as oss_auth, authentication as auth}, +}; + +#[cfg(feature = "olap")] +#[instrument(skip_all)] +pub async fn routing_create_config( + state: web::Data, + req: HttpRequest, + json_payload: web::Json, +) -> impl Responder { + let flow = Flow::RoutingCreateConfig; + Box::pin(oss_api::server_wrap( + flow, + state, + &req, + json_payload.into_inner(), + |state, auth: oss_auth::AuthenticationData, payload| { + routing::create_routing_config(state, auth.merchant_account, auth.key_store, payload) + }, + #[cfg(not(feature = "release"))] + auth::auth_type(&oss_auth::ApiKeyAuth, &auth::JWTAuth, req.headers()), + #[cfg(feature = "release")] + &auth::JWTAuth, + api_locking::LockAction::NotApplicable, + )) + .await +} + +#[cfg(feature = "olap")] +#[instrument(skip_all)] +pub async fn routing_link_config( + state: web::Data, + req: HttpRequest, + path: web::Path, +) -> impl Responder { + let flow = Flow::RoutingLinkConfig; + Box::pin(oss_api::server_wrap( + flow, + state, + &req, + path.into_inner(), + |state, auth: oss_auth::AuthenticationData, algorithm_id| { + routing::link_routing_config( + state, + auth.merchant_account, + #[cfg(not(feature = "business_profile_routing"))] + auth.key_store, + algorithm_id, + ) + }, + #[cfg(not(feature = "release"))] + auth::auth_type(&oss_auth::ApiKeyAuth, &auth::JWTAuth, req.headers()), + #[cfg(feature = "release")] + &auth::JWTAuth, + api_locking::LockAction::NotApplicable, + )) + .await +} + +#[cfg(feature = "olap")] +#[instrument(skip_all)] +pub async fn routing_retrieve_config( + state: web::Data, + req: HttpRequest, + path: web::Path, +) -> impl Responder { + let algorithm_id = path.into_inner(); + let flow = Flow::RoutingRetrieveConfig; + Box::pin(oss_api::server_wrap( + flow, + state, + &req, + algorithm_id, + |state, auth: oss_auth::AuthenticationData, algorithm_id| { + routing::retrieve_routing_config(state, auth.merchant_account, algorithm_id) + }, + #[cfg(not(feature = "release"))] + auth::auth_type(&oss_auth::ApiKeyAuth, &auth::JWTAuth, req.headers()), + #[cfg(feature = "release")] + &auth::JWTAuth, + api_locking::LockAction::NotApplicable, + )) + .await +} + +#[cfg(feature = "olap")] +#[instrument(skip_all)] +pub async fn routing_retrieve_dictionary( + state: web::Data, + req: HttpRequest, + #[cfg(feature = "business_profile_routing")] query: web::Query, +) -> impl Responder { + #[cfg(feature = "business_profile_routing")] + { + let flow = Flow::RoutingRetrieveDictionary; + Box::pin(oss_api::server_wrap( + flow, + state, + &req, + query.into_inner(), + |state, auth: oss_auth::AuthenticationData, query_params| { + routing::retrieve_merchant_routing_dictionary( + state, + auth.merchant_account, + query_params, + ) + }, + #[cfg(not(feature = "release"))] + auth::auth_type(&oss_auth::ApiKeyAuth, &auth::JWTAuth, req.headers()), + #[cfg(feature = "release")] + &auth::JWTAuth, + api_locking::LockAction::NotApplicable, + )) + .await + } + + #[cfg(not(feature = "business_profile_routing"))] + { + let flow = Flow::RoutingRetrieveDictionary; + Box::pin(oss_api::server_wrap( + flow, + state, + &req, + (), + |state, auth: oss_auth::AuthenticationData, _| { + routing::retrieve_merchant_routing_dictionary(state, auth.merchant_account) + }, + #[cfg(not(feature = "release"))] + auth::auth_type(&oss_auth::ApiKeyAuth, &auth::JWTAuth, req.headers()), + #[cfg(feature = "release")] + &auth::JWTAuth, + api_locking::LockAction::NotApplicable, + )) + .await + } +} + +#[cfg(feature = "olap")] +#[instrument(skip_all)] +pub async fn routing_unlink_config( + state: web::Data, + req: HttpRequest, + #[cfg(feature = "business_profile_routing")] payload: web::Json< + routing_types::RoutingConfigRequest, + >, +) -> impl Responder { + #[cfg(feature = "business_profile_routing")] + { + let flow = Flow::RoutingUnlinkConfig; + Box::pin(oss_api::server_wrap( + flow, + state, + &req, + payload.into_inner(), + |state, auth: oss_auth::AuthenticationData, payload_req| { + routing::unlink_routing_config(state, auth.merchant_account, payload_req) + }, + #[cfg(not(feature = "release"))] + auth::auth_type(&oss_auth::ApiKeyAuth, &auth::JWTAuth, req.headers()), + #[cfg(feature = "release")] + &auth::JWTAuth, + api_locking::LockAction::NotApplicable, + )) + .await + } + + #[cfg(not(feature = "business_profile_routing"))] + { + let flow = Flow::RoutingUnlinkConfig; + Box::pin(oss_api::server_wrap( + flow, + state, + &req, + (), + |state, auth: oss_auth::AuthenticationData, _| { + routing::unlink_routing_config(state, auth.merchant_account, auth.key_store) + }, + #[cfg(not(feature = "release"))] + auth::auth_type(&oss_auth::ApiKeyAuth, &auth::JWTAuth, req.headers()), + #[cfg(feature = "release")] + &auth::JWTAuth, + api_locking::LockAction::NotApplicable, + )) + .await + } +} + +#[cfg(feature = "olap")] +#[instrument(skip_all)] +pub async fn routing_update_default_config( + state: web::Data, + req: HttpRequest, + json_payload: web::Json>, +) -> impl Responder { + oss_api::server_wrap( + Flow::RoutingUpdateDefaultConfig, + state, + &req, + json_payload.into_inner(), + |state, auth: oss_auth::AuthenticationData, updated_config| { + routing::update_default_routing_config(state, auth.merchant_account, updated_config) + }, + #[cfg(not(feature = "release"))] + auth::auth_type(&oss_auth::ApiKeyAuth, &auth::JWTAuth, req.headers()), + #[cfg(feature = "release")] + &auth::JWTAuth, + api_locking::LockAction::NotApplicable, + ) + .await +} + +#[cfg(feature = "olap")] +#[instrument(skip_all)] +pub async fn routing_retrieve_default_config( + state: web::Data, + req: HttpRequest, +) -> impl Responder { + oss_api::server_wrap( + Flow::RoutingRetrieveDefaultConfig, + state, + &req, + (), + |state, auth: oss_auth::AuthenticationData, _| { + routing::retrieve_default_routing_config(state, auth.merchant_account) + }, + #[cfg(not(feature = "release"))] + auth::auth_type(&oss_auth::ApiKeyAuth, &auth::JWTAuth, req.headers()), + #[cfg(feature = "release")] + &auth::JWTAuth, + api_locking::LockAction::NotApplicable, + ) + .await +} + +#[cfg(feature = "olap")] +#[instrument(skip_all)] +pub async fn routing_retrieve_linked_config( + state: web::Data, + req: HttpRequest, + #[cfg(feature = "business_profile_routing")] query: web::Query, +) -> impl Responder { + #[cfg(feature = "business_profile_routing")] + { + use crate::services::authentication::AuthenticationData; + let flow = Flow::RoutingRetrieveActiveConfig; + Box::pin(oss_api::server_wrap( + flow, + state, + &req, + query.into_inner(), + |state, auth: AuthenticationData, query_params| { + routing::retrieve_linked_routing_config(state, auth.merchant_account, query_params) + }, + #[cfg(not(feature = "release"))] + auth::auth_type(&oss_auth::ApiKeyAuth, &auth::JWTAuth, req.headers()), + #[cfg(feature = "release")] + &auth::JWTAuth, + api_locking::LockAction::NotApplicable, + )) + .await + } + + #[cfg(not(feature = "business_profile_routing"))] + { + let flow = Flow::RoutingRetrieveActiveConfig; + Box::pin(oss_api::server_wrap( + flow, + state, + &req, + (), + |state, auth: oss_auth::AuthenticationData, _| { + routing::retrieve_linked_routing_config(state, auth.merchant_account) + }, + #[cfg(not(feature = "release"))] + auth::auth_type(&oss_auth::ApiKeyAuth, &auth::JWTAuth, req.headers()), + #[cfg(feature = "release")] + &auth::JWTAuth, + api_locking::LockAction::NotApplicable, + )) + .await + } +} diff --git a/crates/router/src/types/api.rs b/crates/router/src/types/api.rs index 8f5a0f8a59f2..69e7f8898d15 100644 --- a/crates/router/src/types/api.rs +++ b/crates/router/src/types/api.rs @@ -11,6 +11,7 @@ pub mod payment_methods; pub mod payments; pub mod payouts; pub mod refunds; +pub mod routing; pub mod webhooks; use std::{fmt::Debug, str::FromStr}; @@ -38,6 +39,13 @@ pub trait ConnectorAccessToken: { } +#[derive(Clone)] +pub enum ConnectorCallType { + PreDetermined(ConnectorData), + Retryable(Vec), + SessionMultiple(Vec), +} + #[derive(Clone, Debug)] pub struct VerifyWebhookSource; @@ -218,12 +226,6 @@ pub enum PayoutConnectorChoice { Decide, } -#[derive(Clone)] -pub enum ConnectorCallType { - Multiple(Vec), - Single(ConnectorData), -} - #[cfg(feature = "payouts")] #[derive(Clone)] pub enum PayoutConnectorCallType { @@ -231,12 +233,6 @@ pub enum PayoutConnectorCallType { Single(PayoutConnectorData), } -impl ConnectorCallType { - pub fn is_single(&self) -> bool { - matches!(self, Self::Single(_)) - } -} - #[cfg(feature = "payouts")] impl PayoutConnectorData { pub fn get_connector_by_name( diff --git a/crates/router/src/types/api/admin.rs b/crates/router/src/types/api/admin.rs index 258a3d566dde..6bbe9149f4d7 100644 --- a/crates/router/src/types/api/admin.rs +++ b/crates/router/src/types/api/admin.rs @@ -4,8 +4,8 @@ pub use api_models::admin::{ MerchantAccountResponse, MerchantAccountUpdate, MerchantConnectorCreate, MerchantConnectorDeleteResponse, MerchantConnectorDetails, MerchantConnectorDetailsWrap, MerchantConnectorId, MerchantConnectorResponse, MerchantDetails, MerchantId, - PaymentMethodsEnabled, PayoutRoutingAlgorithm, PayoutStraightThroughAlgorithm, - RoutingAlgorithm, StraightThroughAlgorithm, ToggleKVRequest, ToggleKVResponse, WebhookDetails, + PaymentMethodsEnabled, PayoutRoutingAlgorithm, PayoutStraightThroughAlgorithm, ToggleKVRequest, + ToggleKVResponse, WebhookDetails, }; use common_utils::ext_traits::ValueExt; use error_stack::ResultExt; diff --git a/crates/router/src/types/api/routing.rs b/crates/router/src/types/api/routing.rs new file mode 100644 index 000000000000..faafac76e3dc --- /dev/null +++ b/crates/router/src/types/api/routing.rs @@ -0,0 +1,41 @@ +#[cfg(feature = "backwards_compatibility")] +pub use api_models::routing::RoutableChoiceKind; +pub use api_models::{ + enums as api_enums, + routing::{ + ConnectorVolumeSplit, DetailedConnectorChoice, RoutableConnectorChoice, RoutingAlgorithm, + RoutingAlgorithmKind, RoutingAlgorithmRef, RoutingConfigRequest, RoutingDictionary, + RoutingDictionaryRecord, StraightThroughAlgorithm, + }, +}; + +use super::types::api as api_oss; + +pub struct SessionRoutingChoice { + pub connector: api_oss::ConnectorData, + #[cfg(not(feature = "connector_choice_mca_id"))] + pub sub_label: Option, + pub payment_method_type: api_enums::PaymentMethodType, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct ConnectorVolumeSplitV0 { + pub connector: RoutableConnectorChoice, + pub split: u8, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(tag = "type", content = "data", rename_all = "snake_case")] +pub enum RoutingAlgorithmV0 { + Single(Box), + Priority(Vec), + VolumeSplit(Vec), + Custom { timestamp: i64 }, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct FrmRoutingAlgorithm { + pub data: String, + #[serde(rename = "type")] + pub algorithm_type: String, +} diff --git a/crates/router/src/types/storage.rs b/crates/router/src/types/storage.rs index 92ead76e9137..00a5e07a30e8 100644 --- a/crates/router/src/types/storage.rs +++ b/crates/router/src/types/storage.rs @@ -21,6 +21,9 @@ pub mod merchant_key_store; pub mod payment_attempt; pub mod payment_link; pub mod payment_method; +pub mod routing_algorithm; +use std::collections::HashMap; + pub use diesel_models::{ProcessTracker, ProcessTrackerNew, ProcessTrackerUpdate}; pub use scheduler::db::process_tracker; pub mod reverse_lookup; @@ -41,11 +44,63 @@ pub use self::{ customers::*, dispute::*, ephemeral_key::*, events::*, file::*, locker_mock_up::*, mandate::*, merchant_account::*, merchant_connector_account::*, merchant_key_store::*, payment_link::*, payment_method::*, payout_attempt::*, payouts::*, process_tracker::*, refund::*, - reverse_lookup::*, + reverse_lookup::*, routing_algorithm::*, }; +use crate::types::api::routing; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct RoutingData { pub routed_through: Option, - pub algorithm: Option, + #[cfg(feature = "connector_choice_mca_id")] + pub merchant_connector_id: Option, + #[cfg(not(feature = "connector_choice_mca_id"))] + pub business_sub_label: Option, + pub routing_info: PaymentRoutingInfo, + pub algorithm: Option, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(from = "PaymentRoutingInfoSerde", into = "PaymentRoutingInfoSerde")] +pub struct PaymentRoutingInfo { + pub algorithm: Option, + pub pre_routing_results: + Option>, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct PaymentRoutingInfoInner { + pub algorithm: Option, + pub pre_routing_results: + Option>, +} + +#[derive(Debug, serde::Serialize, serde::Deserialize)] +#[serde(untagged)] +pub enum PaymentRoutingInfoSerde { + OnlyAlgorithm(Box), + WithDetails(Box), +} + +impl From for PaymentRoutingInfo { + fn from(value: PaymentRoutingInfoSerde) -> Self { + match value { + PaymentRoutingInfoSerde::OnlyAlgorithm(algo) => Self { + algorithm: Some(*algo), + pre_routing_results: None, + }, + PaymentRoutingInfoSerde::WithDetails(details) => Self { + algorithm: details.algorithm, + pre_routing_results: details.pre_routing_results, + }, + } + } +} + +impl From for PaymentRoutingInfoSerde { + fn from(value: PaymentRoutingInfo) -> Self { + Self::WithDetails(Box::new(PaymentRoutingInfoInner { + algorithm: value.algorithm, + pre_routing_results: value.pre_routing_results, + })) + } } diff --git a/crates/router/src/types/storage/routing_algorithm.rs b/crates/router/src/types/storage/routing_algorithm.rs new file mode 100644 index 000000000000..8022ab075ec4 --- /dev/null +++ b/crates/router/src/types/storage/routing_algorithm.rs @@ -0,0 +1,3 @@ +pub use diesel_models::routing_algorithm::{ + RoutingAlgorithm, RoutingAlgorithmMetadata, RoutingProfileMetadata, +}; diff --git a/crates/router/src/types/transformers.rs b/crates/router/src/types/transformers.rs index d38497c7100a..83ca0d014dc8 100644 --- a/crates/router/src/types/transformers.rs +++ b/crates/router/src/types/transformers.rs @@ -1,6 +1,6 @@ // use actix_web::HttpMessage; use actix_web::http::header::HeaderMap; -use api_models::{enums as api_enums, payments}; +use api_models::{enums as api_enums, payments, routing::ConnectorSelection}; use common_utils::{ consts::X_HS_LATENCY, crypto::Encryptable, @@ -8,14 +8,15 @@ use common_utils::{ pii, }; use diesel_models::enums as storage_enums; -use error_stack::ResultExt; +use error_stack::{IntoReport, ResultExt}; +use euclid::enums as dsl_enums; use masking::{ExposeInterface, PeekInterface}; use super::domain; use crate::{ core::errors, services::authentication::get_header_value_by_key, - types::{api as api_types, storage}, + types::{api as api_types, api::routing as routing_types, storage}, }; pub trait ForeignInto { @@ -169,6 +170,154 @@ impl ForeignFrom for api_models::payments::Manda } } +impl ForeignTryFrom for api_enums::RoutableConnectors { + type Error = error_stack::Report; + + fn foreign_try_from(from: api_enums::Connector) -> Result { + Ok(match from { + #[cfg(feature = "dummy_connector")] + api_enums::Connector::DummyConnector1 => Self::DummyConnector1, + #[cfg(feature = "dummy_connector")] + api_enums::Connector::DummyConnector2 => Self::DummyConnector2, + #[cfg(feature = "dummy_connector")] + api_enums::Connector::DummyConnector3 => Self::DummyConnector3, + #[cfg(feature = "dummy_connector")] + api_enums::Connector::DummyConnector4 => Self::DummyConnector4, + #[cfg(feature = "dummy_connector")] + api_enums::Connector::DummyConnector5 => Self::DummyConnector5, + #[cfg(feature = "dummy_connector")] + api_enums::Connector::DummyConnector6 => Self::DummyConnector6, + #[cfg(feature = "dummy_connector")] + api_enums::Connector::DummyConnector7 => Self::DummyConnector7, + api_enums::Connector::Aci => Self::Aci, + api_enums::Connector::Adyen => Self::Adyen, + api_enums::Connector::Airwallex => Self::Airwallex, + api_enums::Connector::Authorizedotnet => Self::Authorizedotnet, + api_enums::Connector::Bitpay => Self::Bitpay, + api_enums::Connector::Bambora => Self::Bambora, + api_enums::Connector::Bluesnap => Self::Bluesnap, + api_enums::Connector::Boku => Self::Boku, + api_enums::Connector::Braintree => Self::Braintree, + api_enums::Connector::Cashtocode => Self::Cashtocode, + api_enums::Connector::Checkout => Self::Checkout, + api_enums::Connector::Coinbase => Self::Coinbase, + api_enums::Connector::Cryptopay => Self::Cryptopay, + api_enums::Connector::Cybersource => Self::Cybersource, + api_enums::Connector::Dlocal => Self::Dlocal, + api_enums::Connector::Fiserv => Self::Fiserv, + api_enums::Connector::Forte => Self::Forte, + api_enums::Connector::Globalpay => Self::Globalpay, + api_enums::Connector::Globepay => Self::Globepay, + api_enums::Connector::Gocardless => Self::Gocardless, + api_enums::Connector::Helcim => Self::Helcim, + api_enums::Connector::Iatapay => Self::Iatapay, + api_enums::Connector::Klarna => Self::Klarna, + api_enums::Connector::Mollie => Self::Mollie, + api_enums::Connector::Multisafepay => Self::Multisafepay, + api_enums::Connector::Nexinets => Self::Nexinets, + api_enums::Connector::Nmi => Self::Nmi, + api_enums::Connector::Noon => Self::Noon, + api_enums::Connector::Nuvei => Self::Nuvei, + api_enums::Connector::Opennode => Self::Opennode, + api_enums::Connector::Payme => Self::Payme, + api_enums::Connector::Paypal => Self::Paypal, + api_enums::Connector::Payu => Self::Payu, + api_enums::Connector::Plaid => { + Err(common_utils::errors::ValidationError::InvalidValue { + message: "plaid is not a routable connector".to_string(), + }) + .into_report()? + } + api_enums::Connector::Powertranz => Self::Powertranz, + api_enums::Connector::Rapyd => Self::Rapyd, + api_enums::Connector::Shift4 => Self::Shift4, + api_enums::Connector::Signifyd => { + Err(common_utils::errors::ValidationError::InvalidValue { + message: "signifyd is not a routable connector".to_string(), + }) + .into_report()? + } + api_enums::Connector::Square => Self::Square, + api_enums::Connector::Stax => Self::Stax, + api_enums::Connector::Stripe => Self::Stripe, + api_enums::Connector::Trustpay => Self::Trustpay, + api_enums::Connector::Tsys => Self::Tsys, + api_enums::Connector::Volt => Self::Volt, + api_enums::Connector::Wise => Self::Wise, + api_enums::Connector::Worldline => Self::Worldline, + api_enums::Connector::Worldpay => Self::Worldpay, + api_enums::Connector::Zen => Self::Zen, + }) + } +} + +impl ForeignFrom for api_enums::RoutableConnectors { + fn foreign_from(from: dsl_enums::Connector) -> Self { + match from { + #[cfg(feature = "dummy_connector")] + dsl_enums::Connector::DummyConnector1 => Self::DummyConnector1, + #[cfg(feature = "dummy_connector")] + dsl_enums::Connector::DummyConnector2 => Self::DummyConnector2, + #[cfg(feature = "dummy_connector")] + dsl_enums::Connector::DummyConnector3 => Self::DummyConnector3, + #[cfg(feature = "dummy_connector")] + dsl_enums::Connector::DummyConnector4 => Self::DummyConnector4, + #[cfg(feature = "dummy_connector")] + dsl_enums::Connector::DummyConnector5 => Self::DummyConnector5, + #[cfg(feature = "dummy_connector")] + dsl_enums::Connector::DummyConnector6 => Self::DummyConnector6, + #[cfg(feature = "dummy_connector")] + dsl_enums::Connector::DummyConnector7 => Self::DummyConnector7, + dsl_enums::Connector::Aci => Self::Aci, + dsl_enums::Connector::Adyen => Self::Adyen, + dsl_enums::Connector::Airwallex => Self::Airwallex, + dsl_enums::Connector::Authorizedotnet => Self::Authorizedotnet, + dsl_enums::Connector::Bitpay => Self::Bitpay, + dsl_enums::Connector::Bambora => Self::Bambora, + dsl_enums::Connector::Bluesnap => Self::Bluesnap, + dsl_enums::Connector::Boku => Self::Boku, + dsl_enums::Connector::Braintree => Self::Braintree, + dsl_enums::Connector::Cashtocode => Self::Cashtocode, + dsl_enums::Connector::Checkout => Self::Checkout, + dsl_enums::Connector::Coinbase => Self::Coinbase, + dsl_enums::Connector::Cryptopay => Self::Cryptopay, + dsl_enums::Connector::Cybersource => Self::Cybersource, + dsl_enums::Connector::Dlocal => Self::Dlocal, + dsl_enums::Connector::Fiserv => Self::Fiserv, + dsl_enums::Connector::Forte => Self::Forte, + dsl_enums::Connector::Globalpay => Self::Globalpay, + dsl_enums::Connector::Globepay => Self::Globepay, + dsl_enums::Connector::Gocardless => Self::Gocardless, + dsl_enums::Connector::Helcim => Self::Helcim, + dsl_enums::Connector::Iatapay => Self::Iatapay, + dsl_enums::Connector::Klarna => Self::Klarna, + dsl_enums::Connector::Mollie => Self::Mollie, + dsl_enums::Connector::Multisafepay => Self::Multisafepay, + dsl_enums::Connector::Nexinets => Self::Nexinets, + dsl_enums::Connector::Nmi => Self::Nmi, + dsl_enums::Connector::Noon => Self::Noon, + dsl_enums::Connector::Nuvei => Self::Nuvei, + dsl_enums::Connector::Opennode => Self::Opennode, + dsl_enums::Connector::Payme => Self::Payme, + dsl_enums::Connector::Paypal => Self::Paypal, + dsl_enums::Connector::Payu => Self::Payu, + dsl_enums::Connector::Powertranz => Self::Powertranz, + dsl_enums::Connector::Rapyd => Self::Rapyd, + dsl_enums::Connector::Shift4 => Self::Shift4, + dsl_enums::Connector::Square => Self::Square, + dsl_enums::Connector::Stax => Self::Stax, + dsl_enums::Connector::Stripe => Self::Stripe, + dsl_enums::Connector::Trustpay => Self::Trustpay, + dsl_enums::Connector::Tsys => Self::Tsys, + dsl_enums::Connector::Volt => Self::Volt, + dsl_enums::Connector::Wise => Self::Wise, + dsl_enums::Connector::Worldline => Self::Worldline, + dsl_enums::Connector::Worldpay => Self::Worldpay, + dsl_enums::Connector::Zen => Self::Zen, + } + } +} + impl ForeignFrom for api_models::payments::MandateAmountData { fn foreign_from(from: storage_enums::MandateAmountData) -> Self { Self { @@ -862,6 +1011,16 @@ impl From for payments::AddressDetails { } } +impl ForeignFrom for routing_types::RoutingAlgorithm { + fn foreign_from(value: ConnectorSelection) -> Self { + match value { + ConnectorSelection::Priority(connectors) => Self::Priority(connectors), + + ConnectorSelection::VolumeSplit(splits) => Self::VolumeSplit(splits), + } + } +} + impl ForeignFrom for diesel_models::organization::OrganizationNew { diff --git a/crates/router/src/workflows/payment_sync.rs b/crates/router/src/workflows/payment_sync.rs index 540f2d68dd61..f41b300c5127 100644 --- a/crates/router/src/workflows/payment_sync.rs +++ b/crates/router/src/workflows/payment_sync.rs @@ -69,6 +69,7 @@ impl ProcessTrackerWorkflow for PaymentsSyncWorkflow { tracking_data.clone(), payment_flows::CallConnectorAction::Trigger, services::AuthFlow::Client, + None, api::HeaderPayload::default(), ) .await?; diff --git a/crates/router/tests/payments.rs b/crates/router/tests/payments.rs index 551960ac1380..d2d6c48507e5 100644 --- a/crates/router/tests/payments.rs +++ b/crates/router/tests/payments.rs @@ -369,6 +369,7 @@ async fn payments_create_core() { req, services::AuthFlow::Merchant, payments::CallConnectorAction::Trigger, + None, api::HeaderPayload::default(), ) .await @@ -539,6 +540,7 @@ async fn payments_create_core_adyen_no_redirect() { req, services::AuthFlow::Merchant, payments::CallConnectorAction::Trigger, + None, api::HeaderPayload::default(), ) .await diff --git a/crates/router/tests/payments2.rs b/crates/router/tests/payments2.rs index 96ed131dc6f8..ed8827a910be 100644 --- a/crates/router/tests/payments2.rs +++ b/crates/router/tests/payments2.rs @@ -135,6 +135,7 @@ async fn payments_create_core() { req, services::AuthFlow::Merchant, payments::CallConnectorAction::Trigger, + None, api::HeaderPayload::default(), ) .await @@ -313,6 +314,7 @@ async fn payments_create_core_adyen_no_redirect() { req, services::AuthFlow::Merchant, payments::CallConnectorAction::Trigger, + None, api::HeaderPayload::default(), ) .await diff --git a/crates/router_env/src/logger/types.rs b/crates/router_env/src/logger/types.rs index d63ddce58f30..9822432115b0 100644 --- a/crates/router_env/src/logger/types.rs +++ b/crates/router_env/src/logger/types.rs @@ -163,6 +163,26 @@ pub enum Flow { RefundsUpdate, /// Refunds list flow. RefundsList, + /// Routing create flow, + RoutingCreateConfig, + /// Routing link config + RoutingLinkConfig, + /// Routing link config + RoutingUnlinkConfig, + /// Routing retrieve config + RoutingRetrieveConfig, + /// Routing retrieve active config + RoutingRetrieveActiveConfig, + /// Routing retrieve default config + RoutingRetrieveDefaultConfig, + /// Routing retrieve dictionary + RoutingRetrieveDictionary, + /// Routing update config + RoutingUpdateConfig, + /// Routing update default config + RoutingUpdateDefaultConfig, + /// Routing delete config + RoutingDeleteConfig, /// Incoming Webhook Receive IncomingWebhookReceive, /// Validate payment method flow diff --git a/migrations/2023-10-19-101558_create_routing_algorithm_table/down.sql b/migrations/2023-10-19-101558_create_routing_algorithm_table/down.sql new file mode 100644 index 000000000000..2cace88297db --- /dev/null +++ b/migrations/2023-10-19-101558_create_routing_algorithm_table/down.sql @@ -0,0 +1,4 @@ +-- This file should undo anything in `up.sql` + +DROP TABLE routing_algorithm; +DROP TYPE "RoutingAlgorithmKind"; diff --git a/migrations/2023-10-19-101558_create_routing_algorithm_table/up.sql b/migrations/2023-10-19-101558_create_routing_algorithm_table/up.sql new file mode 100644 index 000000000000..361194561227 --- /dev/null +++ b/migrations/2023-10-19-101558_create_routing_algorithm_table/up.sql @@ -0,0 +1,19 @@ +-- Your SQL goes here + +CREATE TYPE "RoutingAlgorithmKind" AS ENUM ('single', 'priority', 'volume_split', 'advanced'); + +CREATE TABLE routing_algorithm ( + algorithm_id VARCHAR(64) PRIMARY KEY, + profile_id VARCHAR(64) NOT NULL, + merchant_id VARCHAR(64) NOT NULL, + name VARCHAR(64) NOT NULL, + description VARCHAR(256), + kind "RoutingAlgorithmKind" NOT NULL, + algorithm_data JSONB NOT NULL, + created_at TIMESTAMP NOT NULL, + modified_at TIMESTAMP NOT NULL +); + +CREATE INDEX routing_algorithm_profile_id_modified_at ON routing_algorithm (profile_id, modified_at DESC); + +CREATE INDEX routing_algorithm_merchant_id_modified_at ON routing_algorithm (merchant_id, modified_at DESC);