diff --git a/Cargo.lock b/Cargo.lock index 4f1ea22f..e5448649 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -399,6 +399,7 @@ dependencies = [ "cfg-if", "getrandom", "once_cell", + "serde", "version_check", "zerocopy", ] @@ -1182,6 +1183,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "bytecount" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" + [[package]] name = "byteorder" version = "1.5.0" @@ -1863,6 +1870,31 @@ dependencies = [ "serde", ] +[[package]] +name = "cyclonedx-bom" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a114dd99ed051f1481d8d35acd455f77469f026b783fe08074763fdf1701506" +dependencies = [ + "base64 0.21.7", + "cyclonedx-bom-macros", + "fluent-uri 0.1.4", + "indexmap 2.6.0", + "jsonschema", + "once_cell", + "ordered-float 4.3.0", + "packageurl", + "regex", + "serde", + "serde_json", + "spdx", + "strum 0.26.3", + "thiserror", + "time", + "uuid", + "xml-rs", +] + [[package]] name = "cyclonedx-bom" version = "0.7.0" @@ -2354,6 +2386,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" dependencies = [ "log", + "regex", ] [[package]] @@ -2365,6 +2398,7 @@ dependencies = [ "anstream", "anstyle", "env_filter", + "humantime", "log", ] @@ -2433,6 +2467,16 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "fancy-regex" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" +dependencies = [ + "bit-set", + "regex", +] + [[package]] name = "fancy-regex" version = "0.12.0" @@ -2582,6 +2626,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fraction" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3027ae1df8d41b4bed2241c8fdad4acc1e7af60c8e17743534b545e77182d678" +dependencies = [ + "lazy_static", + "num", +] + [[package]] name = "funty" version = "2.0.0" @@ -3673,6 +3727,15 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "iso8601" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "924e5d73ea28f59011fec52a0d12185d496a9b075d360657aed2a5707f701153" +dependencies = [ + "nom", +] + [[package]] name = "itertools" version = "0.10.5" @@ -3765,6 +3828,34 @@ dependencies = [ "thiserror", ] +[[package]] +name = "jsonschema" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a071f4f7efc9a9118dfb627a0a94ef247986e1ab8606a4c806ae2b3aa3b6978" +dependencies = [ + "ahash 0.8.11", + "anyhow", + "base64 0.21.7", + "bytecount", + "fancy-regex 0.11.0", + "fraction", + "getrandom", + "iso8601", + "itoa", + "memchr", + "num-cmp", + "once_cell", + "parking_lot 0.12.3", + "percent-encoding", + "regex", + "serde", + "serde_json", + "time", + "url", + "uuid", +] + [[package]] name = "lalrpop" version = "0.20.2" @@ -4397,6 +4488,20 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -4424,6 +4529,21 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-cmp" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa" + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -4460,6 +4580,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -6418,7 +6549,7 @@ dependencies = [ "bytes", "chrono", "csv", - "cyclonedx-bom", + "cyclonedx-bom 0.7.0", "digest", "filetime", "fluent-uri 0.2.0", @@ -6440,6 +6571,27 @@ dependencies = [ "walker-common", ] +[[package]] +name = "sbomsleuth" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9bdf732fb14c2e4914aaa154b909968c3beabb4dccd005e09dfab8e06a9617a" +dependencies = [ + "anyhow", + "colored", + "cyclonedx-bom 0.6.2", + "env_logger", + "log", + "reqwest 0.12.8", + "sbom-walker", + "serde", + "serde_json", + "spdx-rs", + "tokio", + "walker-common", + "walker-extras", +] + [[package]] name = "schannel" version = "0.1.24" @@ -7955,7 +8107,7 @@ dependencies = [ "anyhow", "base64 0.21.7", "bstr", - "fancy-regex", + "fancy-regex 0.12.0", "lazy_static", "parking_lot 0.12.3", "rustc-hash 1.1.0", @@ -8490,7 +8642,7 @@ dependencies = [ "chrono", "criterion", "csaf", - "cyclonedx-bom", + "cyclonedx-bom 0.7.0", "hex", "humantime", "jsonpath-rust", @@ -8535,7 +8687,7 @@ dependencies = [ "criterion", "csaf", "cve", - "cyclonedx-bom", + "cyclonedx-bom 0.7.0", "futures-util", "hex", "humantime", @@ -8660,7 +8812,7 @@ dependencies = [ "cpe", "csaf", "cve", - "cyclonedx-bom", + "cyclonedx-bom 0.7.0", "hex", "humantime", "jsn", @@ -8675,6 +8827,7 @@ dependencies = [ "roxmltree", "rstest", "sbom-walker", + "sbomsleuth", "sea-orm", "sea-query", "semver", @@ -9221,6 +9374,7 @@ dependencies = [ "bytes", "bzip2", "chrono", + "clap", "csv", "digest", "filetime", @@ -9250,6 +9404,26 @@ dependencies = [ "xattr", ] +[[package]] +name = "walker-extras" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "451ae6af73f431c5c002d8961ad9b28eaf8411f983b32cf3eceb67022cf8b15f" +dependencies = [ + "anyhow", + "async-trait", + "bytes", + "clap", + "csaf-walker", + "humantime", + "log", + "reqwest 0.12.8", + "sbom-walker", + "thiserror", + "tokio", + "walker-common", +] + [[package]] name = "want" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index 14cc5fe2..de84c412 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,6 @@ publish = false license = "Apache-2.0" [workspace.dependencies] -actix = "0.13.3" actix-cors = "0.7" actix-http = "3.3.1" actix-tls = "3" @@ -40,13 +39,11 @@ actix-web-static-files = "4.0.1" anyhow = "1.0.72" async-graphql = "7.0.5" async-graphql-actix-web = "7.0.5" -async-std = "1" async-trait = "0.1.74" base64 = "0.22" biscuit = "0.7" build-info = "0.0.38" build-info-build = "0.0.38" -build-info-common = "0.0.38" bytes = "1.5" bytesize = "1.3" criterion = "0.5.1" @@ -58,19 +55,15 @@ csaf = { version = "0.5.0", default-features = false } csaf-walker = { version = "0.9.0", default-features = false } cve = "0.3.1" cyclonedx-bom = "0.7.0" -env_logger = "0.11.0" futures = "0.3.30" futures-util = "0.3" garage-door = "0.1.1" git2 = { version = "0.19.0", features = ["ssh"] } hex = "0.4.3" -hide = "0.1.5" http = "1" human-date-parser = "0.2" humantime = "2" humantime-serde = "1" -indicatif = "0.17.8" -indicatif-log-bridge = "0.2" itertools = "0.13" jsn = "0.14" json-merge-patch = "0.0.1" @@ -83,7 +76,6 @@ log = "0.4.19" mime = "0.3.17" native-tls = "0.2" nu-ansi-term = "0.46" -once_cell = "1.19.0" openid = "0.15" openssl = "0.10" opentelemetry = "0.24" @@ -107,6 +99,7 @@ ring = "0.17.8" roxmltree = "0.20.0" rstest = "0.22" rust-s3 = "0.35" +sbomsleuth = { version = "0.1.9"} sbom-walker = { version = "0.9.0", default-features = false, features = ["crypto-openssl", "cyclonedx-bom", "spdx-rs"] } schemars = "0.8" sea-orm = "~1.0" # See https://www.sea-ql.org/blog/2024-08-04-sea-orm-1.0/#release-planning @@ -121,7 +114,6 @@ spdx = "0.10.6" spdx-expression = "0.5.2" spdx-rs = "0.5.3" sqlx = "0.7" -static-files = "0.2.3" strum = "0.26.3" temp-env = "0.3" tempfile = "3" @@ -130,10 +122,8 @@ test-log = "0.2.16" thiserror = "1.0.58" time = "0.3" tokio = "1.30.0" -tokio-stream = "0.1.15" tokio-util = "0.7" tracing = "0.1" -tracing-bunyan-formatter = "0.3.7" # Note: This uses OTEL 0.24 https://crates.io/crates/tracing-opentelemetry/0.25.0/dependencies tracing-opentelemetry = "0.25" tracing-subscriber = { version = "0.3.18", default-features = false } @@ -146,7 +136,6 @@ utoipa-swagger-ui = "7.1.0" uuid = "1.7.0" walkdir = "2.5" walker-common = "0.9.3" -walker-extras = "0.9.0" zip = "2.2.0" trustify-auth = { path = "common/auth", features = ["actix", "swagger"] } @@ -165,12 +154,6 @@ trustify-module-storage = { path = "modules/storage" } trustify-module-graphql = { path = "modules/graphql" } trustify-test-context = { path = "test-context" } trustify-module-analysis = { path = "modules/analysis" } - -# These dependencies are active during both the build time and the run time. So they are normal dependencies -# as well as build-dependencies. However, we can't control feature flags for build dependencies the way we do -# it for normal dependencies. So enabling the vendor feature for openssl-sys doesn't work for the build-dependencies. -# This will fail the build on targets where we need vendoring for openssl. Using rustls instead works around this issue. -postgresql_archive = { version = "0.16.3", default-features = false, features = ["theseus", "rustls-tls"] } postgresql_embedded = { version = "0.16.3", default-features = false, features = ["theseus", "rustls-tls"] } postgresql_commands = { version = "0.16.3", default-features = false, features = ["tokio"] } diff --git a/entity/src/source_document.rs b/entity/src/source_document.rs index 447e9b1b..3103099d 100644 --- a/entity/src/source_document.rs +++ b/entity/src/source_document.rs @@ -1,4 +1,5 @@ use sea_orm::entity::prelude::*; +use sea_orm::JsonValue; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "source_document")] @@ -8,6 +9,8 @@ pub struct Model { pub sha256: String, pub sha384: String, pub sha512: String, + #[sea_orm(column_type = "JsonBinary")] + pub meta: JsonValue, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/migration/src/lib.rs b/migration/src/lib.rs index 1d22f26e..3002cf05 100644 --- a/migration/src/lib.rs +++ b/migration/src/lib.rs @@ -83,6 +83,7 @@ mod m0000631_alter_product_cpe_key; mod m0000640_create_product_status; mod m0000650_alter_advisory_tracking; mod m0000660_purl_id_indexes; +mod m0000680_add_meta_report_source_document; pub struct Migrator; @@ -173,6 +174,7 @@ impl MigratorTrait for Migrator { Box::new(m0000640_create_product_status::Migration), Box::new(m0000650_alter_advisory_tracking::Migration), Box::new(m0000660_purl_id_indexes::Migration), + Box::new(m0000680_add_meta_report_source_document::Migration), ] } } diff --git a/migration/src/m0000680_add_meta_report_source_document.rs b/migration/src/m0000680_add_meta_report_source_document.rs new file mode 100644 index 00000000..801dd59f --- /dev/null +++ b/migration/src/m0000680_add_meta_report_source_document.rs @@ -0,0 +1,41 @@ +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + //create jsonb purl column + manager + .alter_table( + Table::alter() + .table(SourceDocument::Table) + .add_column(ColumnDef::new(SourceDocument::Meta).json_binary()) + .to_owned(), + ) + .await?; + + Ok(()) + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + // Drop column + manager + .alter_table( + Table::alter() + .table(SourceDocument::Table) + .drop_column(SourceDocument::Meta) + .to_owned(), + ) + .await?; + + Ok(()) + } +} + +#[derive(DeriveIden)] +enum SourceDocument { + Table, + Meta, +} diff --git a/modules/fundamental/src/ai/endpoints/test.rs b/modules/fundamental/src/ai/endpoints/test.rs index 637e90bd..dc3870c7 100644 --- a/modules/fundamental/src/ai/endpoints/test.rs +++ b/modules/fundamental/src/ai/endpoints/test.rs @@ -84,12 +84,89 @@ async fn tools(ctx: &TrustifyContext) -> anyhow::Result<()> { assert_eq!( result, json!([ - {"name":"ProductInfo","description":"This tool can be used to get information about a product.\nThe input should be the name of the product to search for.\nWhen the input is a full name, the tool will provide information about the product.\nWhen the input is a partial name, the tool will provide a list of possible matches.","parameters":{"type":"object","properties":{"input":{"type":"string","description":"This tool can be used to get information about a product.\nThe input should be the name of the product to search for.\nWhen the input is a full name, the tool will provide information about the product.\nWhen the input is a partial name, the tool will provide a list of possible matches."}},"required":["input"]}}, - {"name":"CVEInfo","description":"This tool can be used to get information about a Vulnerability.\nThe input should be the partial name of the Vulnerability to search for.\nWhen the input is a full CVE ID, the tool will provide information about the vulnerability.\nWhen the input is a partial name, the tool will provide a list of possible matches.","parameters":{"type":"object","properties":{"input":{"type":"string","description":"This tool can be used to get information about a Vulnerability.\nThe input should be the partial name of the Vulnerability to search for.\nWhen the input is a full CVE ID, the tool will provide information about the vulnerability.\nWhen the input is a partial name, the tool will provide a list of possible matches."}},"required":["input"]}}, - {"name":"AdvisoryInfo","description":"This tool can be used to get information about an Advisory.\nThe input should be the name of the Advisory to search for.\nWhen the input is a full name, the tool will provide information about the Advisory.\nWhen the input is a partial name, the tool will provide a list of possible matches.","parameters":{"type":"object","properties":{"input":{"type":"string","description":"This tool can be used to get information about an Advisory.\nThe input should be the name of the Advisory to search for.\nWhen the input is a full name, the tool will provide information about the Advisory.\nWhen the input is a partial name, the tool will provide a list of possible matches."}},"required":["input"]}}, - {"name":"PackageInfo","description":"This tool can be used to get information about a Package.\nThe input should be the name of the package, it's Identifier uri or internal UUID.","parameters":{"type":"object","properties":{"input":{"type":"string","description":"This tool can be used to get information about a Package.\nThe input should be the name of the package, it's Identifier uri or internal UUID."}},"required":["input"]}}, - {"name":"SbomInfo","description":"This tool can be used to get information about an SBOM.\nThe input should be the SBOM Identifier.","parameters":{"type":"object","properties":{"input":{"type":"string","description":"This tool can be used to get information about an SBOM.\nThe input should be the SBOM Identifier."}},"required":["input"]}} - ]) + { + "name": "product-info", + "description": "This tool can be used to get information about a product.\nThe input should be the name of the product to search for.\nWhen the input is a full name, the tool will provide information about the product.\nWhen the input is a partial name, the tool will provide a list of possible matches.", + "parameters": { + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "This tool can be used to get information about a product.\nThe input should be the name of the product to search for.\nWhen the input is a full name, the tool will provide information about the product.\nWhen the input is a partial name, the tool will provide a list of possible matches." + } + }, + "required": [ + "input" + ] + } + }, + { + "name": "cve-info", + "description": "This tool can be used to get information about a Vulnerability.\nThe input should be the partial name of the Vulnerability to search for.\nWhen the input is a full CVE ID, the tool will provide information about the vulnerability.\nWhen the input is a partial name, the tool will provide a list of possible matches.", + "parameters": { + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "This tool can be used to get information about a Vulnerability.\nThe input should be the partial name of the Vulnerability to search for.\nWhen the input is a full CVE ID, the tool will provide information about the vulnerability.\nWhen the input is a partial name, the tool will provide a list of possible matches." + } + }, + "required": [ + "input" + ] + } + }, + { + "name": "advisory-info", + "description": "This tool can be used to get information about an Advisory.\nThe input should be the name of the Advisory to search for.\nWhen the input is a full name, the tool will provide information about the Advisory.\nWhen the input is a partial name, the tool will provide a list of possible matches.", + "parameters": { + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "This tool can be used to get information about an Advisory.\nThe input should be the name of the Advisory to search for.\nWhen the input is a full name, the tool will provide information about the Advisory.\nWhen the input is a partial name, the tool will provide a list of possible matches." + } + }, + "required": [ + "input" + ] + } + }, + { + "name": "package-info", + "description": "This tool can be used to get information about a Package.\nThe input should be the name of the package, it's Identifier uri or internal UUID.", + "parameters": { + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "This tool can be used to get information about a Package.\nThe input should be the name of the package, it's Identifier uri or internal UUID." + } + }, + "required": [ + "input" + ] + } + }, + { + "name": "sbom-info", + "description": "This tool can be used to get information about an SBOM.\nThe input should be the SBOM Identifier.", + "parameters": { + "type": "object", + "properties": { + "input": { + "type": "string", + "description": "This tool can be used to get information about an SBOM.\nThe input should be the SBOM Identifier." + } + }, + "required": [ + "input" + ] + } + } + ]), + "result:\n{}", + serde_json::to_string_pretty(&result)? ); Ok(()) @@ -109,7 +186,15 @@ async fn tools_call(ctx: &TrustifyContext) -> anyhow::Result<()> { let app = caller(ctx).await?; let request = TestRequest::post() - .uri("/api/v1/ai/tools/ProductInfo") + .uri("/api/v1/ai/tools/unknown") + .set_json(json!("bad tool call")) + .to_request(); + + let response = app.call_service(request).await; + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + let request = TestRequest::post() + .uri("/api/v1/ai/tools/product-info") .set_json(json!("Trusted Profile Analyzer")) .to_request(); diff --git a/modules/fundamental/src/ai/service/mod.rs b/modules/fundamental/src/ai/service/mod.rs index 03bbd5a5..91690cb9 100644 --- a/modules/fundamental/src/ai/service/mod.rs +++ b/modules/fundamental/src/ai/service/mod.rs @@ -1,22 +1,14 @@ pub mod tools; -use crate::Error; - -use trustify_common::db::{Database, Transactional}; - use crate::ai::model::{ChatMessage, ChatState, LLMInfo, MessageType}; -use crate::ai::service::tools::{ - AdvisoryInfo, CVEInfo, PackageInfo, ProductInfo, SbomInfo, ToolLogger, -}; -use crate::product::service::ProductService; -use crate::vulnerability::service::VulnerabilityService; - +use crate::Error; use base64::engine::general_purpose::STANDARD; use base64::engine::Engine as _; use langchain_rust::chain::options::ChainCallOptions; use langchain_rust::chain::Chain; use langchain_rust::language_models::options::CallOptions; +use langchain_rust::schemas::{BaseMemory, Message}; use langchain_rust::tools::OpenAIConfig; use langchain_rust::{ agent::{AgentExecutor, OpenAiToolAgentBuilder}, @@ -25,14 +17,9 @@ use langchain_rust::{ prompt_args, tools::Tool, }; - use std::env; - -use crate::advisory::service::AdvisoryService; -use crate::purl::service::PurlService; -use crate::sbom::service::SbomService; -use langchain_rust::schemas::{BaseMemory, Message}; use std::sync::Arc; +use trustify_common::db::{Database, Transactional}; pub const PREFIX: &str = include_str!("prefix.txt"); @@ -85,13 +72,7 @@ impl AiService { /// ``` /// pub fn new(db: Database) -> Self { - let tools: Vec> = vec![ - Arc::new(ToolLogger(ProductInfo(ProductService::new(db.clone())))), - Arc::new(ToolLogger(CVEInfo(VulnerabilityService::new(db.clone())))), - Arc::new(ToolLogger(AdvisoryInfo(AdvisoryService::new(db.clone())))), - Arc::new(ToolLogger(PackageInfo(PurlService::new(db.clone())))), - Arc::new(ToolLogger(SbomInfo(SbomService::new(db.clone())))), - ]; + let tools = tools::new(db.clone()); let api_key = env::var("OPENAI_API_KEY"); let api_key = match api_key { diff --git a/modules/fundamental/src/ai/service/test.rs b/modules/fundamental/src/ai/service/test.rs index 3b2d468f..ee388023 100644 --- a/modules/fundamental/src/ai/service/test.rs +++ b/modules/fundamental/src/ai/service/test.rs @@ -1,15 +1,6 @@ -use crate::advisory::service::AdvisoryService; use crate::ai::model::ChatState; -use crate::ai::service::tools::{AdvisoryInfo, CVEInfo, PackageInfo, ProductInfo, SbomInfo}; use crate::ai::service::AiService; -use crate::product::service::ProductService; -use crate::purl::service::PurlService; -use crate::sbom::service::SbomService; -use crate::vulnerability::service::VulnerabilityService; -use langchain_rust::tools::Tool; -use serde_json::Value; -use std::error::Error; -use std::rc::Rc; + use test_context::test_context; use test_log::test; use trustify_common::db::Transactional; @@ -49,10 +40,6 @@ pub async fn ingest_fixtures(ctx: &TrustifyContext) -> Result<(), anyhow::Error> Ok(()) } -fn cleanup_tool_result(s: Result>) -> String { - sanitize_uuid(s.unwrap().trim().to_string()) -} - pub fn sanitize_uuid(value: String) -> String { let re = regex::Regex::new(r#""uuid": "\b[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}\b""#).unwrap(); re.replace_all( @@ -62,21 +49,6 @@ pub fn sanitize_uuid(value: String) -> String { .to_string() } -async fn assert_tool_contains( - tool: Rc, - input: &str, - expected: &str, -) -> Result<(), anyhow::Error> { - let actual = cleanup_tool_result(tool.run(Value::String(input.to_string())).await); - assert!( - actual.contains(expected.trim()), - "actual:\n{}\nexpected:\n{}\n", - actual, - expected - ); - Ok(()) -} - #[test_context(TrustifyContext)] #[test(actix_web::test)] async fn completions(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { @@ -97,241 +69,3 @@ async fn completions(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { Ok(()) } - -#[test_context(TrustifyContext)] -#[test(actix_web::test)] -async fn cve_info_tool(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - ingest_fixtures(ctx).await?; - let tool = Rc::new(CVEInfo(VulnerabilityService::new(ctx.db.clone()))); - assert_tool_contains( - tool.clone(), - "CVE-2021-32714", - r#" -{ - "title": "Integer Overflow in Chunked Transfer-Encoding", - "description": "hyper is an HTTP library for Rust. In versions prior to 0.14.10, hyper's HTTP server and client code had a flaw that could trigger an integer overflow when decoding chunk sizes that are too big. This allows possible data loss, or if combined with an upstream HTTP proxy that allows chunk sizes larger than hyper does, can result in \"request smuggling\" or \"desync attacks.\" The vulnerability is patched in version 0.14.10. Two possible workarounds exist. One may reject requests manually that contain a `Transfer-Encoding` header or ensure any upstream proxy rejects `Transfer-Encoding` chunk sizes greater than what fits in 64-bit unsigned integers.", - "severity": 9.1, - "score": 9.1, - "released": null, - "affected_packages": [ - { - "name": "pkg://cargo/hyper", - "version": "[0.0.0-0,0.14.10)" - } - ] -} -"#).await -} - -#[test_context(TrustifyContext)] -#[test(actix_web::test)] -async fn product_info_tool(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - ingest_fixtures(ctx).await?; - let tool = Rc::new(ProductInfo(ProductService::new(ctx.db.clone()))); - assert_tool_contains( - tool.clone(), - "Trusted Profile Analyzer", - r#" -{ - "items": [ - { - "name": "Trusted Profile Analyzer", - "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - "vendor": "Red Hat", - "versions": [ - "37.17.9" - ] - } - ], - "total": 1 -} -"#, - ) - .await -} - -#[test_context(TrustifyContext)] -#[test(actix_web::test)] -async fn advisory_info_tool(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - crate::advisory::service::test::ingest_and_link_advisory(ctx).await?; - crate::advisory::service::test::ingest_sample_advisory(ctx, "RHSA-2").await?; - - let tool = Rc::new(AdvisoryInfo(AdvisoryService::new(ctx.db.clone()))); - - assert_tool_contains( - tool.clone(), - "RHSA-1", - r#" -{ - "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - "identifier": "RHSA-1", - "issuer": null, - "title": "RHSA-1", - "score": 9.1, - "severity": "critical", - "vulnerabilities": [ - { - "identifier": "CVE-123", - "title": null, - "description": null, - "released": null - } - ] -} -"#, - ) - .await -} - -#[test_context(TrustifyContext)] -#[test(actix_web::test)] -async fn package_info_tool(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - ctx.ingest_document("ubi9-9.2-755.1697625012.json").await?; - ctx.ingest_document("quarkus-bom-2.13.8.Final-redhat-00004.json") - .await?; - - let tool = Rc::new(PackageInfo(PurlService::new(ctx.db.clone()))); - - assert_tool_contains( - tool.clone(), - "pkg:rpm/redhat/libsepol@3.5-1.el9?arch=s390x", - r#" -{ - "identifier": "pkg://rpm/redhat/libsepol@3.5-1.el9?arch=ppc64le", - "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - "name": "libsepol", - "version": "3.5-1.el9", - "advisories": [], - "licenses": [ - "LGPLV2+" - ] -} -"#, - ) - .await?; - - assert_tool_contains( - tool.clone(), - "1ca731c3-9596-534c-98eb-8dcc6ff7fef9", - r#" -{ - "identifier": "pkg://rpm/redhat/libsepol@3.5-1.el9?arch=ppc64le", - "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - "name": "libsepol", - "version": "3.5-1.el9", - "advisories": [], - "licenses": [ - "LGPLV2+" - ] -} -"#, - ) - .await?; - - assert_tool_contains( - tool.clone(), - "pkg:maven/org.jboss.logging/commons-logging-jboss-logging@1.0.0.Final-redhat-1?repository_url=https://maven.repository.redhat.com/ga/&type=jar", - r#" -{ - "identifier": "pkg://maven/org.jboss.logging/commons-logging-jboss-logging@1.0.0.Final-redhat-1?repository_url=https://maven.repository.redhat.com/ga/&type=jar", - "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - "name": "commons-logging-jboss-logging", - "version": "1.0.0.Final-redhat-1", - "advisories": [], - "licenses": [ - "APACHE-2.0" - ] -} -"#).await?; - - assert_tool_contains( - tool.clone(), - "commons-logging-jboss-logging", - r#" -{ - "identifier": "pkg://maven/org.jboss.logging/commons-logging-jboss-logging@1.0.0.Final-redhat-1?repository_url=https://maven.repository.redhat.com/ga/&type=jar", - "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - "name": "commons-logging-jboss-logging", - "version": "1.0.0.Final-redhat-1", - "advisories": [], - "licenses": [ - "APACHE-2.0" - ] -} -"#).await?; - - assert_tool_contains( - tool.clone(), - "quarkus-resteasy-reactive-json", - r#" -There are multiple that match: - -{ - "items": [ - { - "identifier": "pkg://maven/io.quarkus/quarkus-resteasy-reactive-jsonb-common@2.13.8.Final-redhat-00004?repository_url=https://maven.repository.redhat.com/ga/&type=jar", - "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - "name": "quarkus-resteasy-reactive-jsonb-common", - "version": "2.13.8.Final-redhat-00004" - }, - { - "identifier": "pkg://maven/io.quarkus/quarkus-resteasy-reactive-jsonb@2.13.8.Final-redhat-00004?repository_url=https://maven.repository.redhat.com/ga/&type=jar", - "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - "name": "quarkus-resteasy-reactive-jsonb", - "version": "2.13.8.Final-redhat-00004" - }, - { - "identifier": "pkg://maven/io.quarkus/quarkus-resteasy-reactive-jsonb-common-deployment@2.13.8.Final-redhat-00004?repository_url=https://maven.repository.redhat.com/ga/&type=jar", - "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - "name": "quarkus-resteasy-reactive-jsonb-common-deployment", - "version": "2.13.8.Final-redhat-00004" - }, - { - "identifier": "pkg://maven/io.quarkus/quarkus-resteasy-reactive-jsonb-deployment@2.13.8.Final-redhat-00004?repository_url=https://maven.repository.redhat.com/ga/&type=jar", - "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - "name": "quarkus-resteasy-reactive-jsonb-deployment", - "version": "2.13.8.Final-redhat-00004" - } - ], - "total": 4 -} -"#).await -} - -#[test_context(TrustifyContext)] -#[test(actix_web::test)] -async fn sbom_info_tool(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { - ctx.ingest_document("ubi9-9.2-755.1697625012.json").await?; - ctx.ingest_document("quarkus-bom-2.13.8.Final-redhat-00004.json") - .await?; - - let tool = Rc::new(SbomInfo(SbomService::new(ctx.db.clone()))); - - assert_tool_contains( - tool.clone(), - "quarkus", - r#" -{ - "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", - "source_document_sha256": "sha256:5a370574a991aa42f7ecc5b7d88754b258f81c230a73bea247c0a6fcc6f608ab", - "name": "quarkus-bom", - "published": "2023-11-13T00:10:00Z", - "authors": [ - "Organization: Red Hat Product Security (secalert@redhat.com)" - ], - "labels": [ - [ - "source", - "TrustifyContext" - ], - [ - "type", - "spdx" - ] - ], - "advisories": [] -} -"#, - ) - .await -} diff --git a/modules/fundamental/src/ai/service/tools.rs b/modules/fundamental/src/ai/service/tools.rs deleted file mode 100644 index 20876f72..00000000 --- a/modules/fundamental/src/ai/service/tools.rs +++ /dev/null @@ -1,660 +0,0 @@ -use crate::{ - advisory::service::AdvisoryService, product::service::ProductService, - purl::service::PurlService, sbom::service::SbomService, - vulnerability::service::VulnerabilityService, -}; -use anyhow::anyhow; -use async_trait::async_trait; -use itertools::Itertools; -use langchain_rust::tools::Tool; -use serde::Serialize; -use serde_json::Value; -use std::{error::Error, fmt::Write, str::FromStr}; -use time::OffsetDateTime; -use trustify_common::model::PaginatedResults; -use trustify_common::{db::query::Query, id::Id, purl::Purl}; -use trustify_module_ingestor::common::Deprecation; -use uuid::Uuid; - -fn to_json(value: &T) -> Result> -where - T: Serialize, -{ - #[cfg(test)] - { - serde_json::to_string_pretty(&value).map_err(|e| e.into()) - } - #[cfg(not(test))] - { - serde_json::to_string(&value).map_err(|e| e.into()) - } -} - -fn paginated_to_json(p: PaginatedResults, f: fn(&A) -> T) -> Result> -where - T: Serialize, -{ - to_json(&PaginatedResults { - items: p.items.iter().map(f).collect(), - total: p.total, - }) -} - -pub struct ToolLogger(pub T); - -#[async_trait] -impl Tool for ToolLogger { - fn name(&self) -> String { - self.0.name() - } - - fn description(&self) -> String { - self.0.description() - } - - fn parameters(&self) -> Value { - self.0.parameters() - } - - async fn call(&self, input: &str) -> Result> { - log::info!(" tool call: {}, input: {}", self.name(), input); - let result = self.0.call(input).await; - match &result { - Ok(result) => { - log::info!(" ok: {}", result); - } - Err(err) => { - log::info!(" err: {}", err); - } - } - result - } - - async fn run(&self, input: Value) -> Result> { - self.0.run(input).await - } - - async fn parse_input(&self, input: &str) -> Value { - self.0.parse_input(input).await - } -} - -pub struct ProductInfo(pub ProductService); - -#[async_trait] -impl Tool for ProductInfo { - fn name(&self) -> String { - String::from("ProductInfo") - } - - fn description(&self) -> String { - String::from( - r##" -This tool can be used to get information about a product. -The input should be the name of the product to search for. -When the input is a full name, the tool will provide information about the product. -When the input is a partial name, the tool will provide a list of possible matches. -"## - .trim(), - ) - } - - async fn run(&self, input: Value) -> Result> { - let service = &self.0; - let input = input - .as_str() - .ok_or("Input should be a string")? - .to_string(); - - let results = service - .fetch_products( - Query { - q: input, - ..Default::default() - }, - Default::default(), - (), - ) - .await?; - - if results.items.is_empty() { - return Err(anyhow!("I don't know").into()); - } - - #[derive(Serialize)] - struct Product { - name: String, - uuid: Uuid, - vendor: Option, - versions: Vec, - } - paginated_to_json(results, |item| Product { - name: item.head.name.clone(), - uuid: item.head.id, - vendor: item.vendor.clone().map(|v| v.head.name), - versions: item.versions.iter().map(|v| v.version.clone()).collect(), - }) - } -} - -pub struct CVEInfo(pub VulnerabilityService); - -#[async_trait] -impl Tool for CVEInfo { - fn name(&self) -> String { - String::from("CVEInfo") - } - - fn description(&self) -> String { - String::from( - r##" -This tool can be used to get information about a Vulnerability. -The input should be the partial name of the Vulnerability to search for. -When the input is a full CVE ID, the tool will provide information about the vulnerability. -When the input is a partial name, the tool will provide a list of possible matches. -"## - .trim(), - ) - } - - async fn run(&self, input: Value) -> Result> { - let service = &self.0; - - let input = input - .as_str() - .ok_or("Input should be a string")? - .to_string(); - - let item = match service - .fetch_vulnerability(input.as_str(), Deprecation::Ignore, ()) - .await? - { - Some(v) => v, - None => { - // search for possible matches - let results = service - .fetch_vulnerabilities( - Query { - q: input.clone(), - ..Default::default() - }, - Default::default(), - Deprecation::Ignore, - (), - ) - .await?; - - if results.items.is_empty() { - return Err(anyhow!("I don't know").into()); - } - - // let the caller know what the possible matches are - if results.items.len() > 1 { - #[derive(Serialize)] - struct Item { - identifier: String, - name: Option, - } - - let json = paginated_to_json(results, |item| Item { - identifier: item.head.identifier.clone(), - name: item.head.title.clone(), - })?; - return Ok(format!("There are multiple that match:\n\n{}", json)); - } - - // let's show the details for the one that matched. - if let Some(v) = service - .fetch_vulnerability( - results.items[0].head.identifier.as_str(), - Deprecation::Ignore, - (), - ) - .await? - { - v - } else { - return Err(anyhow!("I don't know").into()); - } - } - }; - - #[derive(Serialize)] - struct Item { - title: Option, - description: Option, - severity: Option, - score: Option, - #[serde(with = "time::serde::rfc3339::option")] - released: Option, - affected_packages: Vec, - } - #[derive(Serialize)] - struct Package { - name: Purl, - version: String, - } - - let affected_packages = item - .advisories - .iter() - .flat_map(|v| { - v.purls - .get("affected") - .into_iter() - .flatten() - .map(|v| Package { - name: v.base_purl.purl.clone(), - version: v.version.clone(), - }) - }) - .collect(); - let json = to_json(&Item { - title: item.head.title.clone(), - description: item.head.description.clone(), - severity: item.average_score, - score: item.average_score, - released: item.head.released, - affected_packages, - })?; - - let mut result = "".to_string(); - if item.head.identifier != input { - writeln!(result, "There is one match, but it had a different identifier. Inform the user that that you are providing information on: {}\n", item.head.identifier)?; - } - writeln!(result, "{}", json)?; - Ok(result) - } -} - -pub struct AdvisoryInfo(pub AdvisoryService); - -#[async_trait] -impl Tool for AdvisoryInfo { - fn name(&self) -> String { - String::from("AdvisoryInfo") - } - - fn description(&self) -> String { - String::from( - r##" -This tool can be used to get information about an Advisory. -The input should be the name of the Advisory to search for. -When the input is a full name, the tool will provide information about the Advisory. -When the input is a partial name, the tool will provide a list of possible matches. -"## - .trim(), - ) - } - - async fn run(&self, input: Value) -> Result> { - let service = &self.0; - - let input = input - .as_str() - .ok_or("Input should be a string")? - .to_string(); - - // search for possible matches - let results = service - .fetch_advisories( - Query { - q: input, - ..Default::default() - }, - Default::default(), - Deprecation::Ignore, - (), - ) - .await?; - - if results.items.is_empty() { - return Err(anyhow!("I don't know").into()); - } - - // let the caller know what the possible matches are - if results.items.len() > 1 { - #[derive(Serialize)] - struct Item { - identifier: String, - title: Option, - } - - let json = paginated_to_json(results, |item| Item { - identifier: item.head.identifier.clone(), - title: item.head.title.clone(), - })?; - return Ok(format!("There are multiple that match:\n\n{}", json)); - } - - // let's show the details - let item = match service - .fetch_advisory(Id::Uuid(results.items[0].head.uuid), ()) - .await? - { - Some(v) => v, - None => return Err(anyhow!("I don't know").into()), - }; - - #[derive(Serialize)] - struct Item { - uuid: Uuid, - identifier: String, - issuer: Option, - title: Option, - score: Option, - severity: Option, - vulnerabilities: Vec, - } - #[derive(Serialize)] - struct Vulnerability { - identifier: String, - title: Option, - description: Option, - #[serde(with = "time::serde::rfc3339::option")] - released: Option, - } - - let vulnerabilities = item - .vulnerabilities - .iter() - .map(|v| Vulnerability { - identifier: v.head.head.identifier.clone(), - title: v.head.head.title.clone(), - description: v.head.head.description.clone(), - released: v.head.head.released, - }) - .collect(); - - to_json(&Item { - uuid: item.head.uuid, - - identifier: item.head.identifier.clone(), - issuer: item.head.issuer.clone().map(|v| v.head.name), - title: item.head.title.clone(), - score: item.average_score, - severity: item.average_severity.map(|v| v.to_string()), - vulnerabilities, - }) - } -} - -pub struct PackageInfo(pub PurlService); - -#[async_trait] -impl Tool for PackageInfo { - fn name(&self) -> String { - String::from("PackageInfo") - } - - fn description(&self) -> String { - String::from( - r##" -This tool can be used to get information about a Package. -The input should be the name of the package, it's Identifier uri or internal UUID. -"## - .trim(), - ) - } - - async fn run(&self, input: Value) -> Result> { - let service = &self.0; - - let input = input - .as_str() - .ok_or("Input should be a string")? - .to_string(); - - // Try lookup as a PURL - let mut purl_details = match Purl::try_from(input.clone()) { - Err(_) => None, - Ok(purl) => service.purl_by_purl(&purl, Deprecation::Ignore, ()).await?, - }; - - // Try lookup as a UUID - if purl_details.is_none() { - purl_details = match Uuid::parse_str(input.as_str()) { - Err(_) => None, - Ok(uuid) => service.purl_by_uuid(&uuid, Deprecation::Ignore, ()).await?, - }; - } - - // Fallback to search - if purl_details.is_none() { - // try to search for possible matches - let results = service - .purls( - Query { - q: input, - ..Default::default() - }, - Default::default(), - (), - ) - .await?; - - purl_details = match results.items.len() { - 0 => None, - 1 => { - service - .purl_by_uuid(&results.items[0].head.uuid, Deprecation::Ignore, ()) - .await? - } - _ => { - #[derive(Serialize)] - struct Item { - identifier: Purl, - uuid: Uuid, - name: String, - version: Option, - } - - let json = paginated_to_json(results, |item| Item { - identifier: item.head.purl.clone(), - uuid: item.head.uuid, - name: item.head.purl.name.clone(), - version: item.head.purl.version.clone(), - })?; - return Ok(format!("There are multiple that match:\n\n{}", json)); - } - }; - } - - let item = match purl_details { - Some(v) => v, - None => return Err(anyhow!("I don't know").into()), - }; - - #[derive(Serialize)] - struct Item { - identifier: Purl, - uuid: Uuid, - name: String, - version: Option, - advisories: Vec, - licenses: Vec, - } - - #[derive(Serialize)] - struct Advisory { - uuid: Uuid, - identifier: String, - issuer: Option, - vulnerabilities: Vec, - } - - #[derive(Serialize)] - struct Vulnerability { - identifier: String, - title: Option, - status: String, - } - - to_json(&Item { - identifier: item.head.purl.clone(), - uuid: item.head.uuid, - name: item.head.purl.name.clone(), - version: item.head.purl.version.clone(), - - advisories: item - .advisories - .iter() - .map(|advisory| Advisory { - uuid: advisory.head.uuid, - identifier: advisory.head.identifier.clone(), - issuer: advisory.head.issuer.clone().map(|v| v.head.name.clone()), - vulnerabilities: advisory - .status - .iter() - .map(|status| Vulnerability { - identifier: status.vulnerability.identifier.clone(), - title: status.vulnerability.title.clone(), - status: status.status.clone(), - }) - .collect(), - }) - .collect(), - - licenses: item - .licenses - .iter() - .flat_map(|v| v.licenses.iter()) - .cloned() - .collect(), - }) - } -} - -pub struct SbomInfo(pub SbomService); - -#[async_trait] -impl Tool for SbomInfo { - fn name(&self) -> String { - String::from("SbomInfo") - } - - fn description(&self) -> String { - String::from( - r##" -This tool can be used to get information about an SBOM. -The input should be the SBOM Identifier. -"## - .trim(), - ) - } - - async fn run(&self, input: Value) -> Result> { - let service = &self.0; - - let input = input - .as_str() - .ok_or("Input should be a string")? - .to_string(); - - // Try lookup as a UUID - let mut sbom_details = match Id::from_str(input.as_str()) { - Err(_) => None, - Ok(id) => service.fetch_sbom_details(id, ()).await?, - }; - - // Fallback to search - if sbom_details.is_none() { - // try to search for possible matches - let results = service - .fetch_sboms( - Query { - q: input, - ..Default::default() - }, - Default::default(), - (), - (), - ) - .await?; - - sbom_details = match results.items.len() { - 0 => None, - 1 => { - service - .fetch_sbom_details(Id::Uuid(results.items[0].head.id), ()) - .await? - } - _ => { - #[derive(Serialize)] - struct Item { - uuid: Uuid, - source_document_sha256: String, - name: String, - #[serde(with = "time::serde::rfc3339::option")] - published: Option, - } - - let json = paginated_to_json(results, |item| Item { - uuid: item.head.id, - source_document_sha256: item - .source_document - .as_ref() - .map(|v| v.sha256.clone()) - .unwrap_or_default(), - name: item.head.name.clone(), - published: item.head.published, - })?; - return Ok(format!("There are multiple that match:\n\n{}", json)); - } - }; - } - - let item = match sbom_details { - Some(v) => v, - None => return Err(anyhow!("I don't know").into()), - }; - - #[derive(Serialize)] - struct Item { - uuid: Uuid, - source_document_sha256: String, - name: String, - #[serde(with = "time::serde::rfc3339::option")] - published: Option, - authors: Vec, - labels: Vec<(String, String)>, - advisories: Vec, - } - - #[derive(Serialize)] - struct Advisory { - uuid: Uuid, - identifier: String, - issuer: Option, - } - - let mut labels = item.summary.head.labels.iter().collect_vec(); - labels.sort_by(|a, b| a.0.cmp(b.0)); - - to_json(&Item { - uuid: item.summary.head.id, - source_document_sha256: item - .summary - .source_document - .as_ref() - .map(|v| v.sha256.clone()) - .unwrap_or_default(), - name: item.summary.head.name.clone(), - published: item.summary.head.published, - authors: item.summary.head.authors.clone(), - labels: labels - .iter() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(), - advisories: item - .advisories - .iter() - .map(|advisory| Advisory { - uuid: advisory.head.uuid, - identifier: advisory.head.identifier.clone(), - issuer: advisory.head.issuer.clone().map(|v| v.head.name.clone()), - }) - .collect(), - }) - } -} diff --git a/modules/fundamental/src/ai/service/tools/advisory_info.rs b/modules/fundamental/src/ai/service/tools/advisory_info.rs new file mode 100644 index 00000000..bc995e72 --- /dev/null +++ b/modules/fundamental/src/ai/service/tools/advisory_info.rs @@ -0,0 +1,168 @@ +use crate::advisory::service::AdvisoryService; +use crate::ai::service::tools; +use anyhow::anyhow; +use async_trait::async_trait; +use langchain_rust::tools::Tool; +use serde::Serialize; +use serde_json::Value; +use std::error::Error; +use time::OffsetDateTime; +use trustify_common::db::query::Query; +use trustify_common::id::Id; +use trustify_module_ingestor::common::Deprecation; +use uuid::Uuid; + +pub struct AdvisoryInfo(pub AdvisoryService); + +#[async_trait] +impl Tool for AdvisoryInfo { + fn name(&self) -> String { + String::from("advisory-info") + } + + fn description(&self) -> String { + String::from( + r##" +This tool can be used to get information about an Advisory. +The input should be the name of the Advisory to search for. +When the input is a full name, the tool will provide information about the Advisory. +When the input is a partial name, the tool will provide a list of possible matches. +"## + .trim(), + ) + } + + async fn run(&self, input: Value) -> Result> { + let service = &self.0; + + let input = input + .as_str() + .ok_or("Input should be a string")? + .to_string(); + + // search for possible matches + let results = service + .fetch_advisories( + Query { + q: input, + ..Default::default() + }, + Default::default(), + Deprecation::Ignore, + (), + ) + .await?; + + if results.items.is_empty() { + return Err(anyhow!("I don't know").into()); + } + + // let the caller know what the possible matches are + if results.items.len() > 1 { + #[derive(Serialize)] + struct Item { + identifier: String, + title: Option, + } + + let json = tools::paginated_to_json(results, |item| Item { + identifier: item.head.identifier.clone(), + title: item.head.title.clone(), + })?; + return Ok(format!("There are multiple that match:\n\n{}", json)); + } + + // let's show the details + let item = match service + .fetch_advisory(Id::Uuid(results.items[0].head.uuid), ()) + .await? + { + Some(v) => v, + None => return Err(anyhow!("I don't know").into()), + }; + + #[derive(Serialize)] + struct Item { + uuid: Uuid, + identifier: String, + issuer: Option, + title: Option, + score: Option, + severity: Option, + vulnerabilities: Vec, + } + #[derive(Serialize)] + struct Vulnerability { + identifier: String, + title: Option, + description: Option, + #[serde(with = "time::serde::rfc3339::option")] + released: Option, + } + + let vulnerabilities = item + .vulnerabilities + .iter() + .map(|v| Vulnerability { + identifier: v.head.head.identifier.clone(), + title: v.head.head.title.clone(), + description: v.head.head.description.clone(), + released: v.head.head.released, + }) + .collect(); + + tools::to_json(&Item { + uuid: item.head.uuid, + + identifier: item.head.identifier.clone(), + issuer: item.head.issuer.clone().map(|v| v.head.name), + title: item.head.title.clone(), + score: item.average_score, + severity: item.average_severity.map(|v| v.to_string()), + vulnerabilities, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ai::service::tools::tests::assert_tool_contains; + use std::rc::Rc; + use test_context::test_context; + use test_log::test; + use trustify_test_context::TrustifyContext; + + #[test_context(TrustifyContext)] + #[test(actix_web::test)] + async fn advisory_info_tool(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { + crate::advisory::service::test::ingest_and_link_advisory(ctx).await?; + crate::advisory::service::test::ingest_sample_advisory(ctx, "RHSA-2").await?; + + let tool = Rc::new(AdvisoryInfo(AdvisoryService::new(ctx.db.clone()))); + + assert_tool_contains( + tool.clone(), + "RHSA-1", + r#" +{ + "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "identifier": "RHSA-1", + "issuer": null, + "title": "RHSA-1", + "score": 9.1, + "severity": "critical", + "vulnerabilities": [ + { + "identifier": "CVE-123", + "title": null, + "description": null, + "released": null + } + ] +} +"#, + ) + .await + } +} diff --git a/modules/fundamental/src/ai/service/tools/cve_info.rs b/modules/fundamental/src/ai/service/tools/cve_info.rs new file mode 100644 index 00000000..5b8d825a --- /dev/null +++ b/modules/fundamental/src/ai/service/tools/cve_info.rs @@ -0,0 +1,179 @@ +use crate::ai::service::tools; +use crate::vulnerability::service::VulnerabilityService; +use anyhow::anyhow; +use async_trait::async_trait; +use langchain_rust::tools::Tool; +use serde::Serialize; +use serde_json::Value; +use std::error::Error; +use std::fmt::Write; +use time::OffsetDateTime; +use trustify_common::db::query::Query; +use trustify_common::purl::Purl; +use trustify_module_ingestor::common::Deprecation; + +pub struct CVEInfo(pub VulnerabilityService); + +#[async_trait] +impl Tool for CVEInfo { + fn name(&self) -> String { + String::from("cve-info") + } + + fn description(&self) -> String { + String::from( + r##" +This tool can be used to get information about a Vulnerability. +The input should be the partial name of the Vulnerability to search for. +When the input is a full CVE ID, the tool will provide information about the vulnerability. +When the input is a partial name, the tool will provide a list of possible matches. +"## + .trim(), + ) + } + + async fn run(&self, input: Value) -> Result> { + let service = &self.0; + + let input = input + .as_str() + .ok_or("Input should be a string")? + .to_string(); + + let item = match service + .fetch_vulnerability(input.as_str(), Deprecation::Ignore, ()) + .await? + { + Some(v) => v, + None => { + // search for possible matches + let results = service + .fetch_vulnerabilities( + Query { + q: input.clone(), + ..Default::default() + }, + Default::default(), + Deprecation::Ignore, + (), + ) + .await?; + + if results.items.is_empty() { + return Err(anyhow!("I don't know").into()); + } + + // let the caller know what the possible matches are + if results.items.len() > 1 { + #[derive(Serialize)] + struct Item { + identifier: String, + name: Option, + } + + let json = tools::paginated_to_json(results, |item| Item { + identifier: item.head.identifier.clone(), + name: item.head.title.clone(), + })?; + return Ok(format!("There are multiple that match:\n\n{}", json)); + } + + // let's show the details for the one that matched. + if let Some(v) = service + .fetch_vulnerability( + results.items[0].head.identifier.as_str(), + Deprecation::Ignore, + (), + ) + .await? + { + v + } else { + return Err(anyhow!("I don't know").into()); + } + } + }; + + #[derive(Serialize)] + struct Item { + title: Option, + description: Option, + severity: Option, + score: Option, + #[serde(with = "time::serde::rfc3339::option")] + released: Option, + affected_packages: Vec, + } + #[derive(Serialize)] + struct Package { + name: Purl, + version: String, + } + + let affected_packages = item + .advisories + .iter() + .flat_map(|v| { + v.purls + .get("affected") + .into_iter() + .flatten() + .map(|v| Package { + name: v.base_purl.purl.clone(), + version: v.version.clone(), + }) + }) + .collect(); + let json = tools::to_json(&Item { + title: item.head.title.clone(), + description: item.head.description.clone(), + severity: item.average_score, + score: item.average_score, + released: item.head.released, + affected_packages, + })?; + + let mut result = "".to_string(); + if item.head.identifier != input { + writeln!(result, "There is one match, but it had a different identifier. Inform the user that that you are providing information on: {}\n", item.head.identifier)?; + } + writeln!(result, "{}", json)?; + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ai::service::test::ingest_fixtures; + use crate::ai::service::tools::tests::assert_tool_contains; + use std::rc::Rc; + use test_context::test_context; + use test_log::test; + use trustify_test_context::TrustifyContext; + + #[test_context(TrustifyContext)] + #[test(actix_web::test)] + async fn cve_info_tool(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { + ingest_fixtures(ctx).await?; + let tool = Rc::new(CVEInfo(VulnerabilityService::new(ctx.db.clone()))); + assert_tool_contains( + tool.clone(), + "CVE-2021-32714", + r#" +{ + "title": "Integer Overflow in Chunked Transfer-Encoding", + "description": "hyper is an HTTP library for Rust. In versions prior to 0.14.10, hyper's HTTP server and client code had a flaw that could trigger an integer overflow when decoding chunk sizes that are too big. This allows possible data loss, or if combined with an upstream HTTP proxy that allows chunk sizes larger than hyper does, can result in \"request smuggling\" or \"desync attacks.\" The vulnerability is patched in version 0.14.10. Two possible workarounds exist. One may reject requests manually that contain a `Transfer-Encoding` header or ensure any upstream proxy rejects `Transfer-Encoding` chunk sizes greater than what fits in 64-bit unsigned integers.", + "severity": 9.1, + "score": 9.1, + "released": null, + "affected_packages": [ + { + "name": "pkg://cargo/hyper", + "version": "[0.0.0-0,0.14.10)" + } + ] +} +"#).await + } +} diff --git a/modules/fundamental/src/ai/service/tools/logger.rs b/modules/fundamental/src/ai/service/tools/logger.rs new file mode 100644 index 00000000..b89118c2 --- /dev/null +++ b/modules/fundamental/src/ai/service/tools/logger.rs @@ -0,0 +1,43 @@ +use async_trait::async_trait; +use langchain_rust::tools::Tool; +use serde_json::Value; +use std::error::Error; + +pub struct ToolLogger(pub T); + +#[async_trait] +impl Tool for ToolLogger { + fn name(&self) -> String { + self.0.name() + } + + fn description(&self) -> String { + self.0.description() + } + + fn parameters(&self) -> Value { + self.0.parameters() + } + + async fn call(&self, input: &str) -> Result> { + log::info!(" tool call: {}, input: {}", self.name(), input); + let result = self.0.call(input).await; + match &result { + Ok(result) => { + log::info!(" ok: {}", result); + } + Err(err) => { + log::info!(" err: {}", err); + } + } + result + } + + async fn run(&self, input: Value) -> Result> { + self.0.run(input).await + } + + async fn parse_input(&self, input: &str) -> Value { + self.0.parse_input(input).await + } +} diff --git a/modules/fundamental/src/ai/service/tools/mod.rs b/modules/fundamental/src/ai/service/tools/mod.rs new file mode 100644 index 00000000..9d15787d --- /dev/null +++ b/modules/fundamental/src/ai/service/tools/mod.rs @@ -0,0 +1,90 @@ +use crate::advisory::service::AdvisoryService; +use crate::ai::service::tools::advisory_info::AdvisoryInfo; +use crate::ai::service::tools::cve_info::CVEInfo; +use crate::ai::service::tools::logger::ToolLogger; +use crate::ai::service::tools::package_info::PackageInfo; +use crate::ai::service::tools::product_info::ProductInfo; +use crate::ai::service::tools::sbom_info::SbomInfo; +use crate::product::service::ProductService; +use crate::purl::service::PurlService; +use crate::sbom::service::SbomService; +use crate::vulnerability::service::VulnerabilityService; +use langchain_rust::tools::Tool; +use serde::Serialize; +use std::error::Error; +use std::sync::Arc; +use trustify_common::db::Database; +use trustify_common::model::PaginatedResults; + +pub mod advisory_info; +pub mod cve_info; +pub mod logger; +pub mod package_info; +pub mod product_info; +pub mod sbom_info; + +pub fn new(db: Database) -> Vec> { + vec![ + Arc::new(ToolLogger(ProductInfo(ProductService::new(db.clone())))), + Arc::new(ToolLogger(CVEInfo(VulnerabilityService::new(db.clone())))), + Arc::new(ToolLogger(AdvisoryInfo(AdvisoryService::new(db.clone())))), + Arc::new(ToolLogger(PackageInfo(PurlService::new(db.clone())))), + Arc::new(ToolLogger(SbomInfo(SbomService::new(db.clone())))), + ] +} + +pub fn to_json(value: &T) -> Result> +where + T: Serialize, +{ + #[cfg(test)] + { + serde_json::to_string_pretty(&value).map_err(|e| e.into()) + } + + #[cfg(not(test))] + { + serde_json::to_string(&value).map_err(|e| e.into()) + } +} + +pub fn paginated_to_json( + p: PaginatedResults, + f: fn(&A) -> T, +) -> Result> +where + T: Serialize, +{ + to_json(&PaginatedResults { + items: p.items.iter().map(f).collect(), + total: p.total, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ai::service::test::sanitize_uuid; + use langchain_rust::tools::Tool; + use serde_json::Value; + use std::rc::Rc; + + pub fn cleanup_tool_result(s: Result>) -> String { + sanitize_uuid(s.unwrap().trim().to_string()) + } + + pub async fn assert_tool_contains( + tool: Rc, + input: &str, + expected: &str, + ) -> Result<(), anyhow::Error> { + let actual = cleanup_tool_result(tool.run(Value::String(input.to_string())).await); + assert!( + actual.contains(expected.trim()), + "actual:\n{}\nexpected:\n{}\n", + actual, + expected + ); + Ok(()) + } +} diff --git a/modules/fundamental/src/ai/service/tools/package_info.rs b/modules/fundamental/src/ai/service/tools/package_info.rs new file mode 100644 index 00000000..68b6d860 --- /dev/null +++ b/modules/fundamental/src/ai/service/tools/package_info.rs @@ -0,0 +1,283 @@ +use crate::ai::service::tools; +use crate::purl::service::PurlService; +use anyhow::anyhow; +use async_trait::async_trait; +use langchain_rust::tools::Tool; +use serde::Serialize; +use serde_json::Value; +use std::error::Error; +use trustify_common::db::query::Query; +use trustify_common::purl::Purl; +use trustify_module_ingestor::common::Deprecation; +use uuid::Uuid; + +pub struct PackageInfo(pub PurlService); + +#[async_trait] +impl Tool for PackageInfo { + fn name(&self) -> String { + String::from("package-info") + } + + fn description(&self) -> String { + String::from( + r##" +This tool can be used to get information about a Package. +The input should be the name of the package, it's Identifier uri or internal UUID. +"## + .trim(), + ) + } + + async fn run(&self, input: Value) -> Result> { + let service = &self.0; + + let input = input + .as_str() + .ok_or("Input should be a string")? + .to_string(); + + // Try lookup as a PURL + let mut purl_details = match Purl::try_from(input.clone()) { + Err(_) => None, + Ok(purl) => service.purl_by_purl(&purl, Deprecation::Ignore, ()).await?, + }; + + // Try lookup as a UUID + if purl_details.is_none() { + purl_details = match Uuid::parse_str(input.as_str()) { + Err(_) => None, + Ok(uuid) => service.purl_by_uuid(&uuid, Deprecation::Ignore, ()).await?, + }; + } + + // Fallback to search + if purl_details.is_none() { + // try to search for possible matches + let results = service + .purls( + Query { + q: input, + ..Default::default() + }, + Default::default(), + (), + ) + .await?; + + purl_details = match results.items.len() { + 0 => None, + 1 => { + service + .purl_by_uuid(&results.items[0].head.uuid, Deprecation::Ignore, ()) + .await? + } + _ => { + #[derive(Serialize)] + struct Item { + identifier: Purl, + uuid: Uuid, + name: String, + version: Option, + } + + let json = tools::paginated_to_json(results, |item| Item { + identifier: item.head.purl.clone(), + uuid: item.head.uuid, + name: item.head.purl.name.clone(), + version: item.head.purl.version.clone(), + })?; + return Ok(format!("There are multiple that match:\n\n{}", json)); + } + }; + } + + let item = match purl_details { + Some(v) => v, + None => return Err(anyhow!("I don't know").into()), + }; + + #[derive(Serialize)] + struct Item { + identifier: Purl, + uuid: Uuid, + name: String, + version: Option, + advisories: Vec, + licenses: Vec, + } + + #[derive(Serialize)] + struct Advisory { + uuid: Uuid, + identifier: String, + issuer: Option, + vulnerabilities: Vec, + } + + #[derive(Serialize)] + struct Vulnerability { + identifier: String, + title: Option, + status: String, + } + + tools::to_json(&Item { + identifier: item.head.purl.clone(), + uuid: item.head.uuid, + name: item.head.purl.name.clone(), + version: item.head.purl.version.clone(), + + advisories: item + .advisories + .iter() + .map(|advisory| Advisory { + uuid: advisory.head.uuid, + identifier: advisory.head.identifier.clone(), + issuer: advisory.head.issuer.clone().map(|v| v.head.name.clone()), + vulnerabilities: advisory + .status + .iter() + .map(|status| Vulnerability { + identifier: status.vulnerability.identifier.clone(), + title: status.vulnerability.title.clone(), + status: status.status.clone(), + }) + .collect(), + }) + .collect(), + + licenses: item + .licenses + .iter() + .flat_map(|v| v.licenses.iter()) + .cloned() + .collect(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ai::service::tools::tests::assert_tool_contains; + use std::rc::Rc; + use test_context::test_context; + use test_log::test; + use trustify_test_context::TrustifyContext; + + #[test_context(TrustifyContext)] + #[test(actix_web::test)] + async fn package_info_tool(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { + ctx.ingest_document("ubi9-9.2-755.1697625012.json").await?; + ctx.ingest_document("quarkus-bom-2.13.8.Final-redhat-00004.json") + .await?; + + let tool = Rc::new(PackageInfo(PurlService::new(ctx.db.clone()))); + + assert_tool_contains( + tool.clone(), + "pkg:rpm/redhat/libsepol@3.5-1.el9?arch=s390x", + r#" +{ + "identifier": "pkg://rpm/redhat/libsepol@3.5-1.el9?arch=ppc64le", + "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "name": "libsepol", + "version": "3.5-1.el9", + "advisories": [], + "licenses": [ + "LGPLV2+" + ] +} +"#, + ) + .await?; + + assert_tool_contains( + tool.clone(), + "1ca731c3-9596-534c-98eb-8dcc6ff7fef9", + r#" +{ + "identifier": "pkg://rpm/redhat/libsepol@3.5-1.el9?arch=ppc64le", + "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "name": "libsepol", + "version": "3.5-1.el9", + "advisories": [], + "licenses": [ + "LGPLV2+" + ] +} +"#, + ) + .await?; + + assert_tool_contains( + tool.clone(), + "pkg:maven/org.jboss.logging/commons-logging-jboss-logging@1.0.0.Final-redhat-1?repository_url=https://maven.repository.redhat.com/ga/&type=jar", + r#" +{ + "identifier": "pkg://maven/org.jboss.logging/commons-logging-jboss-logging@1.0.0.Final-redhat-1?repository_url=https://maven.repository.redhat.com/ga/&type=jar", + "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "name": "commons-logging-jboss-logging", + "version": "1.0.0.Final-redhat-1", + "advisories": [], + "licenses": [ + "APACHE-2.0" + ] +} +"#).await?; + + assert_tool_contains( + tool.clone(), + "commons-logging-jboss-logging", + r#" +{ + "identifier": "pkg://maven/org.jboss.logging/commons-logging-jboss-logging@1.0.0.Final-redhat-1?repository_url=https://maven.repository.redhat.com/ga/&type=jar", + "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "name": "commons-logging-jboss-logging", + "version": "1.0.0.Final-redhat-1", + "advisories": [], + "licenses": [ + "APACHE-2.0" + ] +} +"#).await?; + + assert_tool_contains( + tool.clone(), + "quarkus-resteasy-reactive-json", + r#" +There are multiple that match: + +{ + "items": [ + { + "identifier": "pkg://maven/io.quarkus/quarkus-resteasy-reactive-jsonb-common@2.13.8.Final-redhat-00004?repository_url=https://maven.repository.redhat.com/ga/&type=jar", + "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "name": "quarkus-resteasy-reactive-jsonb-common", + "version": "2.13.8.Final-redhat-00004" + }, + { + "identifier": "pkg://maven/io.quarkus/quarkus-resteasy-reactive-jsonb@2.13.8.Final-redhat-00004?repository_url=https://maven.repository.redhat.com/ga/&type=jar", + "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "name": "quarkus-resteasy-reactive-jsonb", + "version": "2.13.8.Final-redhat-00004" + }, + { + "identifier": "pkg://maven/io.quarkus/quarkus-resteasy-reactive-jsonb-common-deployment@2.13.8.Final-redhat-00004?repository_url=https://maven.repository.redhat.com/ga/&type=jar", + "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "name": "quarkus-resteasy-reactive-jsonb-common-deployment", + "version": "2.13.8.Final-redhat-00004" + }, + { + "identifier": "pkg://maven/io.quarkus/quarkus-resteasy-reactive-jsonb-deployment@2.13.8.Final-redhat-00004?repository_url=https://maven.repository.redhat.com/ga/&type=jar", + "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "name": "quarkus-resteasy-reactive-jsonb-deployment", + "version": "2.13.8.Final-redhat-00004" + } + ], + "total": 4 +} +"#).await + } +} diff --git a/modules/fundamental/src/ai/service/tools/product_info.rs b/modules/fundamental/src/ai/service/tools/product_info.rs new file mode 100644 index 00000000..edbe0c3d --- /dev/null +++ b/modules/fundamental/src/ai/service/tools/product_info.rs @@ -0,0 +1,106 @@ +use crate::ai::service::tools; +use crate::product::service::ProductService; +use anyhow::anyhow; +use async_trait::async_trait; +use langchain_rust::tools::Tool; +use serde::Serialize; +use serde_json::Value; +use std::error::Error; +use trustify_common::db::query::Query; +use uuid::Uuid; + +pub struct ProductInfo(pub ProductService); + +#[async_trait] +impl Tool for ProductInfo { + fn name(&self) -> String { + String::from("product-info") + } + + fn description(&self) -> String { + String::from( + r##" +This tool can be used to get information about a product. +The input should be the name of the product to search for. +When the input is a full name, the tool will provide information about the product. +When the input is a partial name, the tool will provide a list of possible matches. +"## + .trim(), + ) + } + + async fn run(&self, input: Value) -> Result> { + let service = &self.0; + let input = input + .as_str() + .ok_or("Input should be a string")? + .to_string(); + + let results = service + .fetch_products( + Query { + q: input, + ..Default::default() + }, + Default::default(), + (), + ) + .await?; + + if results.items.is_empty() { + return Err(anyhow!("I don't know").into()); + } + + #[derive(Serialize)] + struct Product { + name: String, + uuid: Uuid, + vendor: Option, + versions: Vec, + } + tools::paginated_to_json(results, |item| Product { + name: item.head.name.clone(), + uuid: item.head.id, + vendor: item.vendor.clone().map(|v| v.head.name), + versions: item.versions.iter().map(|v| v.version.clone()).collect(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ai::service::test::ingest_fixtures; + use crate::ai::service::tools::tests::assert_tool_contains; + use std::rc::Rc; + use test_context::test_context; + use test_log::test; + use trustify_test_context::TrustifyContext; + + #[test_context(TrustifyContext)] + #[test(actix_web::test)] + async fn product_info_tool(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { + ingest_fixtures(ctx).await?; + let tool = Rc::new(ProductInfo(ProductService::new(ctx.db.clone()))); + assert_tool_contains( + tool.clone(), + "Trusted Profile Analyzer", + r#" +{ + "items": [ + { + "name": "Trusted Profile Analyzer", + "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "vendor": "Red Hat", + "versions": [ + "37.17.9" + ] + } + ], + "total": 1 +} +"#, + ) + .await + } +} diff --git a/modules/fundamental/src/ai/service/tools/sbom_info.rs b/modules/fundamental/src/ai/service/tools/sbom_info.rs new file mode 100644 index 00000000..b2b79d78 --- /dev/null +++ b/modules/fundamental/src/ai/service/tools/sbom_info.rs @@ -0,0 +1,196 @@ +use crate::ai::service::tools; +use crate::sbom::service::SbomService; +use anyhow::anyhow; +use async_trait::async_trait; +use itertools::Itertools; +use langchain_rust::tools::Tool; +use serde::Serialize; +use serde_json::Value; +use std::error::Error; +use std::str::FromStr; +use time::OffsetDateTime; +use trustify_common::db::query::Query; +use trustify_common::id::Id; +use uuid::Uuid; + +pub struct SbomInfo(pub SbomService); + +#[async_trait] +impl Tool for SbomInfo { + fn name(&self) -> String { + String::from("sbom-info") + } + + fn description(&self) -> String { + String::from( + r##" +This tool can be used to get information about an SBOM. +The input should be the SBOM Identifier. +"## + .trim(), + ) + } + + async fn run(&self, input: Value) -> Result> { + let service = &self.0; + + let input = input + .as_str() + .ok_or("Input should be a string")? + .to_string(); + + // Try lookup as a UUID + let mut sbom_details = match Id::from_str(input.as_str()) { + Err(_) => None, + Ok(id) => service.fetch_sbom_details(id, ()).await?, + }; + + // Fallback to search + if sbom_details.is_none() { + // try to search for possible matches + let results = service + .fetch_sboms( + Query { + q: input, + ..Default::default() + }, + Default::default(), + (), + (), + ) + .await?; + + sbom_details = match results.items.len() { + 0 => None, + 1 => { + service + .fetch_sbom_details(Id::Uuid(results.items[0].head.id), ()) + .await? + } + _ => { + #[derive(Serialize)] + struct Item { + uuid: Uuid, + source_document_sha256: String, + name: String, + #[serde(with = "time::serde::rfc3339::option")] + published: Option, + } + + let json = tools::paginated_to_json(results, |item| Item { + uuid: item.head.id, + source_document_sha256: item + .source_document + .as_ref() + .map(|v| v.sha256.clone()) + .unwrap_or_default(), + name: item.head.name.clone(), + published: item.head.published, + })?; + return Ok(format!("There are multiple that match:\n\n{}", json)); + } + }; + } + + let item = match sbom_details { + Some(v) => v, + None => return Err(anyhow!("I don't know").into()), + }; + + #[derive(Serialize)] + struct Item { + uuid: Uuid, + source_document_sha256: String, + name: String, + #[serde(with = "time::serde::rfc3339::option")] + published: Option, + authors: Vec, + labels: Vec<(String, String)>, + advisories: Vec, + } + + #[derive(Serialize)] + struct Advisory { + uuid: Uuid, + identifier: String, + issuer: Option, + } + + let mut labels = item.summary.head.labels.iter().collect_vec(); + labels.sort_by(|a, b| a.0.cmp(b.0)); + + tools::to_json(&Item { + uuid: item.summary.head.id, + source_document_sha256: item + .summary + .source_document + .as_ref() + .map(|v| v.sha256.clone()) + .unwrap_or_default(), + name: item.summary.head.name.clone(), + published: item.summary.head.published, + authors: item.summary.head.authors.clone(), + labels: labels + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(), + advisories: item + .advisories + .iter() + .map(|advisory| Advisory { + uuid: advisory.head.uuid, + identifier: advisory.head.identifier.clone(), + issuer: advisory.head.issuer.clone().map(|v| v.head.name.clone()), + }) + .collect(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ai::service::tools::tests::assert_tool_contains; + use std::rc::Rc; + use test_context::test_context; + use test_log::test; + use trustify_test_context::TrustifyContext; + + #[test_context(TrustifyContext)] + #[test(actix_web::test)] + async fn sbom_info_tool(ctx: &TrustifyContext) -> Result<(), anyhow::Error> { + ctx.ingest_document("ubi9-9.2-755.1697625012.json").await?; + ctx.ingest_document("quarkus-bom-2.13.8.Final-redhat-00004.json") + .await?; + + let tool = Rc::new(SbomInfo(SbomService::new(ctx.db.clone()))); + + assert_tool_contains( + tool.clone(), + "quarkus", + r#" +{ + "uuid": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "source_document_sha256": "sha256:5a370574a991aa42f7ecc5b7d88754b258f81c230a73bea247c0a6fcc6f608ab", + "name": "quarkus-bom", + "published": "2023-11-13T00:10:00Z", + "authors": [ + "Organization: Red Hat Product Security (secalert@redhat.com)" + ], + "labels": [ + [ + "source", + "TrustifyContext" + ], + [ + "type", + "spdx" + ] + ], + "advisories": [] +} +"#, + ) + .await + } +} diff --git a/modules/ingestor/Cargo.toml b/modules/ingestor/Cargo.toml index 491c53f4..164dbb9c 100644 --- a/modules/ingestor/Cargo.toml +++ b/modules/ingestor/Cargo.toml @@ -12,6 +12,7 @@ trustify-entity = { workspace = true } trustify-module-storage = { workspace = true } trustify-module-analysis = { workspace = true } + actix-web = { workspace = true } anyhow = { workspace = true } bytes = { workspace = true } @@ -30,6 +31,7 @@ packageurl = { workspace = true } parking_lot = { workspace = true } quick-xml = { workspace = true } roxmltree = { workspace = true } +sbomsleuth = { workspace = true } sbom-walker = { workspace = true } sea-orm = { workspace = true } sea-query = { workspace = true } diff --git a/modules/ingestor/src/graph/advisory/mod.rs b/modules/ingestor/src/graph/advisory/mod.rs index e3e19695..bc7857d1 100644 --- a/modules/ingestor/src/graph/advisory/mod.rs +++ b/modules/ingestor/src/graph/advisory/mod.rs @@ -11,6 +11,7 @@ use sea_orm::{ }; use sea_query::{Condition, JoinType, OnConflict}; use semver::Version; +use serde_json::json; use std::fmt::{Debug, Formatter}; use time::OffsetDateTime; use tracing::instrument; @@ -141,6 +142,7 @@ impl Graph { sha256: Set(sha256), sha384: Set(digests.sha384.encode_hex()), sha512: Set(digests.sha512.encode_hex()), + meta: Set(json!({})), // Set to an empty JSON object }; let doc = doc_model.insert(&self.connection(&tx)).await?; diff --git a/modules/ingestor/src/graph/sbom/mod.rs b/modules/ingestor/src/graph/sbom/mod.rs index cc5744d9..1d43085e 100644 --- a/modules/ingestor/src/graph/sbom/mod.rs +++ b/modules/ingestor/src/graph/sbom/mod.rs @@ -21,6 +21,7 @@ use crate::{ use cpe::uri::OwnedUri; use entity::{product, product_version}; use hex::ToHex; +use sbomsleuth::report::Report; use sea_orm::{ prelude::Uuid, ActiveModelTrait, ColumnTrait, EntityTrait, ModelTrait, QueryFilter, QuerySelect, QueryTrait, RelationTrait, Select, SelectColumns, Set, @@ -28,6 +29,7 @@ use sea_orm::{ use sea_query::{ extension::postgres::PgExpr, Alias, Condition, Expr, Func, JoinType, Query, SimpleExpr, }; +use serde_json::json; use std::{ fmt::{Debug, Formatter}, iter, @@ -96,6 +98,68 @@ impl Graph { .map(|sbom| SbomContext::new(self, sbom))) } + #[instrument(skip(tx, info), err(level=tracing::Level::INFO))] + pub async fn ingest_sbom_with_report>( + &self, + report: &Report, + labels: impl Into + Debug, + digests: &Digests, + document_id: &str, + info: impl Into, + tx: TX, + ) -> Result { + let sha256 = digests.sha256.encode_hex::(); + + if let Some(found) = self.get_sbom_by_digest(&sha256, &tx).await? { + return Ok(found); + } + + let SbomInformation { + node_id, + name, + published, + authors, + } = info.into(); + + let connection = self.db.connection(&tx); + + let sbom_id = Uuid::now_v7(); + + let doc_model = source_document::ActiveModel { + id: Default::default(), + sha256: Set(sha256), + sha384: Set(digests.sha384.encode_hex()), + sha512: Set(digests.sha512.encode_hex()), + meta: Set(serde_json::to_string(report).unwrap().parse()?), + }; + + let doc = doc_model.insert(&connection).await?; + + let model = sbom::ActiveModel { + sbom_id: Set(sbom_id), + node_id: Set(node_id.clone()), + + document_id: Set(document_id.to_string()), + + published: Set(published), + authors: Set(authors), + + source_document_id: Set(Some(doc.id)), + labels: Set(labels.into()), + }; + + let node_model = sbom_node::ActiveModel { + sbom_id: Set(sbom_id), + node_id: Set(node_id), + name: Set(name), + }; + + let result = model.insert(&connection).await?; + node_model.insert(&connection).await?; + + Ok(SbomContext::new(self, result)) + } + #[instrument(skip(tx, info), err(level=tracing::Level::INFO))] pub async fn ingest_sbom>( &self, @@ -127,6 +191,7 @@ impl Graph { sha256: Set(sha256), sha384: Set(digests.sha384.encode_hex()), sha512: Set(digests.sha512.encode_hex()), + meta: Set(json!({})), // Set to an empty JSON object }; let doc = doc_model.insert(&connection).await?; diff --git a/modules/ingestor/src/service/sbom/spdx.rs b/modules/ingestor/src/service/sbom/spdx.rs index 7e52f953..754d3585 100644 --- a/modules/ingestor/src/service/sbom/spdx.rs +++ b/modules/ingestor/src/service/sbom/spdx.rs @@ -6,6 +6,7 @@ use crate::{ model::IngestResult, service::{Error, Warnings}, }; +use sbomsleuth::license::Licenses; use serde_json::Value; use tracing::instrument; use trustify_common::{hashing::Digests, id::Id}; @@ -31,6 +32,14 @@ impl<'g> SpdxLoader<'g> { let (spdx, _) = parse_spdx(&warnings, json)?; + let license_instance = Licenses::default(); + let report_instance = sbomsleuth::report::Report { + licenses: license_instance.run_with_spdx(spdx.clone()).await.unwrap(), + ..Default::default() + }; + let report = report_instance.run_with_spdx(spdx.clone()).unwrap(); + // println!("report: {:?}", report); + log::info!( "Storing: {}", spdx.document_creation_information.document_name @@ -47,7 +56,14 @@ impl<'g> SpdxLoader<'g> { let sbom = self .graph - .ingest_sbom(labels, digests, &document_id, spdx::Information(&spdx), &tx) + .ingest_sbom_with_report( + &report, + labels, + digests, + &document_id, + spdx::Information(&spdx), + &tx, + ) .await?; sbom.ingest_spdx(spdx, &warnings, &tx).await?;