diff --git a/modules/core/src/main/scala/sangria/execution/FieldCollector.scala b/modules/core/src/main/scala/sangria/execution/FieldCollector.scala index e9e261d5..b9a9dd00 100644 --- a/modules/core/src/main/scala/sangria/execution/FieldCollector.scala +++ b/modules/core/src/main/scala/sangria/execution/FieldCollector.scala @@ -176,13 +176,15 @@ class FieldCollector[Ctx, Val]( .map(dir -> _) }) - possibleDirs.collect { case Failure(error) => error }.headOption.map(Failure(_)).getOrElse { - val validDirs = possibleDirs.collect { case Success(v) => v } - val should = validDirs.forall { case (dir, args) => - dir.shouldInclude(DirectiveContext(selection, dir, args)) - } + possibleDirs.collectFirst { case Failure(error) => error } match { + case Some(f) => Failure(f) + case None => + val validDirs = possibleDirs.collect { case Success(v) => v } + val should = validDirs.forall { case (dir, args) => + dir.shouldInclude(DirectiveContext(selection, dir, args)) + } - Success(should) + Success(should) } } @@ -255,9 +257,9 @@ class CollectedFieldsBuilder { } def build = { - val builtFields = firstFields.toVector.zipWithIndex.map { case (f, idx) => + val builtFields = firstFields.iterator.zipWithIndex.map { case (f, idx) => CollectedField(names(idx), f, fields(idx).map(_.toVector)) - } + }.toVector CollectedFields(names.toVector, builtFields) } diff --git a/modules/core/src/main/scala/sangria/execution/Resolver.scala b/modules/core/src/main/scala/sangria/execution/Resolver.scala index acb15537..b8aa3520 100644 --- a/modules/core/src/main/scala/sangria/execution/Resolver.scala +++ b/modules/core/src/main/scala/sangria/execution/Resolver.scala @@ -493,8 +493,12 @@ class Resolver[Ctx]( Result(ErrorRegistry(fieldPath, resolveError(e), fields.head.location), None) } - val deferred = values.collect { case SeqRes(_, d, _) if d != null => d }.toVector - val deferredFut = values.collect { case SeqRes(_, _, d) if d != null => d }.toVector + val deferred = values.iterator.collect { + case SeqRes(_, d, _) if d != null => d + }.toVector + val deferredFut = values.iterator.collect { + case SeqRes(_, _, d) if d != null => d + }.toVector immediatelyResolveDeferred( uc, @@ -826,7 +830,7 @@ class Resolver[Ctx]( val resolved = future .flatMap { vs => - val errors = vs.flatMap(_.errors).toVector + val errors = vs.iterator.flatMap(_.errors).toVector val successfulValues = vs.collect { case SeqFutRes(v, _, _) if v != null => v } val dctx = vs.collect { case SeqFutRes(_, _, d) if d != null => d } @@ -869,8 +873,12 @@ class Resolver[Ctx]( None) } - val deferred = values.collect { case SeqRes(_, d, _) if d != null => d }.toVector - val deferredFut = values.collect { case SeqRes(_, _, d) if d != null => d }.toVector + val deferred = values.iterator.collect { + case SeqRes(_, d, _) if d != null => d + }.toVector + val deferredFut = values.iterator.collect { + case SeqRes(_, _, d) if d != null => d + }.toVector astFields.head -> DeferredResult(Future.successful(deferred) +: deferredFut, resolved) @@ -1101,19 +1109,19 @@ class Resolver[Ctx]( } - val simpleRes = resolvedValues.collect { case (af, r: Result) => af -> r } - val resSoFar = - simpleRes.foldLeft(Result(errors, Some(marshaller.emptyMapNode(fieldsNamesOrdered)))) { - case (res, (astField, other)) => - res.addToMap( - other, - astField.outputName, - isOptional(tpe, astField.name), - path.add(astField, tpe), - astField.location, - res.errors) - } + resolvedValues.iterator + .collect { case (af, r: Result) => af -> r } + .foldLeft(Result(errors, Some(marshaller.emptyMapNode(fieldsNamesOrdered)))) { + case (res, (astField, other)) => + res.addToMap( + other, + astField.outputName, + isOptional(tpe, astField.name), + path.add(astField, tpe), + astField.location, + res.errors) + } val complexRes = resolvedValues.collect { case (af, r: DeferredResult) => af -> r } @@ -1438,7 +1446,7 @@ class Resolver[Ctx]( val beforeAction = mBefore.collect { case (BeforeFieldResult(_, Some(action), _), _, _) => action }.lastOption - val beforeAttachments = mBefore.collect { + val beforeAttachments = mBefore.iterator.collect { case (BeforeFieldResult(_, _, Some(att)), _, _) => att }.toVector val updatedCtx = @@ -1463,7 +1471,7 @@ class Resolver[Ctx]( } def doErrorMiddleware(error: Throwable): Unit = - mError.collect { + mError.foreach { case (BeforeFieldResult(cv, _, _), mv, m: MiddlewareErrorField[Ctx]) => m.fieldError( mv.asInstanceOf[m.QueryVal], @@ -1471,6 +1479,7 @@ class Resolver[Ctx]( error, middlewareCtx, updatedCtx) + case _ => () } def doAfterMiddlewareWithMap[Val, NewVal](fn: Val => NewVal)(v: Val): NewVal = @@ -1646,10 +1655,12 @@ class Resolver[Ctx]( .contains(ProjectionExclude) => val astField = fields.head val field = objTpe.getField(schema, astField.name).head - val projectionNames = field.tags.collect { case ProjectionName(name) => name } + val projectionNames = field.tags.iterator.collect { case ProjectionName(name) => + name + }.toVector val projectedName = - if (projectionNames.nonEmpty) projectionNames.toVector + if (projectionNames.nonEmpty) projectionNames else Vector(field.name) projectedName.map(name => diff --git a/modules/core/src/main/scala/sangria/schema/SchemaValidationRule.scala b/modules/core/src/main/scala/sangria/schema/SchemaValidationRule.scala index 38930e03..2618c7a3 100644 --- a/modules/core/src/main/scala/sangria/schema/SchemaValidationRule.scala +++ b/modules/core/src/main/scala/sangria/schema/SchemaValidationRule.scala @@ -57,11 +57,11 @@ object SchemaValidationRule { } object DefaultValuesValidationRule extends SchemaValidationRule { - def validate[Ctx, Val](schema: Schema[Ctx, Val]) = { + def validate[Ctx, Val](schema: Schema[Ctx, Val]): List[Violation] = { val coercionHelper = ValueCoercionHelper.default def validate(prefix: => String, path: List[String], tpe: InputType[_])( - defaultValue: (_, ToInput[_, _])) = { + defaultValue: (_, ToInput[_, _])): Vector[Violation] = { val (default, toInput) = defaultValue.asInstanceOf[(Any, ToInput[Any, Any])] val (inputValue, iu) = toInput.toInput(default) @@ -76,11 +76,11 @@ object DefaultValuesValidationRule extends SchemaValidationRule { false, prefix)(iu) match { case Left(violations) => violations - case Right(violations) => Nil + case Right(violations) => Vector.empty } } - val inputTypeViolations = schema.inputTypes.values.toList.flatMap { + val inputTypeViolations = schema.inputTypes.values.flatMap { case it: InputObjectType[_] => it.fields.flatMap(f => f.defaultValue @@ -93,7 +93,7 @@ object DefaultValuesValidationRule extends SchemaValidationRule { case _ => Nil } - val outputTypeViolations = schema.outputTypes.values.toList.flatMap { + val outputTypeViolations = schema.outputTypes.values.flatMap { case ot: ObjectLikeType[_, _] => ot.fields.flatMap(f => f.arguments.flatMap(a => @@ -107,7 +107,7 @@ object DefaultValuesValidationRule extends SchemaValidationRule { case _ => Nil } - inputTypeViolations ++ outputTypeViolations + inputTypeViolations.toList ++ outputTypeViolations } } @@ -118,17 +118,17 @@ object InterfaceImplementationValidationRule extends SchemaValidationRule { intTpe: InterfaceType[_, _]): Vector[Violation] = { val objFields: Map[String, Vector[Field[_, _]]] = objTpe.ownFields.groupBy(_.name) - intTpe.ownFields.flatMap { intField => - objFields.get(intField.name) match { - case None => - // we allow object type to inherit fields from the interfaces - // without explicitly defining them, but only when it is not - // defined though SDL. - Vector.empty - - case Some(objField) - if !TypeComparators.isSubType(schema, objField.head.fieldType, intField.fieldType) => - Vector( + val violations: List[Violation] = intTpe.ownFields.foldLeft(List.empty[Violation]) { + case (acc, intField) => + objFields.get(intField.name) match { + case None => + // we allow object type to inherit fields from the interfaces + // without explicitly defining them, but only when it is not + // defined though SDL. + acc + + case Some(objField) + if !TypeComparators.isSubType(schema, objField.head.fieldType, intField.fieldType) => InvalidImplementationFieldTypeViolation( intTpe.name, objTpe.name, @@ -138,13 +138,12 @@ object InterfaceImplementationValidationRule extends SchemaValidationRule { SchemaElementValidator.sourceMapper(schema), SchemaElementValidator.location(objField.head) ++ SchemaElementValidator.location( intField) - )) + ) :: acc - case Some(objField) => - val intArgViolations = intField.arguments.flatMap { iarg => - objField.head.arguments.find(_.name == iarg.name) match { - case None => - Vector( + case Some(objField) => + val violationsWithIntArg = intField.arguments.foldLeft(acc) { case (acc, iarg) => + objField.head.arguments.find(_.name == iarg.name) match { + case None => MissingImplementationFieldArgumentViolation( intTpe.name, objTpe.name, @@ -153,11 +152,10 @@ object InterfaceImplementationValidationRule extends SchemaValidationRule { SchemaElementValidator.sourceMapper(schema), SchemaElementValidator.location(iarg) ++ SchemaElementValidator.location( objField.head) - )) + ) :: acc - case Some(oarg) - if !TypeComparators.isEqualType(iarg.argumentType, oarg.argumentType) => - Vector( + case Some(oarg) + if !TypeComparators.isEqualType(iarg.argumentType, oarg.argumentType) => InvalidImplementationFieldArgumentTypeViolation( intTpe.name, objTpe.name, @@ -167,17 +165,18 @@ object InterfaceImplementationValidationRule extends SchemaValidationRule { SchemaRenderer.renderTypeName(oarg.argumentType), SchemaElementValidator.sourceMapper(schema), SchemaElementValidator.location(iarg) ++ SchemaElementValidator.location(oarg) - )) + ) :: acc - case _ => Nil + case _ => acc + } } - } - val objArgViolations = objField.head.arguments - .filterNot(oa => intField.arguments.exists(_.name == oa.name)) - .flatMap { - case oarg if !oarg.argumentType.isInstanceOf[OptionInputType[_]] => - Vector( + objField.head.arguments.iterator + .filterNot(oa => intField.arguments.exists(_.name == oa.name)) + .foldLeft(violationsWithIntArg) { case (acc, oarg) => + if (oarg.argumentType.isInstanceOf[OptionInputType[_]]) + acc + else ImplementationExtraFieldArgumentNotOptionalViolation( intTpe.name, objTpe.name, @@ -187,16 +186,14 @@ object InterfaceImplementationValidationRule extends SchemaValidationRule { SchemaElementValidator.sourceMapper(schema), SchemaElementValidator.location(oarg) ++ SchemaElementValidator.location( intField) - )) - case _ => Nil - } - - intArgViolations ++ objArgViolations - } + ) :: acc + } + } } + violations.toVector } - def validate[Ctx, Val](schema: Schema[Ctx, Val]) = + def validate[Ctx, Val](schema: Schema[Ctx, Val]): List[Violation] = schema.possibleTypes.toList.flatMap { case (intName, objTypes) => schema.outputTypes(intName) match { case intTpe: InterfaceType[_, _] => objTypes.flatMap(validateObjectType(schema, _, intTpe)) @@ -206,7 +203,7 @@ object InterfaceImplementationValidationRule extends SchemaValidationRule { } object SubscriptionFieldsValidationRule extends SchemaValidationRule { - def validate[Ctx, Val](schema: Schema[Ctx, Val]) = { + def validate[Ctx, Val](schema: Schema[Ctx, Val]): List[Violation] = { val subsName = schema.subscription.map(_.name) def subscriptionTag(tag: FieldTag) = tag match { @@ -216,7 +213,7 @@ object SubscriptionFieldsValidationRule extends SchemaValidationRule { val otherViolations = schema.typeList.flatMap { case obj: ObjectLikeType[_, _] if subsName.isDefined && subsName.get != obj.name => - obj.uniqueFields + obj.uniqueFields.iterator .filter(_.tags.exists(subscriptionTag)) .map(f => InvalidSubscriptionFieldViolation(obj.name, f.name)) @@ -462,14 +459,14 @@ object ValidNamesValidator extends SchemaElementValidator { } object ContainerMembersValidator extends SchemaElementValidator { - override def validateUnionType(schema: Schema[_, _], tpe: UnionType[_]) = { + override def validateUnionType(schema: Schema[_, _], tpe: UnionType[_]): Vector[Violation] = { val emptyErrors = if (tpe.types.isEmpty) Vector(EmptyUnionMembersViolation(tpe.name, sourceMapper(schema), location(tpe))) else Vector.empty val nonUnique = - tpe.types.groupBy(_.name).toVector.collect { + tpe.types.groupBy(_.name).iterator.collect { case (memberName, dup) if dup.size > 1 => val astMembers = tpe.astNodes.collect { case astUnion: UnionTypeDefinition => astUnion.types @@ -483,14 +480,14 @@ object ContainerMembersValidator extends SchemaElementValidator { emptyErrors ++ nonUnique } - override def validateEnumType(schema: Schema[_, _], tpe: EnumType[_]) = { + override def validateEnumType(schema: Schema[_, _], tpe: EnumType[_]): Vector[Violation] = { val emptyErrors = if (tpe.values.isEmpty) Vector(EmptyEnumValuesMembersViolation(tpe.name, sourceMapper(schema), location(tpe))) else Vector.empty val nonUnique = - tpe.values.groupBy(_.name).toVector.collect { + tpe.values.groupBy(_.name).iterator.collect { case (valueName, dup) if dup.size > 1 => NonUniqueEnumValuesViolation( tpe.name, @@ -502,14 +499,16 @@ object ContainerMembersValidator extends SchemaElementValidator { emptyErrors ++ nonUnique } - override def validateInputObjectType(schema: Schema[_, _], tpe: InputObjectType[_]) = { + override def validateInputObjectType( + schema: Schema[_, _], + tpe: InputObjectType[_]): Vector[Violation] = { val emptyErrors = if (tpe.fields.isEmpty) Vector(EmptyInputFieldsViolation(tpe.name, sourceMapper(schema), location(tpe))) else Vector.empty val nonUnique = - tpe.fields.groupBy(_.name).toVector.collect { + tpe.fields.groupBy(_.name).iterator.collect { case (fieldName, dup) if dup.size > 1 => NonUniqueInputFieldsViolation( tpe.name, @@ -521,11 +520,13 @@ object ContainerMembersValidator extends SchemaElementValidator { emptyErrors ++ nonUnique } - override def validateObjectType(schema: Schema[_, _], tpe: ObjectType[_, _]) = { + override def validateObjectType( + schema: Schema[_, _], + tpe: ObjectType[_, _]): Vector[Violation] = { val generalErrors = validateObjectLikeType(schema, tpe, "Object") val nonUnique = - tpe.interfaces.groupBy(_.name).toVector.collect { + tpe.interfaces.groupBy(_.name).iterator.collect { case (intName, dup) if dup.size > 1 => val astMembers = tpe.astNodes.collect { case astUnion: ObjectTypeDefinition => astUnion.interfaces @@ -539,7 +540,9 @@ object ContainerMembersValidator extends SchemaElementValidator { generalErrors ++ nonUnique } - override def validateInterfaceType(schema: Schema[_, _], tpe: InterfaceType[_, _]) = + override def validateInterfaceType( + schema: Schema[_, _], + tpe: InterfaceType[_, _]): Vector[Violation] = validateObjectLikeType(schema, tpe, "Interface") def validateObjectLikeType( @@ -552,7 +555,7 @@ object ContainerMembersValidator extends SchemaElementValidator { else Vector.empty val nonUnique = - tpe.ownFields.groupBy(_.name).toVector.collect { + tpe.ownFields.groupBy(_.name).iterator.collect { case (fieldName, dup) if dup.size > 1 => NonUniqueFieldsViolation( kind, @@ -565,32 +568,46 @@ object ContainerMembersValidator extends SchemaElementValidator { emptyErrors ++ nonUnique } - override def validateField(schema: Schema[_, _], tpe: ObjectLikeType[_, _], field: Field[_, _]) = - field.arguments.groupBy(_.name).toVector.collect { - case (argName, dup) if dup.size > 1 => - NonUniqueFieldArgumentsViolation( - tpe.name, - field.name, - argName, - sourceMapper(schema), - dup.flatMap(location)) - } - - override def validateDirective(schema: Schema[_, _], tpe: Directive) = - tpe.arguments.groupBy(_.name).toVector.collect { - case (argName, dup) if dup.size > 1 => - NonUniqueDirectiveArgumentsViolation( - tpe.name, - argName, - sourceMapper(schema), - dup.flatMap(location)) - } + override def validateField( + schema: Schema[_, _], + tpe: ObjectLikeType[_, _], + field: Field[_, _]): Vector[Violation] = + field.arguments + .groupBy(_.name) + .iterator + .collect { + case (argName, dup) if dup.size > 1 => + NonUniqueFieldArgumentsViolation( + tpe.name, + field.name, + argName, + sourceMapper(schema), + dup.flatMap(location)) + } + .toVector + + override def validateDirective(schema: Schema[_, _], tpe: Directive): Vector[Violation] = + tpe.arguments + .groupBy(_.name) + .iterator + .collect { + case (argName, dup) if dup.size > 1 => + NonUniqueDirectiveArgumentsViolation( + tpe.name, + argName, + sourceMapper(schema), + dup.flatMap(location)) + } + .toVector } object EnumValueReservedNameValidator extends SchemaElementValidator { private val reservedNames = Set("true", "false", "null") - override def validateEnumValue(schema: Schema[_, _], tpe: EnumType[_], value: EnumValue[_]) = + override def validateEnumValue( + schema: Schema[_, _], + tpe: EnumType[_], + value: EnumValue[_]): Vector[Violation] = if (reservedNames.contains(value.name)) Vector( ReservedEnumValueNameViolation(tpe.name, value.name, sourceMapper(schema), location(value))) @@ -611,8 +628,8 @@ object InputObjectTypeRecursionValidator extends SchemaElementValidator { val recursiveFields = tpe.fields.filter(childField => childField.fieldType.namedType.name == rootTypeName && !childField.fieldType.isOptional && !childField.fieldType.isList) if (recursiveFields.nonEmpty) { - recursiveFields - .flatMap(field => Vector(InputObjectTypeRecursion(tpe.name, field.name, path, None, Nil))) + recursiveFields.iterator + .map(field => InputObjectTypeRecursion(tpe.name, field.name, path, None, Nil)) .toVector } else { val childTypesToCheck = tpe.fields.filter(field => @@ -685,7 +702,7 @@ class FullSchemaTraversalValidationRule(validators: SchemaElementValidator*) extends SchemaValidationRule { private val reservedNames = Set("true", "false", "null") - def validate[Ctx, Val](schema: Schema[Ctx, Val]) = { + def validate[Ctx, Val](schema: Schema[Ctx, Val]): List[Violation] = { val violations = new VectorBuilder[Violation] def add(vs: Vector[Violation]): Unit = diff --git a/modules/core/src/main/scala/sangria/validation/rules/ProvidedRequiredArguments.scala b/modules/core/src/main/scala/sangria/validation/rules/ProvidedRequiredArguments.scala index 0404cf92..593ed133 100644 --- a/modules/core/src/main/scala/sangria/validation/rules/ProvidedRequiredArguments.scala +++ b/modules/core/src/main/scala/sangria/validation/rules/ProvidedRequiredArguments.scala @@ -17,9 +17,9 @@ class ProvidedRequiredArguments extends ValidationRule { ctx.typeInfo.fieldDef match { case None => AstVisitorCommand.RightContinue case Some(fieldDef) => - val astArgs = args.map(_.name).toSet + val astArgs = args.iterator.map(_.name).toSet - val errors = fieldDef.arguments.toVector.collect { + val errors = fieldDef.arguments.iterator.collect { case argDef if !astArgs.contains( argDef.name) && !argDef.argumentType.isOptional && argDef.defaultValue.isEmpty => @@ -29,7 +29,7 @@ class ProvidedRequiredArguments extends ValidationRule { SchemaRenderer.renderTypeName(argDef.argumentType), ctx.sourceMapper, pos.toList) - } + }.toVector if (errors.nonEmpty) Left(errors) else AstVisitorCommand.RightContinue } @@ -38,9 +38,9 @@ class ProvidedRequiredArguments extends ValidationRule { ctx.typeInfo.directive match { case None => AstVisitorCommand.RightContinue case Some(dirDef) => - val astArgs = args.map(_.name).toSet + val astArgs = args.iterator.map(_.name).toSet - val errors = dirDef.arguments.toVector.collect { + val errors = dirDef.arguments.iterator.collect { case argDef if !astArgs.contains( argDef.name) && !argDef.argumentType.isOptional && argDef.defaultValue.isEmpty => @@ -50,7 +50,7 @@ class ProvidedRequiredArguments extends ValidationRule { SchemaRenderer.renderTypeName(argDef.argumentType), ctx.sourceMapper, pos.toList) - } + }.toVector if (errors.nonEmpty) Left(errors) else AstVisitorCommand.RightContinue } diff --git a/modules/core/src/test/scala/sangria/schema/TypeFieldConstraintsSpec.scala b/modules/core/src/test/scala/sangria/schema/TypeFieldConstraintsSpec.scala index 7825e356..c6b6643a 100644 --- a/modules/core/src/test/scala/sangria/schema/TypeFieldConstraintsSpec.scala +++ b/modules/core/src/test/scala/sangria/schema/TypeFieldConstraintsSpec.scala @@ -310,7 +310,7 @@ class TypeFieldConstraintsSpec extends AnyWordSpec with Matchers { val error = intercept[SchemaValidationException](Schema(QueryType, additionalTypes = AppleType :: Nil)) - error.violations.head.errorMessage should include( + error.violations.last.errorMessage should include( "Fruit.slice expects argument 'parts', but Apple.slice does not provide it.") }