From 19dabaf90439eba5325287d095f328ceb76a9abe Mon Sep 17 00:00:00 2001 From: Harsh Mahajan <127186841+HarshMN2345@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:17:24 +0530 Subject: [PATCH] Update grpc.rs --- src/core/blueprint/operators/grpc.rs | 224 ++++++++++++++------------- 1 file changed, 113 insertions(+), 111 deletions(-) diff --git a/src/core/blueprint/operators/grpc.rs b/src/core/blueprint/operators/grpc.rs index cc4b85fb1a..dfb094f14a 100644 --- a/src/core/blueprint/operators/grpc.rs +++ b/src/core/blueprint/operators/grpc.rs @@ -64,81 +64,69 @@ fn validate_schema( field_schema: FieldSchema, operation: &ProtobufOperation, name: &str, -) -> Valid<(), BlueprintError> { +) -> Valid<(), String> { let input_type = &operation.input_type; let output_type = &operation.output_type; - let input_type = match JsonSchema::try_from(input_type) { - Ok(input_schema) => Valid::succeed(input_schema), - Err(e) => Valid::from_validation_err(BlueprintError::from_validation_string(e)), - }; + Valid::from(JsonSchema::try_from(input_type)) + .zip(Valid::from(JsonSchema::try_from(output_type))) + .and_then(|(input_schema, output_schema)| { + let fields = &field_schema.field; + let args = &field_schema.args; - let output_type = match JsonSchema::try_from(output_type) { - Ok(output_type) => Valid::succeed(output_type), - Err(e) => Valid::from_validation_err(BlueprintError::from_validation_string(e)), - }; + // Treat repeated message types as optional in input schema + let normalized_input_schema = normalize_repeated_types(&input_schema); - input_type - .zip(output_type) - .and_then(|(_input_schema, sub_type)| { - // TODO: add validation for input schema - should compare result grpc.body to - // schema - let super_type = field_schema.field; - // TODO: all of the fields in protobuf are optional actually - // and if we want to mark some fields as required in GraphQL - // JsonSchema won't match and the validation will fail - match sub_type.is_a(&super_type, name).to_result() { - Ok(res) => Valid::succeed(res), - Err(e) => Valid::from_validation_err(BlueprintError::from_validation_string(e)), - } + // Validate input schema against args + args.compare(&normalized_input_schema, &format!("Input validation failed for {}", name))?; + + // Validate output schema against fields + fields.compare(&output_schema, &format!("Output validation failed for {}", name)) }) } - +fn normalize_repeated_types(schema: &JsonSchema) -> JsonSchema { + match schema { + JsonSchema::Arr(inner_schema) => { + // Treat repeated types (arrays) as optional + JsonSchema::Optional(Box::new(inner_schema.clone())) + } + JsonSchema::Object(fields) => { + let normalized_fields = fields + .iter() + .map(|(key, value)| (key.clone(), normalize_repeated_types(value))) + .collect(); + JsonSchema::Object(normalized_fields) + } + _ => schema.clone(), + } +} fn validate_group_by( field_schema: &FieldSchema, operation: &ProtobufOperation, group_by: Vec, -) -> Valid<(), BlueprintError> { +) -> Valid<(), String> { let input_type = &operation.input_type; let output_type = &operation.output_type; - let mut field_descriptor: Result> = None - .ok_or(ValidationError::new(BlueprintError::FieldNotFound( - group_by[0].clone(), - ))); - for item in group_by.iter().take(&group_by.len() - 1) { - field_descriptor = - output_type - .get_field_by_json_name(item.as_str()) - .ok_or(ValidationError::new(BlueprintError::FieldNotFound( - item.clone(), - ))); - } - let output_type = field_descriptor - .and_then(|f| JsonSchema::try_from(&f).map_err(BlueprintError::from_validation_string)); - let json_schema = match JsonSchema::try_from(input_type) { - Ok(schema) => Valid::succeed(schema), - Err(e) => Valid::from_validation_err(BlueprintError::from_validation_string(e)), - }; + let input_schema = JsonSchema::try_from(input_type)?; + let output_schema = JsonSchema::try_from(output_type)?; - json_schema - .zip(Valid::from(output_type)) - .and_then(|(_input_schema, output_schema)| { - // TODO: add validation for input schema - should compare result grpc.body to - // schema considering repeated message type - let fields = &field_schema.field; - // we're treating List types for gRPC as optional. - let fields = JsonSchema::Opt(Box::new(JsonSchema::Arr(Box::new(fields.to_owned())))); - match fields - .is_a(&output_schema, group_by[0].as_str()) - .to_result() - { - Ok(res) => Valid::succeed(res), - Err(e) => Valid::from_validation_err(BlueprintError::from_validation_string(e)), - } - }) + let normalized_input_schema = normalize_repeated_types(&input_schema); + + let fields = JsonSchema::Arr(Box::new(field_schema.field.to_owned())); + let args = JsonSchema::Arr(Box::new(field_schema.args.to_owned())); + + args.compare( + &normalized_input_schema, + &format!("Input validation failed for group_by {:?}", group_by), + )?; + fields.compare( + &output_schema, + &format!("Output validation failed for group_by {:?}", group_by), + ) } + pub struct CompileGrpc<'a> { pub config_module: &'a ConfigModule, pub operation_type: &'a GraphQLOperationType, @@ -187,63 +175,35 @@ pub fn compile_grpc(inputs: CompileGrpc) -> Valid { let validate_with_schema = inputs.validate_with_schema; let dedupe = grpc.dedupe.unwrap_or_default(); - Valid::from(GrpcMethod::try_from(grpc.method.as_str())) - .and_then(|method| { - let file_descriptor_set = config_module.extensions().get_file_descriptor_set(); + Valid::from(GrpcMethod::try_from(grpc.method.as_str())) + .and_then(|method| { + let file_descriptor_set = config_module.extensions().get_file_descriptor_set(); - if file_descriptor_set.file.is_empty() { - return Valid::fail(BlueprintError::ProtobufFilesNotSpecifiedInConfig); - } + if file_descriptor_set.file.is_empty() { + return Valid::fail("Protobuf files were not specified in the config".to_string()); + } - match to_operation(&method, file_descriptor_set) - .fuse(to_url(grpc, &method)) - .fuse(helpers::headers::to_mustache_headers(&grpc.headers)) - .fuse(helpers::body::to_body(grpc.body.as_ref())) - .to_result() - { - Ok(data) => Valid::succeed(data), - Err(e) => Valid::from_validation_err(BlueprintError::from_validation_string(e)), - } - }) - .and_then(|(operation, url, headers, body)| { - let validation = if validate_with_schema { - let field_schema = json_schema_from_field(config_module, field); - if grpc.batch_key.is_empty() { - validate_schema(field_schema, &operation, field.type_of.name()).unit() - } else { - validate_group_by(&field_schema, &operation, grpc.batch_key.clone()).unit() - } - } else { - Valid::succeed(()) - }; - validation.map(|_| (url, headers, operation, body)) - }) - .map(|(url, headers, operation, body)| { - let req_template = RequestTemplate { - url, - headers, - operation, - body, - operation_type: operation_type.clone(), - }; - let on_response = grpc.on_response_body.clone(); - let hook = WorkerHooks::try_new(None, on_response).ok(); - - let io = if !grpc.batch_key.is_empty() { - IR::IO(IO::Grpc { - req_template, - group_by: Some(GroupBy::new(grpc.batch_key.clone(), None)), - dl_id: None, - dedupe, - hook, - }) + to_operation(&method, file_descriptor_set) + .fuse(to_url(grpc, &method, config_module)) + .fuse(helpers::headers::to_mustache_headers(&grpc.headers)) + .fuse(helpers::body::to_body(grpc.body.as_ref())) + .into() + }) + .and_then(|(operation, url, headers, body)| { + let validation = if validate_with_schema { + let field_schema = json_schema_from_field(config_module, field); + if grpc.batch_key.is_empty() { + // Add input validation with repeated type normalization + validate_schema(field_schema, &operation, field.name()).unit() } else { - IR::IO(IO::Grpc { req_template, group_by: None, dl_id: None, dedupe, hook }) - }; + validate_group_by(&field_schema, &operation, grpc.batch_key.clone()).unit() + } + } else { + Valid::succeed(()) + }; + validation.map(|_| (url, headers, operation, body)) + }) - (io, &grpc.select) - }) - .and_then(apply_select) } #[cfg(test)] @@ -254,6 +214,22 @@ mod tests { use super::GrpcMethod; use crate::core::blueprint::BlueprintError; + #[test] +fn validate_repeated_types_as_optional() { + let operation = ProtobufOperation { + input_type: "RepeatedInputType".to_string(), + output_type: "ValidOutputType".to_string(), + }; + + let field_schema = FieldSchema { + args: JsonSchema::Arr(Box::new(JsonSchema::String)), + field: JsonSchema::Object(HashMap::new()), + }; + + let result = validate_schema(field_schema, &operation, "test_operation"); + assert!(result.is_ok()); +} + #[test] fn try_from_grpc_method() { @@ -268,6 +244,32 @@ mod tests { assert_eq!(method1.service, "ServiceName"); assert_eq!(method1.name, "MethodName"); } + #[test] +fn grpc_repeated_types_validation_integration() { + let config_module = MockConfigModule::new(); + let operation_type = GraphQLOperationType::Query; + let field = Field::new("test_field", "RepeatedInputType"); + + let grpc = Grpc { + method: "package.Service.Method".to_string(), + base_url: Some("http://localhost:5000".to_string()), + headers: None, + body: Some(vec!["repeated_field"]), + batch_key: vec![], + }; + + let compile_inputs = CompileGrpc { + config_module: &config_module, + operation_type: &operation_type, + field: &field, + grpc: &grpc, + validate_with_schema: true, + }; + + let result = compile_grpc(compile_inputs); + assert!(result.is_ok()); +} + #[test] fn try_from_grpc_method_invalid() {