diff --git a/src/generator/from_proto.rs b/src/generator/from_proto.rs index 486ea281e3..8846a095f0 100644 --- a/src/generator/from_proto.rs +++ b/src/generator/from_proto.rs @@ -50,7 +50,8 @@ impl Context { .value .iter() .map(|v| { - GraphQLType::parse_enum_variant(v.name()) + GraphQLType::new(v.name()) + .as_enum_variant() .unwrap() .to_string() }) @@ -58,7 +59,7 @@ impl Context { ty.variants = Some(variants); - let type_name = GraphQLType::parse_enum(enum_name).unwrap().to_string(); + let type_name = GraphQLType::new(enum_name).as_enum().unwrap().to_string(); self = self.insert_type(type_name, ty); } self @@ -77,16 +78,18 @@ impl Context { self = self.append_enums(&message.enum_type); self = self.append_msg_type(&message.nested_type); - let msg_type = GraphQLType::parse_object_type(&msg_name).unwrap(); - let msg_type = msg_type.clone().package(&self.package).unwrap_or(msg_type); + let msg_type = GraphQLType::new(&msg_name) + .package(&self.package) + .as_object_type() + .unwrap(); let mut ty = Type::default(); for field in message.field.iter() { - let field_name = GraphQLType::parse_field(field.name()).unwrap(); - let field_name = field_name - .clone() + let field_name = GraphQLType::new(field.name()) .package(&self.package) - .unwrap_or(field_name.clone()); + .as_field() + .unwrap(); + let mut cfg_field = Field::default(); let label = field.label().as_str_name().to_lowercase(); @@ -103,12 +106,12 @@ impl Context { } else { // for non-primitive types let type_of = convert_ty(field.type_name()); - let type_of = GraphQLType::parse_object_type(&type_of).unwrap(); - let type_of = type_of - .clone() - .package(&self.package) - .unwrap_or(type_of) + let type_of = GraphQLType::new(&type_of) + .package(self.package.as_str()) + .as_object_type() + .unwrap() .to_string(); + cfg_field.type_of = type_of; } @@ -134,22 +137,21 @@ impl Context { for service in services { let service_name = service.name().to_string(); for method in &service.method { - let field_name = GraphQLType::parse_method(method.name()).unwrap(); - let field_name = field_name - .clone() + let field_name = GraphQLType::new(method.name()) .package(&self.package) - .unwrap_or(field_name); + .as_method() + .unwrap(); let mut cfg_field = Field::default(); if let Some(arg_type) = get_input_ty(method.input_type()) { - let key = GraphQLType::parse_field(&arg_type) - .unwrap() + let key = GraphQLType::new(&arg_type) .package(&self.package) + .as_field() .unwrap() .to_string(); - let type_of = GraphQLType::parse_object_type(&arg_type) - .unwrap() + let type_of = GraphQLType::new(&arg_type) .package(&self.package) + .as_object_type() .unwrap() .to_string(); let val = Arg { @@ -167,11 +169,10 @@ impl Context { } let output_ty = get_output_ty(method.output_type()); - let output_ty = GraphQLType::parse_object_type(&output_ty).unwrap(); - let output_ty = output_ty - .clone() + let output_ty = GraphQLType::new(&output_ty) .package(&self.package) - .unwrap_or(output_ty) + .as_object_type() + .unwrap() .to_string(); cfg_field.type_of = output_ty; cfg_field.required = true; diff --git a/src/generator/graphql_type.rs b/src/generator/graphql_type.rs index 2033bba1db..d9d13e8187 100644 --- a/src/generator/graphql_type.rs +++ b/src/generator/graphql_type.rs @@ -6,12 +6,21 @@ static PACKAGE_SEPARATOR: &str = "."; /// A struct to represent the name of a GraphQL type. #[derive(Debug, Clone)] -pub struct GraphQLType { +pub struct GraphQLType(A); + +#[derive(Debug, Clone)] +pub struct Parsed { package: Option, name: String, entity: Entity, } +#[derive(Debug, Clone)] +pub struct Unparsed { + package: Option, + name: String, +} + #[derive(Debug, Clone)] struct Package { path: Vec, @@ -46,55 +55,76 @@ impl Display for Package { } } -impl GraphQLType { - // FIXME: separator should be taken as an input - fn parse(input: &str, convertor: Entity) -> Option { - if input.contains(PACKAGE_SEPARATOR) { - if let Some((package, name)) = input.rsplit_once(PACKAGE_SEPARATOR) { +impl GraphQLType { + pub fn new(input: &str) -> Self { + Self(Unparsed { package: None, name: input.to_string() }) + } + + // TODO: separator should be taken as an input + fn parse(&self, convertor: Entity) -> Option> { + let unparsed = &self.0; + let name = &unparsed.name; + let package = &unparsed.package; + if name.contains(PACKAGE_SEPARATOR) { + if let Some((package, name)) = name.rsplit_once(PACKAGE_SEPARATOR) { let package = Package::parse(package, PACKAGE_SEPARATOR); - Some(Self { package, name: name.to_string(), entity: convertor }) + Some(GraphQLType(Parsed { + package, + name: name.to_string(), + entity: convertor, + })) } else { - println!("Input: {input}"); None } + } else if let Some(package) = package { + Some(GraphQLType(Parsed { + package: Package::parse(package, PACKAGE_SEPARATOR), + name: name.to_string(), + entity: convertor, + })) } else { - Some(Self { package: None, name: input.to_string(), entity: convertor }) + Some(GraphQLType(Parsed { + package: None, + name: name.to_string(), + entity: convertor, + })) } } - pub fn parse_enum(name: &str) -> Option { - Self::parse(name, Entity::Enum) + pub fn as_enum(&self) -> Option> { + self.parse(Entity::Enum) + } + + pub fn as_enum_variant(&self) -> Option> { + self.parse(Entity::EnumVariant) } - pub fn parse_enum_variant(name: &str) -> Option { - Self::parse(name, Entity::EnumVariant) + pub fn as_object_type(&self) -> Option> { + self.parse(Entity::ObjectType) } - pub fn parse_object_type(name: &str) -> Option { - Self::parse(name, Entity::ObjectType) + pub fn as_method(&self) -> Option> { + self.parse(Entity::Method) } - pub fn parse_method(name: &str) -> Option { - Self::parse(name, Entity::Method) + pub fn as_field(&self) -> Option> { + self.parse(Entity::Field) } - pub fn parse_field(name: &str) -> Option { - Self::parse(name, Entity::Field) + pub fn package(mut self, package: &str) -> Self { + self.0.package = Some(package.to_string()); + self } +} +impl GraphQLType { pub fn id(&self) -> String { - if let Some(ref package) = self.package { - format!("{}.{}", package.source(), self.name) + if let Some(ref package) = self.0.package { + format!("{}.{}", package.source(), self.0.name) } else { - self.name.clone() + self.0.name.clone() } } - - pub fn package(mut self, package: &str) -> Option { - let package = Package::parse(package, ".")?; - self.package = Some(package); - Some(self) - } } // FIXME: make it private @@ -109,24 +139,27 @@ enum Entity { Field, } -impl Display for GraphQLType { +impl Display for GraphQLType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.entity { - Entity::EnumVariant => f.write_str(self.name.to_case(Case::ScreamingSnake).as_str())?, - Entity::Field => f.write_str(self.name.to_case(Case::Snake).as_str())?, + let parsed = &self.0; + match parsed.entity { + Entity::EnumVariant => { + f.write_str(parsed.name.to_case(Case::ScreamingSnake).as_str())? + } + Entity::Field => f.write_str(parsed.name.to_case(Case::Snake).as_str())?, Entity::Method => { - if let Some(package) = &self.package { + if let Some(package) = &parsed.package { f.write_str(package.to_string().to_case(Case::Snake).as_str())?; f.write_str(DEFAULT_SEPARATOR)?; }; - f.write_str(self.name.to_case(Case::Snake).as_str())? + f.write_str(parsed.name.to_case(Case::Snake).as_str())? } Entity::Enum | Entity::ObjectType => { - if let Some(package) = &self.package { + if let Some(package) = &parsed.package { f.write_str(package.to_string().to_case(Case::ScreamingSnake).as_str())?; f.write_str(DEFAULT_SEPARATOR)?; }; - f.write_str(self.name.to_case(Case::ScreamingSnake).as_str())? + f.write_str(parsed.name.to_case(Case::ScreamingSnake).as_str())? } }; Ok(()) @@ -219,12 +252,12 @@ mod tests { fn assert_type_names(input: Vec) { for ((entity, package, name), expected) in input { - let mut g = GraphQLType::parse(name, entity).unwrap(); + let mut g = GraphQLType::new(name); if let Some(package) = package { - g = g.clone().package(package).unwrap_or(g); + g = g.clone().package(package); } - let actual = g.to_string(); + let actual = g.parse(entity).unwrap().to_string(); assert_eq!(actual, expected, "Given: {:?}", g); } }