From 54f5c4c2298eefb3210e58b0e06c95fcdef2243d Mon Sep 17 00:00:00 2001 From: Miles Sabin Date: Thu, 7 Dec 2023 12:42:43 +0000 Subject: [PATCH 1/3] Add limits to parser; improve variable and fragment validation --- build.sbt | 2 +- modules/core/src/main/scala-2/syntax2.scala | 18 +- modules/core/src/main/scala-3/syntax3.scala | 17 +- modules/core/src/main/scala/compiler.scala | 443 +++++++---- modules/core/src/main/scala/mapping.scala | 6 +- modules/core/src/main/scala/minimizer.scala | 239 +++--- modules/core/src/main/scala/parser.scala | 707 +++++++++--------- modules/core/src/main/scala/schema.scala | 485 ++++++------ .../test/scala/compiler/CompilerSuite.scala | 77 +- .../test/scala/compiler/DirectivesSuite.scala | 4 +- .../test/scala/compiler/FragmentSuite.scala | 252 +++++++ .../test/scala/compiler/VariablesSuite.scala | 133 ++++ .../directives/DirectiveValidationSuite.scala | 2 +- .../test/scala/minimizer/MinimizerSuite.scala | 12 +- .../src/test/scala/parser/ParserSuite.scala | 219 +++++- .../core/src/test/scala/sdl/SDLSuite.scala | 43 +- .../generic/src/test/scala/ScalarsSuite.scala | 32 + 17 files changed, 1757 insertions(+), 934 deletions(-) diff --git a/build.sbt b/build.sbt index bac97c23..97b8214f 100644 --- a/build.sbt +++ b/build.sbt @@ -30,7 +30,7 @@ ThisBuild / scalaVersion := Scala2 ThisBuild / crossScalaVersions := Seq(Scala2, Scala3) ThisBuild / tlJdkRelease := Some(11) -ThisBuild / tlBaseVersion := "0.17" +ThisBuild / tlBaseVersion := "0.18" ThisBuild / startYear := Some(2019) ThisBuild / licenses := Seq(License.Apache2) ThisBuild / developers := List( diff --git a/modules/core/src/main/scala-2/syntax2.scala b/modules/core/src/main/scala-2/syntax2.scala index 5729ef05..4157f3c1 100644 --- a/modules/core/src/main/scala-2/syntax2.scala +++ b/modules/core/src/main/scala-2/syntax2.scala @@ -19,7 +19,6 @@ import cats.data.NonEmptyChain import cats.syntax.all._ import org.typelevel.literally.Literally import grackle.Ast.Document -import grackle.GraphQLParser.Document.parseAll import grackle.Schema trait VersionSpecificSyntax { @@ -32,7 +31,7 @@ class StringContextOps(val sc: StringContext) extends AnyVal { def doc(args: Any*): Document = macro DocumentLiteral.make } -object SchemaLiteral extends Literally[Schema] { +private object SchemaLiteral extends Literally[Schema] { def validate(c: Context)(s: String): Either[String,c.Expr[Schema]] = { import c.universe._ def mkError(err: Either[Throwable, NonEmptyChain[Problem]]) = @@ -40,18 +39,23 @@ object SchemaLiteral extends Literally[Schema] { t => s"Internal error: ${t.getMessage}", ps => s"Invalid schema: ${ps.toList.distinct.mkString("\n šŸž ", "\n šŸž ", "\n")}", ) - Schema(s).toEither.bimap(mkError, _ => c.Expr(q"_root_.grackle.Schema($s).toOption.get")) + Schema(s, CompiletimeParsers.schemaParser).toEither.bimap(mkError, _ => c.Expr(q"_root_.grackle.Schema($s, _root_.grackle.CompiletimeParsers.schemaParser).toOption.get")) } def make(c: Context)(args: c.Expr[Any]*): c.Expr[Schema] = apply(c)(args: _*) } -object DocumentLiteral extends Literally[Document] { +private object DocumentLiteral extends Literally[Document] { def validate(c: Context)(s: String): Either[String,c.Expr[Document]] = { import c.universe._ - parseAll(s).bimap( - pf => show"Invalid document: $pf", - _ => c.Expr(q"_root_.grackle.GraphQLParser.Document.parseAll($s).toOption.get"), + CompiletimeParsers.parser.parseText(s).toEither.bimap( + _.fold(thr => show"Invalid document: ${thr.getMessage}", _.toList.mkString("\n šŸž ", "\n šŸž ", "\n")), + _ => c.Expr(q"_root_.grackle.CompiletimeParsers.parser.parseText($s).toOption.get"), ) } def make(c: Context)(args: c.Expr[Any]*): c.Expr[Document] = apply(c)(args: _*) } + +object CompiletimeParsers { + val parser: GraphQLParser = GraphQLParser(GraphQLParser.defaultConfig) + val schemaParser: SchemaParser = SchemaParser(parser) +} diff --git a/modules/core/src/main/scala-3/syntax3.scala b/modules/core/src/main/scala-3/syntax3.scala index 65d2cadc..5ce73903 100644 --- a/modules/core/src/main/scala-3/syntax3.scala +++ b/modules/core/src/main/scala-3/syntax3.scala @@ -18,24 +18,27 @@ package grackle import cats.syntax.all._ import org.typelevel.literally.Literally import grackle.Ast.Document -import grackle.GraphQLParser.Document.parseAll trait VersionSpecificSyntax: extension (inline ctx: StringContext) inline def schema(inline args: Any*): Schema = ${SchemaLiteral('ctx, 'args)} - inline def doc(inline args: Any*): Document = ${ DocumentLiteral('ctx, 'args) } + inline def doc(inline args: Any*): Document = ${DocumentLiteral('ctx, 'args) } object SchemaLiteral extends Literally[Schema]: def validate(s: String)(using Quotes) = - Schema(s).toEither.bimap( + Schema(s, CompiletimeParsers.schemaParser).toEither.bimap( nec => s"Invalid schema:${nec.toList.distinct.mkString("\n šŸž ", "\n šŸž ", "\n")}", - _ => '{Schema(${Expr(s)}).toOption.get} + _ => '{Schema(${Expr(s)}, CompiletimeParsers.schemaParser).toOption.get} ) object DocumentLiteral extends Literally[Document]: def validate(s: String)(using Quotes) = - parseAll(s).bimap( - pf => show"Invalid document: $pf", - _ => '{parseAll(${Expr(s)}).toOption.get} + CompiletimeParsers.parser.parseText(s).toEither.bimap( + _.fold(thr => show"Invalid document: ${thr.getMessage}", _.toList.mkString("\n šŸž ", "\n šŸž ", "\n")), + _ => '{CompiletimeParsers.parser.parseText(${Expr(s)}).toOption.get} ) + +object CompiletimeParsers: + val parser: GraphQLParser = GraphQLParser(GraphQLParser.defaultConfig) + val schemaParser: SchemaParser = SchemaParser(parser) diff --git a/modules/core/src/main/scala/compiler.scala b/modules/core/src/main/scala/compiler.scala index e95ee630..3bb5c3a8 100644 --- a/modules/core/src/main/scala/compiler.scala +++ b/modules/core/src/main/scala/compiler.scala @@ -31,151 +31,172 @@ import ScalarType._ /** * GraphQL query parser */ -object QueryParser { - import Ast.{ Directive => _, Type => _, Value => _, _ }, OperationDefinition._, Selection._ - +trait QueryParser { /** * Parse a String to query algebra operations and fragments. * * GraphQL errors and warnings are accumulated in the result. */ - def parseText(text: String): Result[(List[UntypedOperation], List[UntypedFragment])] = - for { - doc <- GraphQLParser.toResult(text, GraphQLParser.Document.parseAll(text)) - res <- parseDocument(doc) - _ <- Result.failure("At least one operation required").whenA(res._1.isEmpty) - } yield res + def parseText(text: String): Result[(List[UntypedOperation], List[UntypedFragment])] /** * Parse a document AST to query algebra operations and fragments. * * GraphQL errors and warnings are accumulated in the result. */ - def parseDocument(doc: Document): Result[(List[UntypedOperation], List[UntypedFragment])] = { - val ops0 = doc.collect { case op: OperationDefinition => op } - val fragments0 = doc.collect { case frag: FragmentDefinition => frag } + def parseDocument(doc: Ast.Document): Result[(List[UntypedOperation], List[UntypedFragment])] +} - for { - ops <- ops0.traverse { - case op: Operation => parseOperation(op) - case qs: QueryShorthand => parseQueryShorthand(qs) - } - frags <- fragments0.traverse { frag => - val tpnme = frag.typeCondition.name - for { - sels <- parseSelections(frag.selectionSet) - dirs <- parseDirectives(frag.directives) - } yield UntypedFragment(frag.name.value, tpnme, dirs, sels) - } - } yield (ops, frags) - } +object QueryParser { + def apply(parser: GraphQLParser): QueryParser = + new Impl(parser) - /** - * Parse an operation AST to a query algebra operation. - * - * GraphQL errors and warnings are accumulated in the result. - */ - def parseOperation(op: Operation): Result[UntypedOperation] = { - val Operation(opType, name, vds, dirs0, sels) = op - for { - vs <- parseVariableDefinitions(vds) - q <- parseSelections(sels) - dirs <- parseDirectives(dirs0) - } yield { - val name0 = name.map(_.value) - opType match { - case OperationType.Query => UntypedQuery(name0, q, vs, dirs) - case OperationType.Mutation => UntypedMutation(name0, q, vs, dirs) - case OperationType.Subscription => UntypedSubscription(name0, q, vs, dirs) - } - } - } + private final class Impl(parser: GraphQLParser) extends QueryParser { + import Ast.{ Directive => _, Type => _, Value => _, _ }, OperationDefinition._, Selection._ - /** - * Parse variable definition ASTs to query algebra variable definitions. + /** + * Parse a String to query algebra operations and fragments. * - * GraphQL errors and warnings are accumulated in the result. + * GraphQL errors and warnings are accumulated in the result. */ - def parseVariableDefinitions(vds: List[VariableDefinition]): Result[List[UntypedVarDef]] = - vds.traverse { - case VariableDefinition(Name(nme), tpe, dv0, dirs0) => - for { - dv <- dv0.traverse(SchemaParser.parseValue) - dirs <- parseDirectives(dirs0) - } yield UntypedVarDef(nme, tpe, dv, dirs) - } + def parseText(text: String): Result[(List[UntypedOperation], List[UntypedFragment])] = + for { + doc <- parser.parseText(text) + res <- parseDocument(doc) + _ <- Result.failure("At least one operation required").whenA(res._1.isEmpty) + } yield res - /** - * Parse a query shorthand AST to query algebra operation. + /** + * Parse a document AST to query algebra operations and fragments. * - * GraphQL errors and warnings are accumulated in the result. + * GraphQL errors and warnings are accumulated in the result. */ - def parseQueryShorthand(qs: QueryShorthand): Result[UntypedOperation] = - parseSelections(qs.selectionSet).map(q => UntypedQuery(None, q, Nil, Nil)) + def parseDocument(doc: Document): Result[(List[UntypedOperation], List[UntypedFragment])] = { + val ops0 = doc.collect { case op: OperationDefinition => op } + val fragments0 = doc.collect { case frag: FragmentDefinition => frag } - /** - * Parse selection ASTs to query algebra terms. - * - * GraphQL errors and warnings are accumulated in the result - */ - def parseSelections(sels: List[Selection]): Result[Query] = - sels.traverse(parseSelection).map { sels0 => - if (sels0.sizeCompare(1) == 0) sels0.head else Group(sels0) + for { + ops <- ops0.traverse { + case op: Operation => parseOperation(op) + case qs: QueryShorthand => parseQueryShorthand(qs) + } + frags <- fragments0.traverse { frag => + val tpnme = frag.typeCondition.name + for { + sels <- parseSelections(frag.selectionSet) + dirs <- parseDirectives(frag.directives) + } yield UntypedFragment(frag.name.value, tpnme, dirs, sels) + } + } yield (ops, frags) } - /** - * Parse a selection AST to a query algebra term. + /** + * Parse an operation AST to a query algebra operation. * - * GraphQL errors and warnings are accumulated in the result. + * GraphQL errors and warnings are accumulated in the result. */ - def parseSelection(sel: Selection): Result[Query] = sel match { - case Field(alias, name, args, directives, sels) => + def parseOperation(op: Operation): Result[UntypedOperation] = { + val Operation(opType, name, vds, dirs0, sels) = op for { - args0 <- parseArgs(args) - sels0 <- parseSelections(sels) - dirs <- parseDirectives(directives) + vs <- parseVariableDefinitions(vds) + q <- parseSelections(sels) + dirs <- parseDirectives(dirs0) } yield { - val nme = name.value - val alias0 = alias.map(_.value).flatMap(n => if (n == nme) None else Some(n)) - if (sels.isEmpty) UntypedSelect(nme, alias0, args0, dirs, Empty) - else UntypedSelect(nme, alias0, args0, dirs, sels0) + val name0 = name.map(_.value) + opType match { + case OperationType.Query => UntypedQuery(name0, q, vs, dirs) + case OperationType.Mutation => UntypedMutation(name0, q, vs, dirs) + case OperationType.Subscription => UntypedSubscription(name0, q, vs, dirs) + } } + } - case FragmentSpread(Name(name), directives) => - for { - dirs <- parseDirectives(directives) - } yield UntypedFragmentSpread(name, dirs) + /** + * Parse variable definition ASTs to query algebra variable definitions. + * + * GraphQL errors and warnings are accumulated in the result. + */ + def parseVariableDefinitions(vds: List[VariableDefinition]): Result[List[UntypedVarDef]] = + vds.traverse { + case VariableDefinition(Name(nme), tpe, dv0, dirs0) => + for { + dv <- dv0.traverse(Value.fromAst) + dirs <- parseDirectives(dirs0) + } yield UntypedVarDef(nme, tpe, dv, dirs) + } - case InlineFragment(typeCondition, directives, sels) => - for { - dirs <- parseDirectives(directives) - sels0 <- parseSelections(sels) - } yield UntypedInlineFragment(typeCondition.map(_.name), dirs, sels0) - } + /** + * Parse a query shorthand AST to query algebra operation. + * + * GraphQL errors and warnings are accumulated in the result. + */ + def parseQueryShorthand(qs: QueryShorthand): Result[UntypedOperation] = + parseSelections(qs.selectionSet).map(q => UntypedQuery(None, q, Nil, Nil)) - /** - * Parse directive ASTs to query algebra directives. - * - * GraphQL errors and warnings are accumulated in the result. - */ - def parseDirectives(directives: List[Ast.Directive]): Result[List[Directive]] = - directives.traverse(SchemaParser.mkDirective) + /** + * Parse selection ASTs to query algebra terms. + * + * GraphQL errors and warnings are accumulated in the result + */ + def parseSelections(sels: List[Selection]): Result[Query] = + sels.traverse(parseSelection).map { sels0 => + if (sels0.sizeCompare(1) == 0) sels0.head else Group(sels0) + } - /** - * Parse argument ASTs to query algebra bindings. - * - * GraphQL errors and warnings are accumulated in the result. - */ - def parseArgs(args: List[(Name, Ast.Value)]): Result[List[Binding]] = - args.traverse((parseArg _).tupled) + /** + * Parse a selection AST to a query algebra term. + * + * GraphQL errors and warnings are accumulated in the result. + */ + def parseSelection(sel: Selection): Result[Query] = sel match { + case Field(alias, name, args, directives, sels) => + for { + args0 <- parseArgs(args) + sels0 <- parseSelections(sels) + dirs <- parseDirectives(directives) + } yield { + val nme = name.value + val alias0 = alias.map(_.value).flatMap(n => if (n == nme) None else Some(n)) + if (sels.isEmpty) UntypedSelect(nme, alias0, args0, dirs, Empty) + else UntypedSelect(nme, alias0, args0, dirs, sels0) + } - /** - * Parse an argument AST to a query algebra binding. - * - * GraphQL errors and warnings are accumulated in the result. - */ - def parseArg(name: Name, value: Ast.Value): Result[Binding] = - SchemaParser.parseValue(value).map(v => Binding(name.value, v)) + case FragmentSpread(Name(name), directives) => + for { + dirs <- parseDirectives(directives) + } yield UntypedFragmentSpread(name, dirs) + + case InlineFragment(typeCondition, directives, sels) => + for { + dirs <- parseDirectives(directives) + sels0 <- parseSelections(sels) + } yield UntypedInlineFragment(typeCondition.map(_.name), dirs, sels0) + } + + /** + * Parse directive ASTs to query algebra directives. + * + * GraphQL errors and warnings are accumulated in the result. + */ + def parseDirectives(directives: List[Ast.Directive]): Result[List[Directive]] = + directives.traverse(Directive.fromAst) + + /** + * Parse argument ASTs to query algebra bindings. + * + * GraphQL errors and warnings are accumulated in the result. + */ + def parseArgs(args: List[(Name, Ast.Value)]): Result[List[Binding]] = + args.traverse((parseArg _).tupled) + + /** + * Parse an argument AST to a query algebra binding. + * + * GraphQL errors and warnings are accumulated in the result. + */ + def parseArg(name: Name, value: Ast.Value): Result[Binding] = + Value.fromAst(value).map(v => Binding(name.value, v)) + } } /** @@ -185,7 +206,7 @@ object QueryParser { * applies a collection of transformation phases in sequence, yielding a * query algebra term which can be directly interpreted. */ -class QueryCompiler(schema: Schema, phases: List[Phase]) { +class QueryCompiler(parser: QueryParser, schema: Schema, phases: List[Phase]) { import IntrospectionLevel._ /** @@ -195,26 +216,30 @@ class QueryCompiler(schema: Schema, phases: List[Phase]) { * GraphQL errors and warnings are accumulated in the result. */ def compile(text: String, name: Option[String] = None, untypedVars: Option[Json] = None, introspectionLevel: IntrospectionLevel = Full, env: Env = Env.empty): Result[Operation] = - QueryParser.parseText(text).flatMap { case (ops, frags) => - (ops, name) match { - case (Nil, _) => - Result.failure("At least one operation required") - case (List(op), None) => - compileOperation(op, untypedVars, frags, introspectionLevel, env) - case (_, None) => - Result.failure("Operation name required to select unique operation") - case (ops, _) if ops.exists(_.name.isEmpty) => - Result.failure("Query shorthand cannot be combined with multiple operations") - case (ops, name) => - ops.filter(_.name == name) match { - case List(op) => - compileOperation(op, untypedVars, frags, introspectionLevel, env) - case Nil => - Result.failure(s"No operation named '$name'") - case _ => - Result.failure(s"Multiple operations named '$name'") - } - } + parser.parseText(text).flatMap { case (ops, frags) => + for { + _ <- Result.fromProblems(validateVariablesAndFragments(ops, frags)) + ops0 <- ops.traverse(op => compileOperation(op, untypedVars, frags, introspectionLevel, env).map(op0 => (op.name, op0))) + res <- (ops0, name) match { + case (List((_, op)), None) => + op.success + case (Nil, _) => + Result.failure("At least one operation required") + case (_, None) => + Result.failure("Operation name required to select unique operation") + case (ops, _) if ops.lengthCompare(1) > 0 && ops.exists(_._1.isEmpty) => + Result.failure("Query shorthand cannot be combined with multiple operations") + case (ops, name) => + ops.filter(_._1 == name) match { + case List((_, op)) => + op.success + case Nil => + Result.failure(s"No operation named '$name'") + case _ => + Result.failure(s"Multiple operations named '$name'") + } + } + } yield res } /** @@ -297,6 +322,154 @@ class QueryCompiler(schema: Schema, phases: List[Phase]) { } loop(tpe, false) } + + def validateVariablesAndFragments(ops: List[UntypedOperation], frags: List[UntypedFragment]): List[Problem] = { + val (uniqueFrags, duplicateFrags) = frags.map(_.name).foldLeft((Set.empty[String], Set.empty[String])) { + case ((unique, duplicate), nme) => + if (unique.contains(nme)) (unique, duplicate + nme) + else (unique + nme, duplicate) + } + + if (duplicateFrags.nonEmpty) + duplicateFrags.toList.map(nme => Problem(s"Fragment '$nme' is defined more than once")) + else { + def collectQueryRefs(query: Query): (Set[String], Set[String]) = { + @tailrec + def loop(queries: Iterator[Query], vars: Set[String], frags: Set[String]): (Set[String], Set[String]) = + if (!queries.hasNext) (vars, frags) + else + queries.next() match { + case UntypedSelect(_, _, args, dirs, child) => + val v0 = args.iterator.flatMap(arg => collectValueRefs(arg.value)).toSet + val v1 = dirs.iterator.flatMap(dir => dir.args.iterator.flatMap(arg => collectValueRefs(arg.value))).toSet + loop(Iterator.single(child) ++ queries, vars ++ v0 ++ v1, frags) + case UntypedFragmentSpread(nme, dirs) => + val v0 = dirs.iterator.flatMap(dir => dir.args.iterator.flatMap(arg => collectValueRefs(arg.value))).toSet + loop(queries, vars ++ v0, frags + nme) + case UntypedInlineFragment(_, dirs, child) => + val v0 = dirs.iterator.flatMap(dir => dir.args.iterator.flatMap(arg => collectValueRefs(arg.value))).toSet + loop(Iterator.single(child) ++ queries, vars ++ v0, frags) + case Group(children) => + loop(children.iterator ++ queries, vars, frags) + case Select(_, _, child) => loop(Iterator.single(child) ++ queries, vars, frags) + case Narrow(_, child) => loop(Iterator.single(child) ++ queries, vars, frags) + case Unique(child) => loop(Iterator.single(child) ++ queries, vars, frags) + case Filter(_, child) => loop(Iterator.single(child) ++ queries, vars, frags) + case Limit(_, child) => loop(Iterator.single(child) ++ queries, vars, frags) + case Offset(_, child) => loop(Iterator.single(child) ++ queries, vars, frags) + case OrderBy(_, child) => loop(Iterator.single(child) ++ queries, vars, frags) + case Introspect(_, child) => loop(Iterator.single(child) ++ queries, vars, frags) + case Environment(_, child) => loop(Iterator.single(child) ++ queries, vars, frags) + case Component(_, _, child) => loop(Iterator.single(child) ++ queries, vars, frags) + case Effect(_, child) => loop(Iterator.single(child) ++ queries, vars, frags) + case TransformCursor(_, child) => loop(Iterator.single(child) ++ queries, vars, frags) + case Count(_) => loop(queries, vars, frags) + case Empty => loop(queries, vars, frags) + } + + loop(Iterator.single(query), Set.empty[String], Set.empty[String]) + } + + def collectValueRefs(value: Value): Set[String] = { + @tailrec + def loop(values: Iterator[Value], vars: Set[String]): Set[String] = + if (!values.hasNext) vars + else + values.next() match { + case VariableRef(nme) => + loop(values, Set(nme)) + case ObjectValue(fields) => + loop(fields.iterator.map(_._2) ++ values, vars) + case ListValue(elems) => + loop(elems.iterator ++ values, vars) + case _ => loop(values, vars) + } + + loop(Iterator.single(value), Set.empty[String]) + } + + val fragRefs: Map[String, (Set[String], Set[String])] = + frags.map { frag => + (frag.name, collectQueryRefs(frag.child)) + }.toMap + + @tailrec + def checkCycle(pendingFrags: Set[String], seen: Set[String]): Option[Set[String]] = { + if (pendingFrags.isEmpty) Some(seen) + else { + val hd = pendingFrags.head + if (seen.contains(hd)) None + else checkCycle(fragRefs(hd)._2 ++ pendingFrags.tail, seen + hd) + } + } + + def findCycle: Option[String] = { + @tailrec + def loop(pendingFrags: Set[String]): Either[Set[String], String] = { + if(pendingFrags.isEmpty) Left(Set.empty[String]) + else { + val hd = pendingFrags.head + checkCycle(Set(hd), Set.empty[String]) match { + case None => Right(hd) + case Some(seen) => loop(pendingFrags.tail.diff(seen)) + } + } + } + + if (uniqueFrags.isEmpty) None + else loop(uniqueFrags).toOption + } + + findCycle match { + case Some(from) => List(Problem(s"Fragment cycle starting from '$from'")) + case _ => + def validateOp(op: UntypedOperation, pendingFrags: Set[String]): (List[Problem], Set[String]) = { + val pendingVars = op.variables.map(_.name).toSet + val (dqv, dqf) = collectQueryRefs(op.query) + + val (qv, qf) = { + dqf.foldLeft((dqv, dqf)) { + case ((v, f), nme) => fragRefs.get(nme) match { + case None => (v, f) + case Some((fv, ff)) => (v ++ fv, f ++ ff) + } + } + } + + val varProblems = + if (qv == pendingVars) Nil + else { + val undefined = qv.diff(pendingVars) + val unused = pendingVars.diff(qv) + val undefinedProblems = undefined.toList.map(nme => Problem(s"Variable '$nme' is undefined")) + val unusedProblems = unused.toList.map(nme => Problem(s"Variable '$nme' is unused")) + undefinedProblems ++ unusedProblems + } + + val fragProblems = + if (qf.subsetOf(uniqueFrags)) Nil + else { + val undefined = qf.diff(uniqueFrags) + val undefinedProblems = undefined.toList.map(nme => Problem(s"Fragment '$nme' is undefined")) + undefinedProblems + } + + (varProblems ++ fragProblems, pendingFrags.diff(qf)) + } + + val (opProblems, unreferencedFrags) = + ops.foldLeft((List.empty[Problem], uniqueFrags)) { + case ((acc, pendingFrags), op) => + val (problems, pendingFrags0) = validateOp(op, pendingFrags) + (acc ++ problems, pendingFrags0) + } + + val unreferencedFragProblems = unreferencedFrags.toList.map(nme => Problem(s"Fragment '$nme' is unused")) + + opProblems ++ unreferencedFragProblems + } + } + } } object QueryCompiler { @@ -710,7 +883,7 @@ object QueryCompiler { case VariableRef(varName) => for { v <- Elab.vars - tv <- Elab.liftR(Result.fromOption(v.get(varName), s"Undefined variable '$varName'")) + tv <- Elab.liftR(Result.fromOption(v.get(varName), s"Variable '$varName' is undefined")) b <- tv match { case (tpe, BooleanValue(value)) if tpe.nonNull =:= BooleanType => Elab.pure(value) case _ => Elab.failure(s"Argument of skip/include must be boolean") diff --git a/modules/core/src/main/scala/mapping.scala b/modules/core/src/main/scala/mapping.scala index 65e21cf7..b9d35239 100644 --- a/modules/core/src/main/scala/mapping.scala +++ b/modules/core/src/main/scala/mapping.scala @@ -480,7 +480,11 @@ abstract class Mapping[F[_]] { def compilerPhases: List[QueryCompiler.Phase] = List(selectElaborator, componentElaborator, effectElaborator) - lazy val compiler = new QueryCompiler(schema, compilerPhases) + def parserConfig: GraphQLParser.Config = GraphQLParser.defaultConfig + lazy val graphQLParser: GraphQLParser = GraphQLParser(parserConfig) + lazy val queryParser: QueryParser = QueryParser(graphQLParser) + + lazy val compiler: QueryCompiler = new QueryCompiler(queryParser, schema, compilerPhases) val interpreter: QueryInterpreter[F] = new QueryInterpreter(this) diff --git a/modules/core/src/main/scala/minimizer.scala b/modules/core/src/main/scala/minimizer.scala index e295966e..bb44b9d9 100644 --- a/modules/core/src/main/scala/minimizer.scala +++ b/modules/core/src/main/scala/minimizer.scala @@ -15,126 +15,133 @@ package grackle -import cats.implicits._ +trait QueryMinimizer { + def minimizeText(text: String): Result[String] + def minimizeDocument(doc: Ast.Document): String +} object QueryMinimizer { - import Ast._ - - def minimizeText(text: String): Either[String, String] = { - for { - doc <- GraphQLParser.Document.parseAll(text).leftMap(_.expected.toList.mkString(",")) - } yield minimizeDocument(doc) - } - - def minimizeDocument(doc: Document): String = { - import OperationDefinition._ - import OperationType._ - import Selection._ - import Value._ - - def renderDefinition(defn: Definition): String = - defn match { - case e: ExecutableDefinition => renderExecutableDefinition(e) - case _ => "" - } - - def renderExecutableDefinition(ex: ExecutableDefinition): String = - ex match { - case op: OperationDefinition => renderOperationDefinition(op) - case frag: FragmentDefinition => renderFragmentDefinition(frag) - } - - def renderOperationDefinition(op: OperationDefinition): String = - op match { - case qs: QueryShorthand => renderSelectionSet(qs.selectionSet) - case op: Operation => renderOperation(op) - } - - def renderOperation(op: Operation): String = - renderOperationType(op.operationType) + - op.name.map(nme => s" ${nme.value}").getOrElse("") + - renderVariableDefns(op.variables)+ - renderDirectives(op.directives)+ - renderSelectionSet(op.selectionSet) - - def renderOperationType(op: OperationType): String = - op match { - case Query => "query" - case Mutation => "mutation" - case Subscription => "subscription" - } - - def renderDirectives(dirs: List[Directive]): String = - dirs.map { case Directive(name, args) => s"@${name.value}${renderArguments(args)}" }.mkString("") - - def renderVariableDefns(vars: List[VariableDefinition]): String = - vars match { - case Nil => "" - case _ => - vars.map { - case VariableDefinition(name, tpe, default, dirs) => - s"$$${name.value}:${tpe.name}${default.map(v => s"=${renderValue(v)}").getOrElse("")}${renderDirectives(dirs)}" - }.mkString("(", ",", ")") - } - - def renderSelectionSet(sels: List[Selection]): String = - sels match { - case Nil => "" - case _ => sels.map(renderSelection).mkString("{", ",", "}") + def apply(parser: GraphQLParser): QueryMinimizer = + new Impl(parser) + + private final class Impl(parser: GraphQLParser) extends QueryMinimizer { + import Ast._ + + def minimizeText(text: String): Result[String] = + for { + doc <- parser.parseText(text) + } yield minimizeDocument(doc) + + def minimizeDocument(doc: Document): String = { + import OperationDefinition._ + import OperationType._ + import Selection._ + import Value._ + + def renderDefinition(defn: Definition): String = + defn match { + case e: ExecutableDefinition => renderExecutableDefinition(e) + case _ => "" + } + + def renderExecutableDefinition(ex: ExecutableDefinition): String = + ex match { + case op: OperationDefinition => renderOperationDefinition(op) + case frag: FragmentDefinition => renderFragmentDefinition(frag) + } + + def renderOperationDefinition(op: OperationDefinition): String = + op match { + case qs: QueryShorthand => renderSelectionSet(qs.selectionSet) + case op: Operation => renderOperation(op) + } + + def renderOperation(op: Operation): String = + renderOperationType(op.operationType) + + op.name.map(nme => s" ${nme.value}").getOrElse("") + + renderVariableDefns(op.variables)+ + renderDirectives(op.directives)+ + renderSelectionSet(op.selectionSet) + + def renderOperationType(op: OperationType): String = + op match { + case Query => "query" + case Mutation => "mutation" + case Subscription => "subscription" + } + + def renderDirectives(dirs: List[Directive]): String = + dirs.map { case Directive(name, args) => s"@${name.value}${renderArguments(args)}" }.mkString("") + + def renderVariableDefns(vars: List[VariableDefinition]): String = + vars match { + case Nil => "" + case _ => + vars.map { + case VariableDefinition(name, tpe, default, dirs) => + s"$$${name.value}:${tpe.name}${default.map(v => s"=${renderValue(v)}").getOrElse("")}${renderDirectives(dirs)}" + }.mkString("(", ",", ")") + } + + def renderSelectionSet(sels: List[Selection]): String = + sels match { + case Nil => "" + case _ => sels.map(renderSelection).mkString("{", ",", "}") + } + + def renderSelection(sel: Selection): String = + sel match { + case f: Field => renderField(f) + case s: FragmentSpread => renderFragmentSpread(s) + case i: InlineFragment => renderInlineFragment(i) + } + + def renderField(f: Field) = { + f.alias.map(a => s"${a.value}:").getOrElse("")+ + f.name.value+ + renderArguments(f.arguments)+ + renderDirectives(f.directives)+ + renderSelectionSet(f.selectionSet) } - def renderSelection(sel: Selection): String = - sel match { - case f: Field => renderField(f) - case s: FragmentSpread => renderFragmentSpread(s) - case i: InlineFragment => renderInlineFragment(i) - } - - def renderField(f: Field) = { - f.alias.map(a => s"${a.value}:").getOrElse("")+ - f.name.value+ - renderArguments(f.arguments)+ - renderDirectives(f.directives)+ - renderSelectionSet(f.selectionSet) + def renderArguments(args: List[(Name, Value)]): String = + args match { + case Nil => "" + case _ => args.map { case (n, v) => s"${n.value}:${renderValue(v)}" }.mkString("(", ",", ")") + } + + def renderInputObject(args: List[(Name, Value)]): String = + args match { + case Nil => "" + case _ => args.map { case (n, v) => s"${n.value}:${renderValue(v)}" }.mkString("{", ",", "}") + } + + def renderTypeCondition(tpe: Type): String = + s"on ${tpe.name}" + + def renderFragmentDefinition(frag: FragmentDefinition): String = + s"fragment ${frag.name.value} ${renderTypeCondition(frag.typeCondition)}${renderDirectives(frag.directives)}${renderSelectionSet(frag.selectionSet)}" + + def renderFragmentSpread(spread: FragmentSpread): String = + s"...${spread.name.value}${renderDirectives(spread.directives)}" + + def renderInlineFragment(frag: InlineFragment): String = + s"...${frag.typeCondition.map(renderTypeCondition).getOrElse("")}${renderDirectives(frag.directives)}${renderSelectionSet(frag.selectionSet)}" + + def renderValue(v: Value): String = + v match { + case Variable(name) => s"$$${name.value}" + case IntValue(value) => value.toString + case FloatValue(value) => value.toString + case StringValue(value) => s""""$value"""" + case BooleanValue(value) => value.toString + case NullValue => "null" + case EnumValue(name) => name.value + case ListValue(values) => values.map(renderValue).mkString("[", ",", "]") + case ObjectValue(fields) => renderInputObject(fields) + } + + doc.map(renderDefinition).mkString(",") } - - def renderArguments(args: List[(Name, Value)]): String = - args match { - case Nil => "" - case _ => args.map { case (n, v) => s"${n.value}:${renderValue(v)}" }.mkString("(", ",", ")") - } - - def renderInputObject(args: List[(Name, Value)]): String = - args match { - case Nil => "" - case _ => args.map { case (n, v) => s"${n.value}:${renderValue(v)}" }.mkString("{", ",", "}") - } - - def renderTypeCondition(tpe: Type): String = - s"on ${tpe.name}" - - def renderFragmentDefinition(frag: FragmentDefinition): String = - s"fragment ${frag.name.value} ${renderTypeCondition(frag.typeCondition)}${renderDirectives(frag.directives)}${renderSelectionSet(frag.selectionSet)}" - - def renderFragmentSpread(spread: FragmentSpread): String = - s"...${spread.name.value}${renderDirectives(spread.directives)}" - - def renderInlineFragment(frag: InlineFragment): String = - s"...${frag.typeCondition.map(renderTypeCondition).getOrElse("")}${renderDirectives(frag.directives)}${renderSelectionSet(frag.selectionSet)}" - - def renderValue(v: Value): String = - v match { - case Variable(name) => s"$$${name.value}" - case IntValue(value) => value.toString - case FloatValue(value) => value.toString - case StringValue(value) => s""""$value"""" - case BooleanValue(value) => value.toString - case NullValue => "null" - case EnumValue(name) => name.value - case ListValue(values) => values.map(renderValue).mkString("[", ",", "]") - case ObjectValue(fields) => renderInputObject(fields) - } - - doc.map(renderDefinition).mkString(",") } } diff --git a/modules/core/src/main/scala/parser.scala b/modules/core/src/main/scala/parser.scala index 0004f55f..bc9361bb 100644 --- a/modules/core/src/main/scala/parser.scala +++ b/modules/core/src/main/scala/parser.scala @@ -15,278 +15,308 @@ package grackle -import cats.parse.{LocationMap, Parser, Parser0} +import scala.util.matching.Regex + +import cats.implicits._ +import cats.parse.{Parser, Parser0} import cats.parse.Parser._ import cats.parse.Numbers._ import cats.parse.Rfc5234.{cr, crlf, digit, hexdig, lf} -import cats.implicits._ -import CommentedText._ -import Literals._ -import scala.util.matching.Regex + +trait GraphQLParser { + def parseText(text: String): Result[Ast.Document] +} object GraphQLParser { + case class Config( + maxSelectionDepth: Int, + maxSelectionWidth: Int, + maxInputValueDepth: Int, + maxListTypeDepth: Int + ) + + val defaultConfig: Config = + Config( + maxSelectionDepth = 100, + maxSelectionWidth = 1000, + maxInputValueDepth = 5, + maxListTypeDepth = 5 + ) - val nameInitial = ('A' to 'Z') ++ ('a' to 'z') ++ Seq('_') - val nameSubsequent = nameInitial ++ ('0' to '9') + def apply(config: Config): GraphQLParser = + new Impl(config) - def keyword(s: String) = token(string(s) <* not(charIn(nameSubsequent))) + def toResult[T](pr: Either[Parser.Error, T]): Result[T] = + Result.fromEither(pr.leftMap(_.show)) - def punctuation(s: String) = token(string(s)) + import CommentedText._ + import Literals._ - lazy val Document: Parser0[Ast.Document] = - (whitespace.void | comment).rep0 *> Definition.rep0 <* Parser.end + private final class Impl(config: Config) extends GraphQLParser { + import config._ - lazy val Definition: Parser[Ast.Definition] = - ExecutableDefinition | TypeSystemDefinition | TypeSystemExtension + def parseText(text: String): Result[Ast.Document] = + toResult(Document.parseAll(text)) - lazy val TypeSystemDefinition: Parser[Ast.TypeSystemDefinition] = { - val SchemaDefinition: Parser[Ast.SchemaDefinition] = - ((keyword("schema") *> Directives.?) ~ braces(RootOperationTypeDefinition.rep0)).map { - case (dirs, rootdefs) => Ast.SchemaDefinition(rootdefs, dirs.getOrElse(Nil)) - } + val nameInitial = ('A' to 'Z') ++ ('a' to 'z') ++ Seq('_') + val nameSubsequent = nameInitial ++ ('0' to '9') - def typeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.TypeDefinition] = { + def keyword(s: String) = token(string(s) <* not(charIn(nameSubsequent))) - def scalarTypeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.ScalarTypeDefinition] = - ((keyword("scalar") *> Name) ~ Directives.?).map { - case (name, dirs) => Ast.ScalarTypeDefinition(name, desc.map(_.value), dirs.getOrElse(Nil)) - } + def punctuation(s: String) = token(string(s)) - def objectTypeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.ObjectTypeDefinition] = - ((keyword("type") *> Name) ~ ImplementsInterfaces.? ~ Directives.? ~ FieldsDefinition).map { - case (((name, ifs), dirs), fields) => Ast.ObjectTypeDefinition(name, desc.map(_.value), fields, ifs.getOrElse(Nil), dirs.getOrElse(Nil)) - } + lazy val Document: Parser0[Ast.Document] = + (whitespace.void | comment).rep0 *> Definition.rep0 <* Parser.end - def interfaceTypeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.InterfaceTypeDefinition] = - ((keyword("interface") *> Name) ~ ImplementsInterfaces.? ~ Directives.? ~ FieldsDefinition).map { - case (((name, ifs), dirs), fields) => Ast.InterfaceTypeDefinition(name, desc.map(_.value), fields, ifs.getOrElse(Nil), dirs.getOrElse(Nil)) - } + lazy val Definition: Parser[Ast.Definition] = + ExecutableDefinition | TypeSystemDefinition | TypeSystemExtension - def unionTypeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.UnionTypeDefinition] = - ((keyword("union") *> Name) ~ Directives.? ~ UnionMemberTypes).map { - case ((name, dirs), members) => Ast.UnionTypeDefinition(name, desc.map(_.value), dirs.getOrElse(Nil), members) + lazy val TypeSystemDefinition: Parser[Ast.TypeSystemDefinition] = { + val SchemaDefinition: Parser[Ast.SchemaDefinition] = + ((keyword("schema") *> Directives.?) ~ braces(RootOperationTypeDefinition.rep0)).map { + case (dirs, rootdefs) => Ast.SchemaDefinition(rootdefs, dirs.getOrElse(Nil)) } - def enumTypeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.EnumTypeDefinition] = - ((keyword("enum") *> Name) ~ Directives.? ~ EnumValuesDefinition).map { - case ((name, dirs), values) => Ast.EnumTypeDefinition(name, desc.map(_.value), dirs.getOrElse(Nil), values) - } + def typeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.TypeDefinition] = { - def inputObjectTypeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.InputObjectTypeDefinition] = - ((keyword("input") *> Name) ~ Directives.? ~ InputFieldsDefinition).map { - case ((name, dirs), fields) => Ast.InputObjectTypeDefinition(name, desc.map(_.value), fields, dirs.getOrElse(Nil)) - } + def scalarTypeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.ScalarTypeDefinition] = + ((keyword("scalar") *> Name) ~ Directives.?).map { + case (name, dirs) => Ast.ScalarTypeDefinition(name, desc.map(_.value), dirs.getOrElse(Nil)) + } - scalarTypeDefinition(desc)| - objectTypeDefinition(desc) | - interfaceTypeDefinition(desc) | - unionTypeDefinition(desc) | - enumTypeDefinition(desc) | - inputObjectTypeDefinition(desc) - } + def objectTypeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.ObjectTypeDefinition] = + ((keyword("type") *> Name) ~ ImplementsInterfaces.? ~ Directives.? ~ FieldsDefinition).map { + case (((name, ifs), dirs), fields) => Ast.ObjectTypeDefinition(name, desc.map(_.value), fields, ifs.getOrElse(Nil), dirs.getOrElse(Nil)) + } - def directiveDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.DirectiveDefinition] = - ((keyword("directive") *> punctuation("@") *> Name) ~ - ArgumentsDefinition.? ~ (keyword("repeatable").? <* keyword("on")) ~ DirectiveLocations).map { - case (((name, args), rpt), locs) => Ast.DirectiveDefinition(name, desc.map(_.value), args.getOrElse(Nil), rpt.isDefined, locs) - } + def interfaceTypeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.InterfaceTypeDefinition] = + ((keyword("interface") *> Name) ~ ImplementsInterfaces.? ~ Directives.? ~ FieldsDefinition).map { + case (((name, ifs), dirs), fields) => Ast.InterfaceTypeDefinition(name, desc.map(_.value), fields, ifs.getOrElse(Nil), dirs.getOrElse(Nil)) + } - SchemaDefinition | - Description.?.with1.flatMap { desc => - typeDefinition(desc) | directiveDefinition(desc) - } - } + def unionTypeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.UnionTypeDefinition] = + ((keyword("union") *> Name) ~ Directives.? ~ UnionMemberTypes).map { + case ((name, dirs), members) => Ast.UnionTypeDefinition(name, desc.map(_.value), dirs.getOrElse(Nil), members) + } - lazy val TypeSystemExtension: Parser[Ast.TypeSystemExtension] = { + def enumTypeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.EnumTypeDefinition] = + ((keyword("enum") *> Name) ~ Directives.? ~ EnumValuesDefinition).map { + case ((name, dirs), values) => Ast.EnumTypeDefinition(name, desc.map(_.value), dirs.getOrElse(Nil), values) + } - val SchemaExtension: Parser[Ast.SchemaExtension] = - ((keyword("schema") *> Directives.?) ~ braces(RootOperationTypeDefinition.rep0).?).map { - case (dirs, rootdefs) => Ast.SchemaExtension(rootdefs.getOrElse(Nil), dirs.getOrElse(Nil)) - } + def inputObjectTypeDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.InputObjectTypeDefinition] = + ((keyword("input") *> Name) ~ Directives.? ~ InputFieldsDefinition).map { + case ((name, dirs), fields) => Ast.InputObjectTypeDefinition(name, desc.map(_.value), fields, dirs.getOrElse(Nil)) + } - val TypeExtension: Parser[Ast.TypeExtension] = { + scalarTypeDefinition(desc)| + objectTypeDefinition(desc) | + interfaceTypeDefinition(desc) | + unionTypeDefinition(desc) | + enumTypeDefinition(desc) | + inputObjectTypeDefinition(desc) + } - val ScalarTypeExtension: Parser[Ast.ScalarTypeExtension] = - ((keyword("scalar") *> NamedType) ~ Directives.?).map { - case (((name), dirs)) => Ast.ScalarTypeExtension(name, dirs.getOrElse(Nil)) + def directiveDefinition(desc: Option[Ast.Value.StringValue]): Parser[Ast.DirectiveDefinition] = + ((keyword("directive") *> punctuation("@") *> Name) ~ + ArgumentsDefinition.? ~ (keyword("repeatable").? <* keyword("on")) ~ DirectiveLocations).map { + case (((name, args), rpt), locs) => Ast.DirectiveDefinition(name, desc.map(_.value), args.getOrElse(Nil), rpt.isDefined, locs) } - val ObjectTypeExtension: Parser[Ast.ObjectTypeExtension] = - ((keyword("type") *> NamedType) ~ ImplementsInterfaces.? ~ Directives.? ~ FieldsDefinition.?).map { - case (((name, ifs), dirs), fields) => Ast.ObjectTypeExtension(name, fields.getOrElse(Nil), ifs.getOrElse(Nil), dirs.getOrElse(Nil)) + SchemaDefinition | + Description.?.with1.flatMap { desc => + typeDefinition(desc) | directiveDefinition(desc) } + } - val InterfaceTypeExtension: Parser[Ast.InterfaceTypeExtension] = - ((keyword("interface") *> NamedType) ~ ImplementsInterfaces.? ~ Directives.? ~ FieldsDefinition.?).map { - case (((name, ifs), dirs), fields) => Ast.InterfaceTypeExtension(name, fields.getOrElse(Nil), ifs.getOrElse(Nil), dirs.getOrElse(Nil)) - } + lazy val TypeSystemExtension: Parser[Ast.TypeSystemExtension] = { - val UnionTypeExtension: Parser[Ast.UnionTypeExtension] = - ((keyword("union") *> NamedType) ~ Directives.? ~ UnionMemberTypes.?).map { - case (((name), dirs), members) => Ast.UnionTypeExtension(name, dirs.getOrElse(Nil), members.getOrElse(Nil)) + val SchemaExtension: Parser[Ast.SchemaExtension] = + ((keyword("schema") *> Directives.?) ~ braces(RootOperationTypeDefinition.rep0).?).map { + case (dirs, rootdefs) => Ast.SchemaExtension(rootdefs.getOrElse(Nil), dirs.getOrElse(Nil)) } - val EnumTypeExtension: Parser[Ast.EnumTypeExtension] = - ((keyword("enum") *> NamedType) ~ Directives.? ~ EnumValuesDefinition.?).map { - case (((name), dirs), values) => Ast.EnumTypeExtension(name, dirs.getOrElse(Nil), values.getOrElse(Nil)) - } + val TypeExtension: Parser[Ast.TypeExtension] = { - val InputObjectTypeExtension: Parser[Ast.InputObjectTypeExtension] = - ((keyword("input") *> NamedType) ~ Directives.? ~ InputFieldsDefinition.?).map { - case (((name), dirs), fields) => Ast.InputObjectTypeExtension(name, dirs.getOrElse(Nil), fields.getOrElse(Nil)) - } + val ScalarTypeExtension: Parser[Ast.ScalarTypeExtension] = + ((keyword("scalar") *> NamedType) ~ Directives.?).map { + case (((name), dirs)) => Ast.ScalarTypeExtension(name, dirs.getOrElse(Nil)) + } - ScalarTypeExtension| - ObjectTypeExtension| - InterfaceTypeExtension| - UnionTypeExtension| - EnumTypeExtension| - InputObjectTypeExtension - } + val ObjectTypeExtension: Parser[Ast.ObjectTypeExtension] = + ((keyword("type") *> NamedType) ~ ImplementsInterfaces.? ~ Directives.? ~ FieldsDefinition.?).map { + case (((name, ifs), dirs), fields) => Ast.ObjectTypeExtension(name, fields.getOrElse(Nil), ifs.getOrElse(Nil), dirs.getOrElse(Nil)) + } - keyword("extend") *> (SchemaExtension | TypeExtension) - } + val InterfaceTypeExtension: Parser[Ast.InterfaceTypeExtension] = + ((keyword("interface") *> NamedType) ~ ImplementsInterfaces.? ~ Directives.? ~ FieldsDefinition.?).map { + case (((name, ifs), dirs), fields) => Ast.InterfaceTypeExtension(name, fields.getOrElse(Nil), ifs.getOrElse(Nil), dirs.getOrElse(Nil)) + } + + val UnionTypeExtension: Parser[Ast.UnionTypeExtension] = + ((keyword("union") *> NamedType) ~ Directives.? ~ UnionMemberTypes.?).map { + case (((name), dirs), members) => Ast.UnionTypeExtension(name, dirs.getOrElse(Nil), members.getOrElse(Nil)) + } - lazy val RootOperationTypeDefinition: Parser[Ast.RootOperationTypeDefinition] = - (OperationType ~ punctuation(":") ~ NamedType ~ Directives).map { - case (((optpe, _), tpe), dirs) => Ast.RootOperationTypeDefinition(optpe, tpe, dirs) + val EnumTypeExtension: Parser[Ast.EnumTypeExtension] = + ((keyword("enum") *> NamedType) ~ Directives.? ~ EnumValuesDefinition.?).map { + case (((name), dirs), values) => Ast.EnumTypeExtension(name, dirs.getOrElse(Nil), values.getOrElse(Nil)) + } + + val InputObjectTypeExtension: Parser[Ast.InputObjectTypeExtension] = + ((keyword("input") *> NamedType) ~ Directives.? ~ InputFieldsDefinition.?).map { + case (((name), dirs), fields) => Ast.InputObjectTypeExtension(name, dirs.getOrElse(Nil), fields.getOrElse(Nil)) + } + + ScalarTypeExtension| + ObjectTypeExtension| + InterfaceTypeExtension| + UnionTypeExtension| + EnumTypeExtension| + InputObjectTypeExtension + } + + keyword("extend") *> (SchemaExtension | TypeExtension) } + lazy val RootOperationTypeDefinition: Parser[Ast.RootOperationTypeDefinition] = + (OperationType ~ punctuation(":") ~ NamedType ~ Directives).map { + case (((optpe, _), tpe), dirs) => Ast.RootOperationTypeDefinition(optpe, tpe, dirs) + } - lazy val Description = StringValue - lazy val ImplementsInterfaces = - (keyword("implements") ~ punctuation("&").?) *> NamedType.repSep0(punctuation("&")) + lazy val Description = StringValue - lazy val FieldsDefinition: Parser[List[Ast.FieldDefinition]] = - braces(FieldDefinition.rep0) + lazy val ImplementsInterfaces = + (keyword("implements") ~ punctuation("&").?) *> NamedType.repSep0(punctuation("&")) - lazy val FieldDefinition: Parser[Ast.FieldDefinition] = - (Description.?.with1 ~ Name ~ ArgumentsDefinition.? ~ punctuation(":") ~ Type ~ Directives.?).map { - case (((((desc, name), args), _), tpe), dirs) => Ast.FieldDefinition(name, desc.map(_.value), args.getOrElse(Nil), tpe, dirs.getOrElse(Nil)) - } + lazy val FieldsDefinition: Parser[List[Ast.FieldDefinition]] = + braces(FieldDefinition.rep0) - lazy val ArgumentsDefinition: Parser[List[Ast.InputValueDefinition]] = - parens(InputValueDefinition.rep0) + lazy val FieldDefinition: Parser[Ast.FieldDefinition] = + (Description.?.with1 ~ Name ~ ArgumentsDefinition.? ~ punctuation(":") ~ Type ~ Directives.?).map { + case (((((desc, name), args), _), tpe), dirs) => Ast.FieldDefinition(name, desc.map(_.value), args.getOrElse(Nil), tpe, dirs.getOrElse(Nil)) + } - lazy val InputFieldsDefinition: Parser[List[Ast.InputValueDefinition]] = - braces(InputValueDefinition.rep0) + lazy val ArgumentsDefinition: Parser[List[Ast.InputValueDefinition]] = + parens(InputValueDefinition.rep0) - lazy val InputValueDefinition: Parser[Ast.InputValueDefinition] = - (Description.?.with1 ~ (Name <* punctuation(":")) ~ Type ~ DefaultValue.? ~ Directives.?).map { - case ((((desc, name), tpe), dv), dirs) => Ast.InputValueDefinition(name, desc.map(_.value), tpe, dv, dirs.getOrElse(Nil)) - } + lazy val InputFieldsDefinition: Parser[List[Ast.InputValueDefinition]] = + braces(InputValueDefinition.rep0) - lazy val UnionMemberTypes: Parser[List[Ast.Type.Named]] = - (punctuation("=") *> punctuation("|").?) *> NamedType.repSep0(punctuation("|")) + lazy val InputValueDefinition: Parser[Ast.InputValueDefinition] = + (Description.?.with1 ~ (Name <* punctuation(":")) ~ Type ~ DefaultValue.? ~ Directives.?).map { + case ((((desc, name), tpe), dv), dirs) => Ast.InputValueDefinition(name, desc.map(_.value), tpe, dv, dirs.getOrElse(Nil)) + } - lazy val EnumValuesDefinition: Parser[List[Ast.EnumValueDefinition]] = - braces(EnumValueDefinition.rep0) + lazy val UnionMemberTypes: Parser[List[Ast.Type.Named]] = + (punctuation("=") *> punctuation("|").?) *> NamedType.repSep0(punctuation("|")) - lazy val EnumValueDefinition: Parser[Ast.EnumValueDefinition] = - (Description.?.with1 ~ Name ~ Directives.?).map { - case ((desc, name), dirs) => Ast.EnumValueDefinition(name, desc.map(_.value), dirs.getOrElse(Nil)) - } + lazy val EnumValuesDefinition: Parser[List[Ast.EnumValueDefinition]] = + braces(EnumValueDefinition.rep0) - lazy val DirectiveLocations: Parser0[List[Ast.DirectiveLocation]] = - punctuation("|").? *> DirectiveLocation.repSep0(punctuation("|")) - - lazy val DirectiveLocation: Parser[Ast.DirectiveLocation] = - keyword("QUERY") .as(Ast.DirectiveLocation.QUERY) | - keyword("MUTATION") .as(Ast.DirectiveLocation.MUTATION) | - keyword("SUBSCRIPTION").as(Ast.DirectiveLocation.SUBSCRIPTION) | - keyword("FIELD_DEFINITION").as(Ast.DirectiveLocation.FIELD_DEFINITION) | - keyword("FIELD").as(Ast.DirectiveLocation.FIELD) | - keyword("FRAGMENT_DEFINITION").as(Ast.DirectiveLocation.FRAGMENT_DEFINITION) | - keyword("FRAGMENT_SPREAD").as(Ast.DirectiveLocation.FRAGMENT_SPREAD) | - keyword("INLINE_FRAGMENT").as(Ast.DirectiveLocation.INLINE_FRAGMENT) | - keyword("VARIABLE_DEFINITION").as(Ast.DirectiveLocation.VARIABLE_DEFINITION) | - keyword("SCHEMA").as(Ast.DirectiveLocation.SCHEMA) | - keyword("SCALAR").as(Ast.DirectiveLocation.SCALAR) | - keyword("OBJECT").as(Ast.DirectiveLocation.OBJECT) | - keyword("ARGUMENT_DEFINITION").as(Ast.DirectiveLocation.ARGUMENT_DEFINITION) | - keyword("INTERFACE").as(Ast.DirectiveLocation.INTERFACE) | - keyword("UNION").as(Ast.DirectiveLocation.UNION) | - keyword("ENUM_VALUE").as(Ast.DirectiveLocation.ENUM_VALUE) | - keyword("ENUM").as(Ast.DirectiveLocation.ENUM) | - keyword("INPUT_OBJECT").as(Ast.DirectiveLocation.INPUT_OBJECT) | - keyword("INPUT_FIELD_DEFINITION").as(Ast.DirectiveLocation.INPUT_FIELD_DEFINITION) - - lazy val ExecutableDefinition: Parser[Ast.ExecutableDefinition] = - OperationDefinition | FragmentDefinition - - lazy val OperationDefinition: Parser[Ast.OperationDefinition] = - QueryShorthand | Operation - - lazy val QueryShorthand: Parser[Ast.OperationDefinition.QueryShorthand] = - SelectionSet.map(Ast.OperationDefinition.QueryShorthand.apply) - - lazy val Operation: Parser[Ast.OperationDefinition.Operation] = - (OperationType ~ Name.? ~ VariableDefinitions.? ~ Directives ~ SelectionSet).map { - case ((((op, name), vars), dirs), sels) => Ast.OperationDefinition.Operation(op, name, vars.orEmpty, dirs, sels) - } + lazy val EnumValueDefinition: Parser[Ast.EnumValueDefinition] = + (Description.?.with1 ~ Name ~ Directives.?).map { + case ((desc, name), dirs) => Ast.EnumValueDefinition(name, desc.map(_.value), dirs.getOrElse(Nil)) + } - lazy val OperationType: Parser[Ast.OperationType] = - keyword("query") .as(Ast.OperationType.Query) | - keyword("mutation") .as(Ast.OperationType.Mutation) | - keyword("subscription").as(Ast.OperationType.Subscription) + lazy val DirectiveLocations: Parser0[List[Ast.DirectiveLocation]] = + punctuation("|").? *> DirectiveLocation.repSep0(punctuation("|")) + + lazy val DirectiveLocation: Parser[Ast.DirectiveLocation] = + keyword("QUERY") .as(Ast.DirectiveLocation.QUERY) | + keyword("MUTATION") .as(Ast.DirectiveLocation.MUTATION) | + keyword("SUBSCRIPTION").as(Ast.DirectiveLocation.SUBSCRIPTION) | + keyword("FIELD_DEFINITION").as(Ast.DirectiveLocation.FIELD_DEFINITION) | + keyword("FIELD").as(Ast.DirectiveLocation.FIELD) | + keyword("FRAGMENT_DEFINITION").as(Ast.DirectiveLocation.FRAGMENT_DEFINITION) | + keyword("FRAGMENT_SPREAD").as(Ast.DirectiveLocation.FRAGMENT_SPREAD) | + keyword("INLINE_FRAGMENT").as(Ast.DirectiveLocation.INLINE_FRAGMENT) | + keyword("VARIABLE_DEFINITION").as(Ast.DirectiveLocation.VARIABLE_DEFINITION) | + keyword("SCHEMA").as(Ast.DirectiveLocation.SCHEMA) | + keyword("SCALAR").as(Ast.DirectiveLocation.SCALAR) | + keyword("OBJECT").as(Ast.DirectiveLocation.OBJECT) | + keyword("ARGUMENT_DEFINITION").as(Ast.DirectiveLocation.ARGUMENT_DEFINITION) | + keyword("INTERFACE").as(Ast.DirectiveLocation.INTERFACE) | + keyword("UNION").as(Ast.DirectiveLocation.UNION) | + keyword("ENUM_VALUE").as(Ast.DirectiveLocation.ENUM_VALUE) | + keyword("ENUM").as(Ast.DirectiveLocation.ENUM) | + keyword("INPUT_OBJECT").as(Ast.DirectiveLocation.INPUT_OBJECT) | + keyword("INPUT_FIELD_DEFINITION").as(Ast.DirectiveLocation.INPUT_FIELD_DEFINITION) + + lazy val ExecutableDefinition: Parser[Ast.ExecutableDefinition] = + OperationDefinition | FragmentDefinition + + lazy val OperationDefinition: Parser[Ast.OperationDefinition] = + QueryShorthand | Operation + + lazy val QueryShorthand: Parser[Ast.OperationDefinition.QueryShorthand] = + SelectionSet.map(Ast.OperationDefinition.QueryShorthand.apply) + + lazy val Operation: Parser[Ast.OperationDefinition.Operation] = + (OperationType ~ Name.? ~ VariableDefinitions.? ~ Directives ~ SelectionSet).map { + case ((((op, name), vars), dirs), sels) => Ast.OperationDefinition.Operation(op, name, vars.orEmpty, dirs, sels) + } - lazy val SelectionSet: Parser[List[Ast.Selection]] = recursive[List[Ast.Selection]] { rec => + lazy val OperationType: Parser[Ast.OperationType] = + keyword("query") .as(Ast.OperationType.Query) | + keyword("mutation") .as(Ast.OperationType.Mutation) | + keyword("subscription").as(Ast.OperationType.Subscription) - val Alias: Parser[Ast.Name] = + lazy val Alias: Parser[Ast.Name] = Name <* punctuation(":") - val Field: Parser[Ast.Selection.Field] = - (Alias.backtrack.?.with1 ~ Name ~ Arguments.? ~ Directives ~ rec.?).map { + lazy val FragmentSpread: Parser[Ast.Selection.FragmentSpread] = + (FragmentName ~ Directives).map{ case (name, dirs) => Ast.Selection.FragmentSpread.apply(name, dirs)} + + def Field(n: Int): Parser[Ast.Selection.Field] = + (Alias.backtrack.?.with1 ~ Name ~ Arguments.? ~ Directives ~ SelectionSetN(n).?).map { case ((((alias, name), args), dirs), sel) => Ast.Selection.Field(alias, name, args.orEmpty, dirs, sel.orEmpty) } - val FragmentSpread: Parser[Ast.Selection.FragmentSpread] = - (FragmentName ~ Directives).map{ case (name, dirs) => Ast.Selection.FragmentSpread.apply(name, dirs)} - - val InlineFragment: Parser[Ast.Selection.InlineFragment] = - ((TypeCondition.? ~ Directives).with1 ~ rec).map { + def InlineFragment(n: Int): Parser[Ast.Selection.InlineFragment] = + ((TypeCondition.? ~ Directives).with1 ~ SelectionSetN(n)).map { case ((cond, dirs), sel) => Ast.Selection.InlineFragment(cond, dirs, sel) } - val Selection: Parser[Ast.Selection] = - Field | - (punctuation("...") *> (InlineFragment | FragmentSpread)) + def Selection(n: Int): Parser[Ast.Selection] = + Field(n) | + (punctuation("...") *> (InlineFragment(n) | FragmentSpread)) - braces(Selection.rep0) - } + lazy val SelectionSet: Parser[List[Ast.Selection]] = + SelectionSetN(maxSelectionDepth) - lazy val Arguments: Parser[List[(Ast.Name, Ast.Value)]] = - parens(Argument.rep0) + def SelectionSetN(n: Int): Parser[List[Ast.Selection]] = + braces(guard0(n, "exceeded maximum selection depth")(Selection(_).repAs0(max = maxSelectionWidth))) - lazy val Argument: Parser[(Ast.Name, Ast.Value)] = - (Name <* punctuation(":")) ~ Value + lazy val Arguments: Parser[List[(Ast.Name, Ast.Value)]] = + parens(Argument.rep0) - lazy val FragmentName: Parser[Ast.Name] = - not(string("on")).with1 *> Name + lazy val Argument: Parser[(Ast.Name, Ast.Value)] = + (Name <* punctuation(":")) ~ Value - lazy val FragmentDefinition: Parser[Ast.FragmentDefinition] = - ((keyword("fragment") *> FragmentName) ~ TypeCondition ~ Directives ~ SelectionSet).map { - case (((name, cond), dirs), sel) => Ast.FragmentDefinition(name, cond, dirs, sel) - } + lazy val FragmentName: Parser[Ast.Name] = + not(string("on")).with1 *> Name - lazy val TypeCondition: Parser[Ast.Type.Named] = - keyword("on") *> NamedType + lazy val FragmentDefinition: Parser[Ast.FragmentDefinition] = + ((keyword("fragment") *> FragmentName) ~ TypeCondition ~ Directives ~ SelectionSet).map { + case (((name, cond), dirs), sel) => Ast.FragmentDefinition(name, cond, dirs, sel) + } - lazy val Value: Parser[Ast.Value] = recursive[Ast.Value] { rec => + lazy val TypeCondition: Parser[Ast.Type.Named] = + keyword("on") *> NamedType - val NullValue: Parser[Ast.Value.NullValue.type] = + lazy val NullValue: Parser[Ast.Value.NullValue.type] = keyword("null").as(Ast.Value.NullValue) lazy val EnumValue: Parser[Ast.Value.EnumValue] = (not(string("true") | string("false") | string("null")).with1 *> Name) .map(Ast.Value.EnumValue.apply) - val ListValue: Parser[Ast.Value.ListValue] = - token(squareBrackets(rec.rep0).map(Ast.Value.ListValue.apply)) - - val NumericLiteral: Parser[Ast.Value] = { + def ListValue(n: Int): Parser[Ast.Value.ListValue] = + token(squareBrackets(guard0(n, "exceeded maximum input value depth")(ValueN(_).rep0)).map(Ast.Value.ListValue.apply)) + lazy val NumericLiteral: Parser[Ast.Value] = { def narrow(d: BigDecimal): Ast.Value.FloatValue = Ast.Value.FloatValue(d.toDouble) @@ -301,204 +331,199 @@ object GraphQLParser { ) } - val BooleanValue: Parser[Ast.Value.BooleanValue] = + lazy val BooleanValue: Parser[Ast.Value.BooleanValue] = token(booleanLiteral).map(Ast.Value.BooleanValue.apply) - val ObjectField: Parser[(Ast.Name, Ast.Value)] = - (Name <* punctuation(":")) ~ rec + def ObjectField(n: Int): Parser[(Ast.Name, Ast.Value)] = + (Name <* punctuation(":")) ~ ValueN(n) - val ObjectValue: Parser[Ast.Value.ObjectValue] = - braces(ObjectField.rep0).map(Ast.Value.ObjectValue.apply) + def ObjectValue(n: Int): Parser[Ast.Value.ObjectValue] = + braces(guard0(n, "exceeded maximum input value depth")(ObjectField(_).rep0)).map(Ast.Value.ObjectValue.apply) - Variable | - NumericLiteral | - StringValue | - BooleanValue | - NullValue | - EnumValue | - ListValue | - ObjectValue - } + lazy val StringValue: Parser[Ast.Value.StringValue] = + token(stringLiteral).map(Ast.Value.StringValue.apply) - lazy val StringValue: Parser[Ast.Value.StringValue] = - token(stringLiteral).map(Ast.Value.StringValue.apply) + def ValueN(n: Int): Parser[Ast.Value] = + Variable | + NumericLiteral | + StringValue | + BooleanValue | + NullValue | + EnumValue | + ListValue(n) | + ObjectValue(n) - lazy val VariableDefinitions: Parser[List[Ast.VariableDefinition]] = - parens(VariableDefinition.rep0) + lazy val Value: Parser[Ast.Value] = + ValueN(maxInputValueDepth) - lazy val VariableDefinition: Parser[Ast.VariableDefinition] = - ((Variable <* punctuation(":")) ~ Type ~ DefaultValue.? ~ Directives.?).map { - case (((v, tpe), dv), dirs) => Ast.VariableDefinition(v.name, tpe, dv, dirs.getOrElse(Nil)) - } + lazy val VariableDefinitions: Parser[List[Ast.VariableDefinition]] = + parens(VariableDefinition.rep0) - lazy val Variable: Parser[Ast.Value.Variable] = - punctuation("$") *> Name.map(Ast.Value.Variable.apply) + lazy val VariableDefinition: Parser[Ast.VariableDefinition] = + ((Variable <* punctuation(":")) ~ Type ~ DefaultValue.? ~ Directives.?).map { + case (((v, tpe), dv), dirs) => Ast.VariableDefinition(v.name, tpe, dv, dirs.getOrElse(Nil)) + } - lazy val DefaultValue: Parser[Ast.Value] = - punctuation("=") *> Value + lazy val Variable: Parser[Ast.Value.Variable] = + punctuation("$") *> Name.map(Ast.Value.Variable.apply) - lazy val Type: Parser[Ast.Type] = recursive[Ast.Type] { rec => + lazy val DefaultValue: Parser[Ast.Value] = + punctuation("=") *> Value - lazy val ListType: Parser[Ast.Type.List] = - squareBrackets(rec).map(Ast.Type.List.apply) + def ListType(n: Int): Parser[Ast.Type.List] = + squareBrackets(guard(n, "exceeded maximum list type depth")(TypeN)).map(Ast.Type.List.apply) - val namedMaybeNull: Parser[Ast.Type] = (NamedType ~ punctuation("!").?).map { + lazy val namedMaybeNull: Parser[Ast.Type] = (NamedType ~ punctuation("!").?).map { case (t, None) => t case (t, _) => Ast.Type.NonNull(Left(t)) } - val listMaybeNull: Parser[Ast.Type] = (ListType ~ punctuation("!").?).map { + def listMaybeNull(n: Int): Parser[Ast.Type] = (ListType(n) ~ punctuation("!").?).map { case (t, None) => t case (t, _) => Ast.Type.NonNull(Right(t)) } - namedMaybeNull | listMaybeNull - } + def TypeN(n: Int): Parser[Ast.Type] = + namedMaybeNull | listMaybeNull(n) - lazy val NamedType: Parser[Ast.Type.Named] = - Name.map(Ast.Type.Named.apply) + lazy val Type: Parser[Ast.Type] = + TypeN(maxListTypeDepth) - lazy val Directives: Parser0[List[Ast.Directive]] = - Directive.rep0 + lazy val NamedType: Parser[Ast.Type.Named] = + Name.map(Ast.Type.Named.apply) - lazy val Directive: Parser[Ast.Directive] = - punctuation("@") *> (Name ~ Arguments.?).map { case (n, ods) => Ast.Directive(n, ods.orEmpty)} + lazy val Directives: Parser0[List[Ast.Directive]] = + Directive.rep0 - lazy val Name: Parser[Ast.Name] = - token(charIn(nameInitial) ~ charIn(nameSubsequent).rep0).map { - case (h, t) => Ast.Name((h :: t).mkString) - } + lazy val Directive: Parser[Ast.Directive] = + punctuation("@") *> (Name ~ Arguments.?).map { case (n, ods) => Ast.Directive(n, ods.orEmpty)} - def toResult[T](text: String, pr: Either[Parser.Error, T]): Result[T] = - Result.fromEither(pr.leftMap { e => - val lm = LocationMap(text) - lm.toLineCol(e.failedAtOffset) match { - case Some((row, col)) => - lm.getLine(row) match { - case Some(line) => - s"""Parse error at line $row column $col - |$line - |${List.fill(col)(" ").mkString}^""".stripMargin - case None => "Malformed query" //This is probably a bug in Cats Parse as it has given us the (row, col) index - } - case None => "Truncated query" + lazy val Name: Parser[Ast.Name] = + token(charIn(nameInitial) ~ charIn(nameSubsequent).rep0).map { + case (h, t) => Ast.Name((h :: t).mkString) } - }) -} -object CommentedText { + def guard0[T](n: Int, msg: String)(p: Int => Parser0[T]): Parser0[T] = + if (n <= 0) Parser.failWith(msg) else defer0(p(n-1)) + + def guard[T](n: Int, msg: String)(p: Int => Parser[T]): Parser[T] = + if (n <= 0) Parser.failWith(msg) else defer(p(n-1)) + } - val whitespace: Parser[Char] = charWhere(_.isWhitespace) + private object CommentedText { - val skipWhitespace: Parser0[Unit] = - charsWhile0(c => c.isWhitespace || c == ',').void.withContext("whitespace") + val whitespace: Parser[Char] = charWhere(_.isWhitespace) - /** Parser that consumes a comment */ - val comment: Parser[Unit] = - (char('#') *> (charWhere(c => c != '\n' && c != '\r')).rep0 <* charIn('\n', '\r') <* skipWhitespace).void.withContext("comment") + val skipWhitespace: Parser0[Unit] = + charsWhile0(c => c.isWhitespace || c == ',').void - /** Turns a parser into one that skips trailing whitespace and comments */ - def token[A](p: Parser[A]): Parser[A] = - p <* skipWhitespace <* comment.rep0 + /** Parser that consumes a comment */ + val comment: Parser[Unit] = + (char('#') *> (charWhere(c => c != '\n' && c != '\r')).rep0 <* charIn('\n', '\r') <* skipWhitespace).void - def token0[A](p: Parser0[A]): Parser0[A] = - p <* skipWhitespace <* comment.rep0 + /** Turns a parser into one that skips trailing whitespace and comments */ + def token[A](p: Parser[A]): Parser[A] = + p <* skipWhitespace <* comment.rep0 - /** - * Consumes `left` and `right`, including the trailing and preceding whitespace, - * respectively, and returns the value of `p`. - */ - private def _bracket[A,B,C](left: Parser[B], p: Parser0[A], right: Parser[C]): Parser[A] = - token(left) *> token0(p) <* token(right) + def token0[A](p: Parser0[A]): Parser0[A] = + p <* skipWhitespace <* comment.rep0 - /** Turns a parser into one that consumes surrounding parentheses `()` */ - def parens[A](p: Parser0[A]): Parser[A] = - _bracket(char('('), p, char(')')).withContext(s"parens(${p.toString})") + /** + * Consumes `left` and `right`, including the trailing and preceding whitespace, + * respectively, and returns the value of `p`. + */ + private def _bracket[A,B,C](left: Parser[B], p: Parser0[A], right: Parser[C]): Parser[A] = + token(left) *> token0(p) <* token(right) - /** Turns a parser into one that consumes surrounding curly braces `{}` */ - def braces[A](p: Parser0[A]): Parser[A] = - _bracket(char('{'), p, char('}')).withContext(s"braces(${p.toString})") + /** Turns a parser into one that consumes surrounding parentheses `()` */ + def parens[A](p: Parser0[A]): Parser[A] = + _bracket(char('('), p, char(')')) - /** Turns a parser into one that consumes surrounding square brackets `[]` */ - def squareBrackets[A](p: Parser0[A]): Parser[A] = - _bracket(char('['), p, char(']')).withContext(s"squareBrackets(${p.toString})") -} + /** Turns a parser into one that consumes surrounding curly braces `{}` */ + def braces[A](p: Parser0[A]): Parser[A] = + _bracket(char('{'), p, char('}')) -object Literals { + /** Turns a parser into one that consumes surrounding square brackets `[]` */ + def squareBrackets[A](p: Parser0[A]): Parser[A] = + _bracket(char('['), p, char(']')) + } - val stringLiteral: Parser[String] = { + private object Literals { - val lineTerminator: Parser[String] = (lf | cr | crlf).string + val stringLiteral: Parser[String] = { - val sourceCharacter: Parser[String] = (charIn(0x0009.toChar, 0x000A.toChar, 0x000D.toChar) | charIn(0x0020.toChar to 0xFFFF.toChar)).string + val lineTerminator: Parser[String] = (lf | cr | crlf).string - val escapedUnicode: Parser[String] = string("\\u") *> - hexdig - .repExactlyAs[String](4) - .map(hex => Integer.parseInt(hex, 16).toChar.toString) + val sourceCharacter: Parser[String] = (charIn(0x0009.toChar, 0x000A.toChar, 0x000D.toChar) | charIn(0x0020.toChar to 0xFFFF.toChar)).string - val escapedCharacter: Parser[String] = char('\\') *> - ( - char('"').as("\"") | - char('\\').as("\\") | - char('/').as("/") | - char('b').as("\b") | - char('f').as("\f") | - char('n').as("\n") | - char('r').as("\r") | - char('t').as("\t") - ) + val escapedUnicode: Parser[String] = string("\\u") *> + hexdig + .repExactlyAs[String](4) + .map(hex => Integer.parseInt(hex, 16).toChar.toString) - val stringCharacter: Parser[String] = ( - (not(charIn('"', '\\') | lineTerminator).with1 *> sourceCharacter) | - escapedUnicode | - escapedCharacter - ) + val escapedCharacter: Parser[String] = char('\\') *> + ( + char('"').as("\"") | + char('\\').as("\\") | + char('/').as("/") | + char('b').as("\b") | + char('f').as("\f") | + char('n').as("\n") | + char('r').as("\r") | + char('t').as("\t") + ) - val blockStringCharacter: Parser[String] = string("\\\"\"\"").as("\"\"\"") | - (not(string("\"\"\"")).with1 *> sourceCharacter) - - //https://spec.graphql.org/June2018/#BlockStringValue() - //TODO this traverses over lines a hideous number of times(but matching the - //algorithm in the spec). Can it be optimized? - val blockQuotesInner: Parser0[String] = blockStringCharacter.repAs0[String].map { str => - val isWhitespace: Regex = "[ \t]*".r - var commonIndent: Int = -1 - var lineNum: Int = 0 - for (line <- str.linesIterator) { - if (lineNum != 0) { - val len = line.length() - val indent = line.takeWhile(c => c == ' ' || c == '\t').length() - if (indent < len) { - if (commonIndent < 0 || indent < commonIndent) { - commonIndent = indent + val stringCharacter: Parser[String] = ( + (not(charIn('"', '\\') | lineTerminator).with1 *> sourceCharacter) | + escapedUnicode | + escapedCharacter + ) + + val blockStringCharacter: Parser[String] = string("\\\"\"\"").as("\"\"\"") | + (not(string("\"\"\"")).with1 *> sourceCharacter) + + //https://spec.graphql.org/June2018/#BlockStringValue() + //TODO this traverses over lines a hideous number of times(but matching the + //algorithm in the spec). Can it be optimized? + val blockQuotesInner: Parser0[String] = blockStringCharacter.repAs0[String].map { str => + val isWhitespace: Regex = "[ \t]*".r + var commonIndent: Int = -1 + var lineNum: Int = 0 + for (line <- str.linesIterator) { + if (lineNum != 0) { + val len = line.length() + val indent = line.takeWhile(c => c == ' ' || c == '\t').length() + if (indent < len) { + if (commonIndent < 0 || indent < commonIndent) { + commonIndent = indent + } } } + lineNum = lineNum + 1 } - lineNum = lineNum + 1 - } - val formattedReversed: List[String] = if ( commonIndent >= 0) { - str.linesIterator.foldLeft[List[String]](Nil) { - (acc, l) => if (acc == Nil) l :: acc else l.drop(commonIndent) :: acc + val formattedReversed: List[String] = if ( commonIndent >= 0) { + str.linesIterator.foldLeft[List[String]](Nil) { + (acc, l) => if (acc == Nil) l :: acc else l.drop(commonIndent) :: acc + } + } else { + str.linesIterator.toList } - } else { - str.linesIterator.toList + val noTrailingEmpty = formattedReversed.dropWhile(isWhitespace.matches(_)).reverse + noTrailingEmpty.dropWhile(isWhitespace.matches(_)).mkString("\n") } - val noTrailingEmpty = formattedReversed.dropWhile(isWhitespace.matches(_)).reverse - noTrailingEmpty.dropWhile(isWhitespace.matches(_)).mkString("\n") - } - (not(string("\"\"\"")).with1 *> stringCharacter.repAs0[String].with1.surroundedBy(char('"'))) | blockQuotesInner.with1.surroundedBy(string("\"\"\"")) + (not(string("\"\"\"")).with1 *> stringCharacter.repAs0[String].with1.surroundedBy(char('"'))) | blockQuotesInner.with1.surroundedBy(string("\"\"\"")) - } - - val intLiteral: Parser[Int] = - bigInt.flatMap { - case v if v.isValidInt => pure(v.toInt) - case v => failWith(s"$v is larger than max int") } - val booleanLiteral: Parser[Boolean] = string("true").as(true) | string("false").as(false) + val intLiteral: Parser[Int] = + bigInt.flatMap { + case v if v.isValidInt => pure(v.toInt) + case v => failWith(s"$v is larger than max int") + } + val booleanLiteral: Parser[Boolean] = string("true").as(true) | string("false").as(false) + + } } diff --git a/modules/core/src/main/scala/schema.scala b/modules/core/src/main/scala/schema.scala index 17ad69c5..8d7439f7 100644 --- a/modules/core/src/main/scala/schema.scala +++ b/modules/core/src/main/scala/schema.scala @@ -228,7 +228,10 @@ trait Schema { object Schema { def apply(schemaText: String)(implicit pos: SourcePos): Result[Schema] = - SchemaParser.parseText(schemaText) + apply(schemaText, SchemaParser(GraphQLParser(GraphQLParser.defaultConfig))) + + def apply(schemaText: String, parser: SchemaParser)(implicit pos: SourcePos): Result[Schema] = + parser.parseText(schemaText) } case class SchemaExtension( @@ -940,6 +943,23 @@ object Value { case object AbsentValue extends Value + def fromAst(value: Ast.Value): Result[Value] = { + value match { + case Ast.Value.IntValue(i) => IntValue(i).success + case Ast.Value.FloatValue(d) => FloatValue(d).success + case Ast.Value.StringValue(s) => StringValue(s).success + case Ast.Value.BooleanValue(b) => BooleanValue(b).success + case Ast.Value.EnumValue(e) => EnumValue(e.value).success + case Ast.Value.Variable(v) => VariableRef(v.value).success + case Ast.Value.NullValue => NullValue.success + case Ast.Value.ListValue(vs) => vs.traverse(fromAst).map(ListValue(_)) + case Ast.Value.ObjectValue(fs) => + fs.traverse { case (name, value) => + fromAst(value).map(v => (name.value, v)) + }.map(ObjectValue(_)) + } + } + object StringListValue { def apply(ss: List[String]): Value = ListValue(ss.map(StringValue(_))) @@ -961,7 +981,7 @@ object Value { def loop(value: Value): Result[Value] = value match { case VariableRef(varName) => - Result.fromOption(vars.get(varName).map(_._2), s"Undefined variable '$varName'") + Result.fromOption(vars.get(varName).map(_._2), s"Variable '$varName' is undefined") case ObjectValue(fields) => val (keys, values) = fields.unzip values.traverse(loop).map(evs => ObjectValue(keys.zip(evs))) @@ -1166,6 +1186,13 @@ case class Directive( ) object Directive { + def fromAst(d: Ast.Directive): Result[Directive] = { + val Ast.Directive(Ast.Name(nme), args) = d + args.traverse { + case (Ast.Name(nme), value) => Value.fromAst(value).map(Binding(nme, _)) + }.map(Directive(nme, _)) + } + def validateDirectivesForSchema(schema: Schema): List[Problem] = { def validateTypeDirectives(tpe: NamedType): List[Problem] = tpe match { @@ -1325,268 +1352,254 @@ object Directive { /** * GraphQL schema parser */ +trait SchemaParser { + def parseText(text: String)(implicit pos: SourcePos): Result[Schema] + def parseDocument(doc: Ast.Document)(implicit sourcePos: SourcePos): Result[Schema] +} + object SchemaParser { + def apply(parser: GraphQLParser): SchemaParser = + new Impl(parser) - import Ast.{Directive => _, EnumValueDefinition => _, SchemaExtension => _, Type => _, TypeExtension => _, Value => _, _} + private final class Impl(parser: GraphQLParser) extends SchemaParser { - /** - * Parse a query String to a query algebra term. - * - * Yields a Query value on the right and accumulates errors on the left. - */ - def parseText(text: String)(implicit pos: SourcePos): Result[Schema] = - for { - doc <- GraphQLParser.toResult(text, GraphQLParser.Document.parseAll(text)) - query <- parseDocument(doc) - } yield query - - def parseDocument(doc: Document)(implicit sourcePos: SourcePos): Result[Schema] = { - object schema extends Schema { - var baseTypes: List[NamedType] = Nil - var baseSchemaType1: Option[NamedType] = null - var pos: SourcePos = sourcePos - - override def baseSchemaType: NamedType = baseSchemaType1.getOrElse(super.baseSchemaType) - - var directives: List[DirectiveDef] = Nil - var schemaExtensions: List[SchemaExtension] = Nil - var typeExtensions: List[TypeExtension] = Nil - - def complete(types0: List[NamedType], baseSchemaType0: Option[NamedType], directives0: List[DirectiveDef], schemaExtensions0: List[SchemaExtension], typeExtensions0: List[TypeExtension]): Unit = { - baseTypes = types0 - baseSchemaType1 = baseSchemaType0 - directives = directives0 ++ DirectiveDef.builtIns - schemaExtensions = schemaExtensions0 - typeExtensions = typeExtensions0 - } - } + import Ast.{Directive => _, EnumValueDefinition => _, SchemaExtension => _, Type => _, TypeExtension => _, Value => _, _} - val schemaExtnDefns: List[Ast.SchemaExtension] = doc.collect { case tpe: Ast.SchemaExtension => tpe } - val typeDefns: List[TypeDefinition] = doc.collect { case tpe: TypeDefinition => tpe } - val dirDefns: List[DirectiveDefinition] = doc.collect { case dir: DirectiveDefinition => dir } - val extnDefns: List[Ast.TypeExtension] = doc.collect { case tpe: Ast.TypeExtension => tpe } + /** + * Parse a query String to a query algebra term. + * + * Yields a Query value on the right and accumulates errors on the left. + */ + def parseText(text: String)(implicit pos: SourcePos): Result[Schema] = + for { + doc <- parser.parseText(text) + query <- parseDocument(doc) + } yield query + + def parseDocument(doc: Document)(implicit sourcePos: SourcePos): Result[Schema] = { + object schema extends Schema { + var baseTypes: List[NamedType] = Nil + var baseSchemaType1: Option[NamedType] = null + var pos: SourcePos = sourcePos + + override def baseSchemaType: NamedType = baseSchemaType1.getOrElse(super.baseSchemaType) + + var directives: List[DirectiveDef] = Nil + var schemaExtensions: List[SchemaExtension] = Nil + var typeExtensions: List[TypeExtension] = Nil + + def complete(types0: List[NamedType], baseSchemaType0: Option[NamedType], directives0: List[DirectiveDef], schemaExtensions0: List[SchemaExtension], typeExtensions0: List[TypeExtension]): Unit = { + baseTypes = types0 + baseSchemaType1 = baseSchemaType0 + directives = directives0 ++ DirectiveDef.builtIns + schemaExtensions = schemaExtensions0 + typeExtensions = typeExtensions0 + } + } - for { - baseTypes <- mkTypeDefs(schema, typeDefns) - schemaExtns <- mkSchemaExtensions(schema, schemaExtnDefns) - typeExtns <- mkExtensions(schema, extnDefns) - directives <- mkDirectiveDefs(schema, dirDefns) - schemaType <- mkSchemaType(schema, doc) - _ = schema.complete(baseTypes, schemaType, directives, schemaExtns, typeExtns) - _ <- Result.fromProblems(SchemaValidator.validateSchema(schema, typeDefns, extnDefns)) - } yield schema - } + val schemaExtnDefns: List[Ast.SchemaExtension] = doc.collect { case tpe: Ast.SchemaExtension => tpe } + val typeDefns: List[TypeDefinition] = doc.collect { case tpe: TypeDefinition => tpe } + val dirDefns: List[DirectiveDefinition] = doc.collect { case dir: DirectiveDefinition => dir } + val extnDefns: List[Ast.TypeExtension] = doc.collect { case tpe: Ast.TypeExtension => tpe } - // explicit Schema type, if any - def mkSchemaType(schema: Schema, doc: Document): Result[Option[NamedType]] = { - def build(dirs: List[Directive], ops: List[Field]): NamedType = { - val query = ops.find(_.name == "query").getOrElse(Field("query", None, Nil, defaultQueryType, Nil)) - ObjectType( - name = "Schema", - description = None, - fields = query :: List(ops.find(_.name == "mutation"), ops.find(_.name == "subscription")).flatten, - interfaces = Nil, - directives = dirs - ) + for { + baseTypes <- mkTypeDefs(schema, typeDefns) + schemaExtns <- mkSchemaExtensions(schema, schemaExtnDefns) + typeExtns <- mkExtensions(schema, extnDefns) + directives <- mkDirectiveDefs(schema, dirDefns) + schemaType <- mkSchemaType(schema, doc) + _ = schema.complete(baseTypes, schemaType, directives, schemaExtns, typeExtns) + _ <- Result.fromProblems(SchemaValidator.validateSchema(schema, typeDefns, extnDefns)) + } yield schema } - def defaultQueryType = schema.ref("Query") + // explicit Schema type, if any + def mkSchemaType(schema: Schema, doc: Document): Result[Option[NamedType]] = { + def build(dirs: List[Directive], ops: List[Field]): NamedType = { + val query = ops.find(_.name == "query").getOrElse(Field("query", None, Nil, defaultQueryType, Nil)) + ObjectType( + name = "Schema", + description = None, + fields = query :: List(ops.find(_.name == "mutation"), ops.find(_.name == "subscription")).flatten, + interfaces = Nil, + directives = dirs + ) + } + + def defaultQueryType = schema.ref("Query") - val defns = doc.collect { case schema: SchemaDefinition => schema } - defns match { - case Nil => None.success - case SchemaDefinition(rootOpTpes, dirs0) :: Nil => - for { - ops <- rootOpTpes.traverse(mkRootOperation(schema)) - dirs <- dirs0.traverse(mkDirective) - } yield Some(build(dirs, ops)) + val defns = doc.collect { case schema: SchemaDefinition => schema } + defns match { + case Nil => None.success + case SchemaDefinition(rootOpTpes, dirs0) :: Nil => + for { + ops <- rootOpTpes.traverse(mkRootOperation(schema)) + dirs <- dirs0.traverse(Directive.fromAst) + } yield Some(build(dirs, ops)) - case _ => Result.failure("At most one schema definition permitted") + case _ => Result.failure("At most one schema definition permitted") + } } - } - def mkSchemaExtensions(schema: Schema, extnDefns: List[Ast.SchemaExtension]): Result[List[SchemaExtension]] = - extnDefns.traverse(mkSchemaExtension(schema)) + def mkSchemaExtensions(schema: Schema, extnDefns: List[Ast.SchemaExtension]): Result[List[SchemaExtension]] = + extnDefns.traverse(mkSchemaExtension(schema)) - def mkSchemaExtension(schema: Schema)(se: Ast.SchemaExtension): Result[SchemaExtension] = { - val Ast.SchemaExtension(rootOpTpes, dirs0) = se - for { - ops <- rootOpTpes.traverse(mkRootOperation(schema)) - dirs <- dirs0.traverse(mkDirective) - } yield SchemaExtension(ops, dirs) - } + def mkSchemaExtension(schema: Schema)(se: Ast.SchemaExtension): Result[SchemaExtension] = { + val Ast.SchemaExtension(rootOpTpes, dirs0) = se + for { + ops <- rootOpTpes.traverse(mkRootOperation(schema)) + dirs <- dirs0.traverse(Directive.fromAst) + } yield SchemaExtension(ops, dirs) + } - def mkRootOperation(schema: Schema)(rootTpe: RootOperationTypeDefinition): Result[Field] = { - val RootOperationTypeDefinition(optype, tpe, dirs0) = rootTpe - for { - dirs <- dirs0.traverse(mkDirective) - tpe <- mkType(schema)(tpe) - _ <- Result.failure(s"Root operation types must be named types, found '$tpe'").whenA(!tpe.nonNull.isNamed) - } yield Field(optype.name, None, Nil, tpe, dirs) - } + def mkRootOperation(schema: Schema)(rootTpe: RootOperationTypeDefinition): Result[Field] = { + val RootOperationTypeDefinition(optype, tpe, dirs0) = rootTpe + for { + dirs <- dirs0.traverse(Directive.fromAst) + tpe <- mkType(schema)(tpe) + _ <- Result.failure(s"Root operation types must be named types, found '$tpe'").whenA(!tpe.nonNull.isNamed) + } yield Field(optype.name, None, Nil, tpe, dirs) + } - def mkExtensions(schema: Schema, extnDefns: List[Ast.TypeExtension]): Result[List[TypeExtension]] = - extnDefns.traverse(mkExtension(schema)) + def mkExtensions(schema: Schema, extnDefns: List[Ast.TypeExtension]): Result[List[TypeExtension]] = + extnDefns.traverse(mkExtension(schema)) + + def mkExtension(schema: Schema)(ed: Ast.TypeExtension): Result[TypeExtension] = + ed match { + case ScalarTypeExtension(Ast.Type.Named(Name(name)), dirs0) => + for { + dirs <- dirs0.traverse(Directive.fromAst) + } yield ScalarExtension(name, dirs) + case InterfaceTypeExtension(Ast.Type.Named(Name(name)), fields0, ifs0, dirs0) => + for { + fields <- fields0.traverse(mkField(schema)) + ifs = ifs0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } + dirs <- dirs0.traverse(Directive.fromAst) + } yield InterfaceExtension(name, fields, ifs, dirs) + case ObjectTypeExtension(Ast.Type.Named(Name(name)), fields0, ifs0, dirs0) => + for { + fields <- fields0.traverse(mkField(schema)) + ifs = ifs0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } + dirs <- dirs0.traverse(Directive.fromAst) + } yield ObjectExtension(name, fields, ifs, dirs) + case UnionTypeExtension(Ast.Type.Named(Name(name)), dirs0, members0) => + for { + dirs <- dirs0.traverse(Directive.fromAst) + members = members0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } + } yield UnionExtension(name, members, dirs) + case EnumTypeExtension(Ast.Type.Named(Name(name)), dirs0, values0) => + for { + values <- values0.traverse(mkEnumValue) + dirs <- dirs0.traverse(Directive.fromAst) + } yield EnumExtension(name, values, dirs) + case InputObjectTypeExtension(Ast.Type.Named(Name(name)), dirs0, fields0) => + for { + fields <- fields0.traverse(mkInputValue(schema)) + dirs <- dirs0.traverse(Directive.fromAst) + } yield InputObjectExtension(name, fields, dirs) + } - def mkExtension(schema: Schema)(ed: Ast.TypeExtension): Result[TypeExtension] = - ed match { - case ScalarTypeExtension(Ast.Type.Named(Name(name)), dirs0) => - for { - dirs <- dirs0.traverse(mkDirective) - } yield ScalarExtension(name, dirs) - case InterfaceTypeExtension(Ast.Type.Named(Name(name)), fields0, ifs0, dirs0) => - for { - fields <- fields0.traverse(mkField(schema)) - ifs = ifs0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } - dirs <- dirs0.traverse(mkDirective) - } yield InterfaceExtension(name, fields, ifs, dirs) - case ObjectTypeExtension(Ast.Type.Named(Name(name)), fields0, ifs0, dirs0) => - for { - fields <- fields0.traverse(mkField(schema)) - ifs = ifs0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } - dirs <- dirs0.traverse(mkDirective) - } yield ObjectExtension(name, fields, ifs, dirs) - case UnionTypeExtension(Ast.Type.Named(Name(name)), dirs0, members0) => - for { - dirs <- dirs0.traverse(mkDirective) - members = members0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } - } yield UnionExtension(name, members, dirs) - case EnumTypeExtension(Ast.Type.Named(Name(name)), dirs0, values0) => - for { - values <- values0.traverse(mkEnumValue) - dirs <- dirs0.traverse(mkDirective) - } yield EnumExtension(name, values, dirs) - case InputObjectTypeExtension(Ast.Type.Named(Name(name)), dirs0, fields0) => + def mkTypeDefs(schema: Schema, defns: List[TypeDefinition]): Result[List[NamedType]] = + defns.traverse(mkTypeDef(schema)) + + def mkTypeDef(schema: Schema)(td: TypeDefinition): Result[NamedType] = td match { + case ScalarTypeDefinition(Name("Int"), _, _) => IntType.success + case ScalarTypeDefinition(Name("Float"), _, _) => FloatType.success + case ScalarTypeDefinition(Name("String"), _, _) => StringType.success + case ScalarTypeDefinition(Name("Boolean"), _, _) => BooleanType.success + case ScalarTypeDefinition(Name("ID"), _, _) => IDType.success + case ScalarTypeDefinition(Name(nme), desc, dirs0) => for { - fields <- fields0.traverse(mkInputValue(schema)) - dirs <- dirs0.traverse(mkDirective) - } yield InputObjectExtension(name, fields, dirs) + dirs <- dirs0.traverse(Directive.fromAst) + } yield ScalarType(nme, desc, dirs) + case ObjectTypeDefinition(Name(nme), desc, fields0, ifs0, dirs0) => + if (fields0.isEmpty) Result.failure(s"object type $nme must define at least one field") + else + for { + fields <- fields0.traverse(mkField(schema)) + ifs = ifs0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } + dirs <- dirs0.traverse(Directive.fromAst) + } yield ObjectType(nme, desc, fields, ifs, dirs) + case InterfaceTypeDefinition(Name(nme), desc, fields0, ifs0, dirs0) => + if (fields0.isEmpty) Result.failure(s"interface type $nme must define at least one field") + else + for { + fields <- fields0.traverse(mkField(schema)) + ifs = ifs0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } + dirs <- dirs0.traverse(Directive.fromAst) + } yield InterfaceType(nme, desc, fields, ifs, dirs) + case UnionTypeDefinition(Name(nme), desc, dirs0, members0) => + if (members0.isEmpty) Result.failure(s"union type $nme must define at least one member") + else { + for { + dirs <- dirs0.traverse(Directive.fromAst) + members = members0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } + } yield UnionType(nme, desc, members, dirs) + } + case EnumTypeDefinition(Name(nme), desc, dirs0, values0) => + if (values0.isEmpty) Result.failure(s"enum type $nme must define at least one enum value") + else + for { + values <- values0.traverse(mkEnumValue) + dirs <- dirs0.traverse(Directive.fromAst) + } yield EnumType(nme, desc, values, dirs) + case InputObjectTypeDefinition(Name(nme), desc, fields0, dirs0) => + if (fields0.isEmpty) Result.failure(s"input object type $nme must define at least one input field") + else + for { + fields <- fields0.traverse(mkInputValue(schema)) + dirs <- dirs0.traverse(Directive.fromAst) + } yield InputObjectType(nme, desc, fields, dirs) } - def mkTypeDefs(schema: Schema, defns: List[TypeDefinition]): Result[List[NamedType]] = - defns.traverse(mkTypeDef(schema)) - - def mkTypeDef(schema: Schema)(td: TypeDefinition): Result[NamedType] = td match { - case ScalarTypeDefinition(Name("Int"), _, _) => IntType.success - case ScalarTypeDefinition(Name("Float"), _, _) => FloatType.success - case ScalarTypeDefinition(Name("String"), _, _) => StringType.success - case ScalarTypeDefinition(Name("Boolean"), _, _) => BooleanType.success - case ScalarTypeDefinition(Name("ID"), _, _) => IDType.success - case ScalarTypeDefinition(Name(nme), desc, dirs0) => + def mkField(schema: Schema)(f: FieldDefinition): Result[Field] = { + val FieldDefinition(Name(nme), desc, args0, tpe0, dirs0) = f for { - dirs <- dirs0.traverse(mkDirective) - } yield ScalarType(nme, desc, dirs) - case ObjectTypeDefinition(Name(nme), desc, fields0, ifs0, dirs0) => - if (fields0.isEmpty) Result.failure(s"object type $nme must define at least one field") - else - for { - fields <- fields0.traverse(mkField(schema)) - ifs = ifs0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } - dirs <- dirs0.traverse(mkDirective) - } yield ObjectType(nme, desc, fields, ifs, dirs) - case InterfaceTypeDefinition(Name(nme), desc, fields0, ifs0, dirs0) => - if (fields0.isEmpty) Result.failure(s"interface type $nme must define at least one field") - else - for { - fields <- fields0.traverse(mkField(schema)) - ifs = ifs0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } - dirs <- dirs0.traverse(mkDirective) - } yield InterfaceType(nme, desc, fields, ifs, dirs) - case UnionTypeDefinition(Name(nme), desc, dirs0, members0) => - if (members0.isEmpty) Result.failure(s"union type $nme must define at least one member") - else { - for { - dirs <- dirs0.traverse(mkDirective) - members = members0.map { case Ast.Type.Named(Name(nme)) => schema.ref(nme) } - } yield UnionType(nme, desc, members, dirs) - } - case EnumTypeDefinition(Name(nme), desc, dirs0, values0) => - if (values0.isEmpty) Result.failure(s"enum type $nme must define at least one enum value") - else - for { - values <- values0.traverse(mkEnumValue) - dirs <- dirs0.traverse(mkDirective) - } yield EnumType(nme, desc, values, dirs) - case InputObjectTypeDefinition(Name(nme), desc, fields0, dirs0) => - if (fields0.isEmpty) Result.failure(s"input object type $nme must define at least one input field") - else - for { - fields <- fields0.traverse(mkInputValue(schema)) - dirs <- dirs0.traverse(mkDirective) - } yield InputObjectType(nme, desc, fields, dirs) - } - - def mkDirective(d: Ast.Directive): Result[Directive] = { - val Ast.Directive(Name(nme), args) = d - args.traverse { - case (Name(nme), value) => parseValue(value).map(Binding(nme, _)) - }.map(Directive(nme, _)) - } - - def mkField(schema: Schema)(f: FieldDefinition): Result[Field] = { - val FieldDefinition(Name(nme), desc, args0, tpe0, dirs0) = f - for { - args <- args0.traverse(mkInputValue(schema)) - tpe <- mkType(schema)(tpe0) - dirs <- dirs0.traverse(mkDirective) - } yield Field(nme, desc, args, tpe, dirs) - } + args <- args0.traverse(mkInputValue(schema)) + tpe <- mkType(schema)(tpe0) + dirs <- dirs0.traverse(Directive.fromAst) + } yield Field(nme, desc, args, tpe, dirs) + } - def mkType(schema: Schema)(tpe: Ast.Type): Result[Type] = { - def loop(tpe: Ast.Type, nullable: Boolean): Result[Type] = { - def wrap(tpe: Type): Type = if (nullable) NullableType(tpe) else tpe + def mkType(schema: Schema)(tpe: Ast.Type): Result[Type] = { + def loop(tpe: Ast.Type, nullable: Boolean): Result[Type] = { + def wrap(tpe: Type): Type = if (nullable) NullableType(tpe) else tpe - tpe match { - case Ast.Type.List(tpe) => loop(tpe, true).map(tpe => wrap(ListType(tpe))) - case Ast.Type.NonNull(Left(tpe)) => loop(tpe, false) - case Ast.Type.NonNull(Right(tpe)) => loop(tpe, false) - case Ast.Type.Named(Name(nme)) => wrap(ScalarType.builtIn(nme).getOrElse(schema.ref(nme))).success + tpe match { + case Ast.Type.List(tpe) => loop(tpe, true).map(tpe => wrap(ListType(tpe))) + case Ast.Type.NonNull(Left(tpe)) => loop(tpe, false) + case Ast.Type.NonNull(Right(tpe)) => loop(tpe, false) + case Ast.Type.Named(Name(nme)) => wrap(ScalarType.builtIn(nme).getOrElse(schema.ref(nme))).success + } } - } - - loop(tpe, true) - } - def mkDirectiveDefs(schema: Schema, defns: List[DirectiveDefinition]): Result[List[DirectiveDef]] = - defns.traverse(mkDirectiveDef(schema)) + loop(tpe, true) + } - def mkDirectiveDef(schema: Schema)(dd: DirectiveDefinition): Result[DirectiveDef] = { - val DirectiveDefinition(Name(nme), desc, args0, repeatable, locations) = dd - for { - args <- args0.traverse(mkInputValue(schema)) - } yield DirectiveDef(nme, desc, args, repeatable, locations) - } + def mkDirectiveDefs(schema: Schema, defns: List[DirectiveDefinition]): Result[List[DirectiveDef]] = + defns.traverse(mkDirectiveDef(schema)) - def mkInputValue(schema: Schema)(f: InputValueDefinition): Result[InputValue] = { - val InputValueDefinition(Name(nme), desc, tpe0, default0, dirs0) = f - for { - tpe <- mkType(schema)(tpe0) - dflt <- default0.traverse(parseValue) - dirs <- dirs0.traverse(mkDirective) - } yield InputValue(nme, desc, tpe, dflt, dirs) - } + def mkDirectiveDef(schema: Schema)(dd: DirectiveDefinition): Result[DirectiveDef] = { + val DirectiveDefinition(Name(nme), desc, args0, repeatable, locations) = dd + for { + args <- args0.traverse(mkInputValue(schema)) + } yield DirectiveDef(nme, desc, args, repeatable, locations) + } - def mkEnumValue(e: Ast.EnumValueDefinition): Result[EnumValueDefinition] = { - val Ast.EnumValueDefinition(Name(nme), desc, dirs0) = e - for { - dirs <- dirs0.traverse(mkDirective) - } yield EnumValueDefinition(nme, desc, dirs) - } + def mkInputValue(schema: Schema)(f: InputValueDefinition): Result[InputValue] = { + val InputValueDefinition(Name(nme), desc, tpe0, default0, dirs0) = f + for { + tpe <- mkType(schema)(tpe0) + dflt <- default0.traverse(Value.fromAst) + dirs <- dirs0.traverse(Directive.fromAst) + } yield InputValue(nme, desc, tpe, dflt, dirs) + } - def parseValue(value: Ast.Value): Result[Value] = { - value match { - case Ast.Value.IntValue(i) => IntValue(i).success - case Ast.Value.FloatValue(d) => FloatValue(d).success - case Ast.Value.StringValue(s) => StringValue(s).success - case Ast.Value.BooleanValue(b) => BooleanValue(b).success - case Ast.Value.EnumValue(e) => EnumValue(e.value).success - case Ast.Value.Variable(v) => VariableRef(v.value).success - case Ast.Value.NullValue => NullValue.success - case Ast.Value.ListValue(vs) => vs.traverse(parseValue).map(ListValue(_)) - case Ast.Value.ObjectValue(fs) => - fs.traverse { case (name, value) => - parseValue(value).map(v => (name.value, v)) - }.map(ObjectValue(_)) + def mkEnumValue(e: Ast.EnumValueDefinition): Result[EnumValueDefinition] = { + val Ast.EnumValueDefinition(Name(nme), desc, dirs0) = e + for { + dirs <- dirs0.traverse(Directive.fromAst) + } yield EnumValueDefinition(nme, desc, dirs) } } } diff --git a/modules/core/src/test/scala/compiler/CompilerSuite.scala b/modules/core/src/test/scala/compiler/CompilerSuite.scala index a78e6608..25e6c1a7 100644 --- a/modules/core/src/test/scala/compiler/CompilerSuite.scala +++ b/modules/core/src/test/scala/compiler/CompilerSuite.scala @@ -26,6 +26,8 @@ import Predicate._, Value._, UntypedOperation._ import QueryCompiler._, ComponentElaborator.TrivialJoin final class CompilerSuite extends CatsEffectSuite { + val queryParser = QueryParser(GraphQLParser(GraphQLParser.defaultConfig)) + test("simple query") { val query = """ query { @@ -40,7 +42,7 @@ final class CompilerSuite extends CatsEffectSuite { UntypedSelect("name", None, Nil, Nil, Empty) ) - val res = QueryParser.parseText(query).map(_._1) + val res = queryParser.parseText(query).map(_._1) assertEquals(res, Result.Success(List(UntypedQuery(None, expected, Nil, Nil)))) } @@ -62,7 +64,7 @@ final class CompilerSuite extends CatsEffectSuite { ) ) - val res = QueryParser.parseText(query).map(_._1) + val res = queryParser.parseText(query).map(_._1) assertEquals(res, Result.Success(List(UntypedMutation(None, expected, Nil, Nil)))) } @@ -80,7 +82,7 @@ final class CompilerSuite extends CatsEffectSuite { UntypedSelect("name", None, Nil, Nil, Empty) ) - val res = QueryParser.parseText(query).map(_._1) + val res = queryParser.parseText(query).map(_._1) assertEquals(res, Result.Success(List(UntypedSubscription(None, expected, Nil, Nil)))) } @@ -106,7 +108,7 @@ final class CompilerSuite extends CatsEffectSuite { ) ) - val res = QueryParser.parseText(query).map(_._1) + val res = queryParser.parseText(query).map(_._1) assertEquals(res, Result.Success(List(UntypedQuery(None, expected, Nil, Nil)))) } @@ -137,7 +139,7 @@ final class CompilerSuite extends CatsEffectSuite { ) ) - val res = QueryParser.parseText(query).map(_._1) + val res = queryParser.parseText(query).map(_._1) assertEquals(res, Result.Success(List(UntypedQuery(None, expected, Nil, Nil)))) } @@ -163,7 +165,7 @@ final class CompilerSuite extends CatsEffectSuite { )) ) - val res = QueryParser.parseText(query).map(_._1) + val res = queryParser.parseText(query).map(_._1) assertEquals(res, Result.Success(List(UntypedQuery(None, expected, Nil, Nil)))) } @@ -192,7 +194,7 @@ final class CompilerSuite extends CatsEffectSuite { UntypedSelect("subscriptionType", None, Nil, Nil, UntypedSelect("name", None, Nil, Nil, Empty)) ) - val res = QueryParser.parseText(query).map(_._1) + val res = queryParser.parseText(query).map(_._1) assertEquals(res, Result.Success(List(UntypedQuery(Some("IntrospectionQuery"), expected, Nil, Nil)))) } @@ -355,46 +357,55 @@ final class CompilerSuite extends CatsEffectSuite { } test("malformed query (1)") { - val query = """ - query { - character(id: "1000" { - name - } - } - """ + val query = + """|query { + | character(id: "1000" { + | name + | } + |}""".stripMargin - val res = QueryParser.parseText(query) + val expected = + """|query { + | character(id: "1000" { + | ^ + |expectation: + |* must be char: ')' + | name + | }""".stripMargin - val error = - """Parse error at line 2 column 29 - | character(id: "1000" { - | ^""".stripMargin + val res = queryParser.parseText(query) - assertEquals(res, Result.failure(error)) + assertEquals(res, Result.failure(expected)) } test("malformed query (2)") { val query = "" - val res = QueryParser.parseText(query) + val res = queryParser.parseText(query) assertEquals(res, Result.failure("At least one operation required")) } test("malformed query (3)") { - val query = """ - query { - character(id: "1000") { - name - } - """ - - val res = QueryParser.parseText(query) + val query = + """|query { + | character(id: "1000") { + | name + | }""".stripMargin - val error = - "Parse error at line 5 column 4\n \n ^" - - assertEquals(res, Result.failure(error)) + val expected = + """|... + | character(id: "1000") { + | name + | } + | ^ + |expectation: + |* must be char: '}'""".stripMargin + + val res = queryParser.parseText(query) + //println(res.toProblems.toList.head.message) + + assertEquals(res, Result.failure(expected)) } } diff --git a/modules/core/src/test/scala/compiler/DirectivesSuite.scala b/modules/core/src/test/scala/compiler/DirectivesSuite.scala index fe25470a..955b9f72 100644 --- a/modules/core/src/test/scala/compiler/DirectivesSuite.scala +++ b/modules/core/src/test/scala/compiler/DirectivesSuite.scala @@ -23,6 +23,8 @@ import Ast.DirectiveLocation._ import Query._ final class DirectivesSuite extends CatsEffectSuite { + val schemaParser = SchemaParser(GraphQLParser(GraphQLParser.defaultConfig)) + def testDirectiveDefs(s: Schema): List[DirectiveDef] = s.directives.filter { case DirectiveDef("skip"|"include"|"deprecated", _, _, _, _) => false @@ -169,7 +171,7 @@ final class DirectivesSuite extends CatsEffectSuite { |directive @foo on SCHEMA|SCALAR|OBJECT|FIELD_DEFINITION|ARGUMENT_DEFINITION|INTERFACE|UNION|ENUM|ENUM_VALUE|INPUT_OBJECT|INPUT_FIELD_DEFINITION |""".stripMargin - val res = SchemaParser.parseText(schema) + val res = schemaParser.parseText(schema) val ser = res.map(_.toString) assertEquals(ser, schema.success) diff --git a/modules/core/src/test/scala/compiler/FragmentSuite.scala b/modules/core/src/test/scala/compiler/FragmentSuite.scala index 604d3832..698b2b8c 100644 --- a/modules/core/src/test/scala/compiler/FragmentSuite.scala +++ b/modules/core/src/test/scala/compiler/FragmentSuite.scala @@ -790,6 +790,258 @@ final class FragmentSuite extends CatsEffectSuite { assertIO(res, expected) } + + test("fragment defined") { + val query = """ + query withFragments { + user(id: 1) { + friends { + ...friendFields + } + } + } + """ + + val expected = json""" + { + "errors" : [ + { + "message" : "Fragment 'friendFields' is undefined" + } + ] + } + """ + + val res = FragmentMapping.compileAndRun(query) + + assertIO(res, expected) + } + + test("fragment used") { + val query = """ + query withFragments { + user(id: 1) { + friends { + id + name + profilePic + } + } + } + + fragment friendFields on User { + id + name + profilePic + } + """ + + val expected = json""" + { + "errors" : [ + { + "message" : "Fragment 'friendFields' is unused" + } + ] + } + """ + + val res = FragmentMapping.compileAndRun(query) + + assertIO(res, expected) + } + + test("fragment duplication") { + val query = """ + query withFragments { + user(id: 1) { + ...userFields + } + } + + fragment userFields on User { + name + } + + fragment userFields on User { + name + } + """ + + val expected = json""" + { + "errors" : [ + { + "message" : "Fragment 'userFields' is defined more than once" + } + ] + } + """ + + val res = FragmentMapping.compileAndRun(query) + + assertIO(res, expected) + } + + + test("fragment recursion (1)") { + val query = """ + query withFragments { + user(id: 1) { + ...userFields + } + } + + fragment userFields on User { + name + friends { + ...userFields + } + } + """ + + val expected = json""" + { + "errors" : [ + { + "message" : "Fragment cycle starting from 'userFields'" + } + ] + } + """ + + val res = FragmentMapping.compileAndRun(query) + + assertIO(res, expected) + } + + test("fragment recursion (2)") { + val query = """ + query withFragments { + user(id: 1) { + ...userFields + } + } + + fragment userFields on User { + name + favourite { + ...pageFields + } + } + + fragment pageFields on Page { + title + likers { + ...userFields + } + } + """ + + val expected = json""" + { + "errors" : [ + { + "message" : "Fragment cycle starting from 'userFields'" + } + ] + } + """ + + val res = FragmentMapping.compileAndRun(query) + + assertIO(res, expected) + } + + test("fragment recursion (3)") { + val query = """ + query withFragments { + user(id: 1) { + ...userFields + } + } + + fragment userFields on User { + name + favourite { + ...pageFields + } + } + + fragment pageFields on Page { + title + likers { + ...userFields2 + } + } + + fragment userFields2 on User { + profilePic + favourite { + ...userFields + } + } + """ + + val expected = json""" + { + "errors" : [ + { + "message" : "Fragment cycle starting from 'userFields'" + } + ] + } + """ + + val res = FragmentMapping.compileAndRun(query) + + assertIO(res, expected) + } + + test("fragment recursion (4)") { + val query = """ + query withFragments { + user(id: 1) { + ...userFields + } + } + + fragment pageFields on Page { + title + likers { + ...userFields2 + } + } + + fragment userFields2 on User { + profilePic + favourite { + ...pageFields + } + } + + fragment userFields on User { + name + favourite { + ...pageFields + } + } + """ + + val expected = json""" + { + "errors" : [ + { + "message" : "Fragment cycle starting from 'pageFields'" + } + ] + } + """ + + val res = FragmentMapping.compileAndRun(query) + + assertIO(res, expected) + } } object FragmentData { diff --git a/modules/core/src/test/scala/compiler/VariablesSuite.scala b/modules/core/src/test/scala/compiler/VariablesSuite.scala index 54bf8632..c8e09bbc 100644 --- a/modules/core/src/test/scala/compiler/VariablesSuite.scala +++ b/modules/core/src/test/scala/compiler/VariablesSuite.scala @@ -378,6 +378,139 @@ final class VariablesSuite extends CatsEffectSuite { assertEquals(compiled.map(_.query), Result.Success(expected)) } + + test("variables in directive argument") { + val query = """ + query getZuckProfile($skipName: Boolean) { + user(id: 4) { + id + name @skip(if: $skipName) + } + } + """ + + val variables = json""" + { + "skipName": true + } + """ + + val expected = + UntypedSelect("user", None, List(Binding("id", IDValue("4"))), Nil, + UntypedSelect("id", None, Nil, Nil, Empty) + ) + + val compiled = VariablesMapping.compiler.compile(query, untypedVars = Some(variables)) + + assertEquals(compiled.map(_.query), Result.Success(expected)) + } + + test("variable not defined (1)") { + val query = """ + query getZuckProfile { + user(id: 4) { + id + name + profilePic(size: $devicePicSize) + } + } + """ + + val compiled = VariablesMapping.compiler.compile(query) + + val expected = Result.failure("Variable 'devicePicSize' is undefined") + + assertEquals(compiled, expected) + } + + + test("variable not defined (2)") { + val query = """ + query getZuckProfile($devicePicSize: Int) { + user(id: 4) { + id + name + profilePic(size: $devicePicSize) + } + } + """ + + val compiled = VariablesMapping.compiler.compile(query) + + val expected = Result.failure("Variable 'devicePicSize' is undefined") + + assertEquals(compiled, expected) + } + + test("variable not defined (3)") { + val query = """ + query getZuckProfile($skipPic: Boolean) { + user(id: 4) { + id + name + profilePic(size: $devicePicSize) @skip(if: $skipPic) + } + } + """ + + val expected = Result.failure("Variable 'devicePicSize' is undefined") + + val compiled = VariablesMapping.compiler.compile(query) + + assertEquals(compiled.map(_.query), expected) + } + + test("variable not defined (4)") { + val query = """ + query getZuckProfile { + user(id: 4) { + id + name @skip(if: $skipName) + } + } + """ + + val expected = Result.failure("Variable 'skipName' is undefined") + + val compiled = VariablesMapping.compiler.compile(query) + + assertEquals(compiled.map(_.query), expected) + } + + + test("variable not defined (5)") { + val query = """ + query getZuckProfile($skipName: Boolean) { + user(id: 4) { + id + name @skip(if: $skipName) + } + } + """ + + val expected = Result.failure("Variable 'skipName' is undefined") + + val compiled = VariablesMapping.compiler.compile(query) + + assertEquals(compiled.map(_.query), expected) + } + + test("variable unused (1)") { + val query = """ + query getZuckProfile($devicePicSize: Int) { + user(id: 4) { + id + name + } + } + """ + + val compiled = VariablesMapping.compiler.compile(query) + + val expected = Result.failure("Variable 'devicePicSize' is unused") + + assertEquals(compiled, expected) + } } object VariablesMapping extends TestMapping { diff --git a/modules/core/src/test/scala/directives/DirectiveValidationSuite.scala b/modules/core/src/test/scala/directives/DirectiveValidationSuite.scala index 35c63285..2ab2b61b 100644 --- a/modules/core/src/test/scala/directives/DirectiveValidationSuite.scala +++ b/modules/core/src/test/scala/directives/DirectiveValidationSuite.scala @@ -375,7 +375,7 @@ object ExecutableDirectiveMapping extends Mapping[IO] { override val selectElaborator = PreserveArgsElaborator def compileAllOperations(text: String): Result[List[Operation]] = - QueryParser.parseText(text).flatMap { + queryParser.parseText(text).flatMap { case (ops, frags) => ops.parTraverse(compiler.compileOperation(_, None, frags)) } } diff --git a/modules/core/src/test/scala/minimizer/MinimizerSuite.scala b/modules/core/src/test/scala/minimizer/MinimizerSuite.scala index 1628accd..af3ffe4f 100644 --- a/modules/core/src/test/scala/minimizer/MinimizerSuite.scala +++ b/modules/core/src/test/scala/minimizer/MinimizerSuite.scala @@ -17,18 +17,22 @@ package minimizer import munit.CatsEffectSuite -import grackle.{ GraphQLParser, QueryMinimizer } +import grackle.{ GraphQLParser, QueryMinimizer, Result } final class MinimizerSuite extends CatsEffectSuite { + val parser = GraphQLParser(GraphQLParser.defaultConfig) + val minimizer = QueryMinimizer(parser) + def run(query: String, expected: String, echo: Boolean = false): Unit = { - val Right(minimized) = QueryMinimizer.minimizeText(query) : @unchecked + + val Result.Success(minimized) = minimizer.minimizeText(query) : @unchecked if (echo) println(minimized) assert(minimized == expected) - val Some(parsed0) = GraphQLParser.Document.parseAll(query).toOption : @unchecked - val Some(parsed1) = GraphQLParser.Document.parseAll(minimized).toOption : @unchecked + val Some(parsed0) = parser.parseText(query).toOption : @unchecked + val Some(parsed1) = parser.parseText(minimized).toOption : @unchecked assertEquals(parsed0, parsed1) } diff --git a/modules/core/src/test/scala/parser/ParserSuite.scala b/modules/core/src/test/scala/parser/ParserSuite.scala index 02d1b65e..bfe3f0b3 100644 --- a/modules/core/src/test/scala/parser/ParserSuite.scala +++ b/modules/core/src/test/scala/parser/ParserSuite.scala @@ -15,14 +15,15 @@ package parser -import cats.data.NonEmptyChain import munit.CatsEffectSuite -import grackle.{Ast, GraphQLParser, Problem, Result} +import grackle.{Ast, GraphQLParser, Result} import grackle.syntax._ import Ast._, OperationType._, OperationDefinition._, Selection._, Value._, Type.Named final class ParserSuite extends CatsEffectSuite { + val parser = GraphQLParser(GraphQLParser.defaultConfig) + test("simple query") { val query = doc""" query { @@ -72,7 +73,7 @@ final class ParserSuite extends CatsEffectSuite { ) ) - GraphQLParser.Document.parseAll(query).toOption match { + parser.parseText(query).toOption match { case Some(List(q)) => assertEquals(q, expected) case _ => assert(false) } @@ -104,7 +105,7 @@ final class ParserSuite extends CatsEffectSuite { ) ) - GraphQLParser.Document.parseAll(query).toOption match { + parser.parseText(query).toOption match { case Some(List(q)) => assertEquals(q, expected) case _ => assert(false) } @@ -152,7 +153,7 @@ final class ParserSuite extends CatsEffectSuite { ) ) - GraphQLParser.Document.parseAll(query).toOption match { + parser.parseText(query).toOption match { case Some(List(q)) => assertEquals(q, expected) case _ => assert(false) } @@ -194,7 +195,7 @@ final class ParserSuite extends CatsEffectSuite { ) ) - GraphQLParser.Document.parseAll(query).toOption match { + parser.parseText(query).toOption match { case Some(List(q)) => assertEquals(q, expected) case _ => assert(false) } @@ -226,7 +227,7 @@ final class ParserSuite extends CatsEffectSuite { ) ) - GraphQLParser.Document.parseAll(query).toOption match { + parser.parseText(query).toOption match { case Some(List(q)) => assertEquals(q, expected) case _ => assert(false) } @@ -258,7 +259,7 @@ final class ParserSuite extends CatsEffectSuite { ) ) - GraphQLParser.Document.parseAll(query).toOption match { + parser.parseText(query).toOption match { case Some(List(q)) => assertEquals(q, expected) case _ => assert(false) } @@ -290,7 +291,7 @@ final class ParserSuite extends CatsEffectSuite { ) ) - GraphQLParser.Document.parseAll(query).toOption match { + parser.parseText(query).toOption match { case Some(List(q)) => assertEquals(q, expected) case _ => assert(false) } @@ -322,7 +323,7 @@ final class ParserSuite extends CatsEffectSuite { ) ) - GraphQLParser.Document.parseAll(query).toOption match { + parser.parseText(query).toOption match { case Some(List(q)) => assertEquals(q, expected) case _ => assert(false) } @@ -354,7 +355,7 @@ final class ParserSuite extends CatsEffectSuite { ) ) - GraphQLParser.Document.parseAll(query).toOption match { + parser.parseText(query).toOption match { case Some(List(q)) => assertEquals(q, expected) case _ => assert(false) } @@ -406,14 +407,14 @@ final class ParserSuite extends CatsEffectSuite { ) ) - GraphQLParser.Document.parseAll(query).toOption match { + parser.parseText(query).toOption match { case Some(List(q)) => assertEquals(q, expected) case _ => assert(false) } } test("invalid document") { - GraphQLParser.Document.parseAll("scalar Foo woozle").toOption match { + parser.parseText("scalar Foo woozle").toOption match { case Some(_) => fail("should have failed") case None => () } @@ -463,7 +464,7 @@ final class ParserSuite extends CatsEffectSuite { ) ) - GraphQLParser.Document.parseAll(query).toOption match { + parser.parseText(query).toOption match { case Some(xs) => assertEquals(xs, expected) case _ => assert(false) } @@ -502,7 +503,7 @@ final class ParserSuite extends CatsEffectSuite { ) ) - GraphQLParser.Document.parseAll(query).toOption match { + parser.parseText(query).toOption match { case Some(List(q)) => assertEquals(q, expected) case _ => assert(false) } @@ -540,7 +541,7 @@ final class ParserSuite extends CatsEffectSuite { ) ) - GraphQLParser.Document.parseAll(query).toOption match { + parser.parseText(query).toOption match { case Some(List(q)) => assertEquals(q, expected) case _ => assert(false) } @@ -549,8 +550,9 @@ final class ParserSuite extends CatsEffectSuite { test("value literals") { def assertParse(input: String, expected: Value) = - GraphQLParser.Value.parseAll(input).toOption match { - case Some(v) => assertEquals(v, expected) + parser.parseText(s"query { foo(bar: $input) }").toOption match { + case Some(List(Operation(_, _, _, _,List(Field(_, _, List((_, v)), _, _))))) => + assertEquals(v, expected) case _ => assert(false) } @@ -580,6 +582,32 @@ final class ParserSuite extends CatsEffectSuite { assertParse("\"\"\" \n\n first\n \tĪ»\n 123\n\n\n \t\n\n\"\"\"", StringValue(" first\n \tĪ»\n123")) } + test("outsized int") { + val query = + """|query { + | foo { + | bar { + | baz(id: 2147483648) + | } + | } + |}""".stripMargin + + val expected = + """|... + | foo { + | bar { + | baz(id: 2147483648) + | ^ + |expectation: + |* must fail: 2147483648 is larger than max int + | } + | }""".stripMargin + + val res = parser.parseText(query) + + assertEquals(res, Result.failure(expected)) + } + test("parse object type extension") { val schema = """ extend type Foo { @@ -592,7 +620,7 @@ final class ParserSuite extends CatsEffectSuite { ObjectTypeExtension(Named(Name("Foo")), List(FieldDefinition(Name("bar"),None,Nil,Named(Name("Int")),Nil)), Nil, Nil) ) - val res = GraphQLParser.Document.parseAll(schema).toOption + val res = parser.parseText(schema).toOption assert(res == Some(expected)) } @@ -608,7 +636,7 @@ final class ParserSuite extends CatsEffectSuite { SchemaExtension(List(RootOperationTypeDefinition(OperationType.Query, Named(Name("Query")), Nil)), Nil) ) - val res = GraphQLParser.Document.parseAll(schema).toOption + val res = parser.parseText(schema).toOption assert(res == Some(expected)) } @@ -630,26 +658,155 @@ final class ParserSuite extends CatsEffectSuite { ) ) - val res = GraphQLParser.Document.parseAll(schema).toOption + val res = parser.parseText(schema).toOption assertEquals(res, Some(expected)) } test("keywords parsed non-greedily (2)") { val schema = """|extendtypeName { - | value:String + | value: String |}""".stripMargin val expected = - NonEmptyChain( - Problem( - """|Parse error at line 0 column 6 - |extendtypeName { - | ^""".stripMargin - ) - ) + """|extendtypeName { + | ^ + |expectation: + |* must fail but matched with t + | value: String + |}""".stripMargin - val res = GraphQLParser.toResult(schema, GraphQLParser.Document.parseAll(schema)) - assertEquals(res, Result.Failure(expected)) + val res = parser.parseText(schema) + assertEquals(res, Result.failure(expected)) } + + test("deep query") { + def mkQuery(depth: Int): String = { + val depth0 = depth - 1 + "query{" + ("f{" *depth0) + "f" + ("}" * depth0) + "}" + } + + val limit = 5 + val limitedParser = mkParser(maxSelectionDepth = limit) + + val queryOk = mkQuery(limit) + val queryFail = mkQuery(limit + 1) + + val expectedFail = + """|query{f{f{f{f{f{f}}}}}} + | ^ + |expectation: + |* must fail: exceeded maximum selection depth""".stripMargin + + val resOk = limitedParser.parseText(queryOk) + assert(resOk.hasValue) + + val resFail = limitedParser.parseText(queryFail) + assertEquals(resFail, Result.failure(expectedFail)) + } + + test("wide query") { + def mkQuery(width: Int): String = + "query{r{" + ("f," * (width - 1) + "f") + "}}" + + val limit = 5 + val limitedParser = mkParser(maxSelectionWidth = limit) + + val queryOk = mkQuery(limit) + val queryFail = mkQuery(limit + 1) + + val expectedFail = + """|query{r{f,f,f,f,f,f}} + | ^ + |expectation: + |* must be char: '}'""".stripMargin + + val resOk = limitedParser.parseText(queryOk) + assert(resOk.hasValue) + + val resFail = limitedParser.parseText(queryFail) + assertEquals(resFail, Result.failure(expectedFail)) + } + + test("deep list value") { + def mkQuery(depth: Int): String = + "query{f(l: " + ("[" *depth) + "0" + ("]" * depth) + "){f}}" + + val limit = 5 + val limitedParser = mkParser(maxInputValueDepth = limit) + + val queryOk = mkQuery(limit) + val queryFail = mkQuery(limit + 1) + + val expectedFail = + """|query{f(l: [[[[[[0]]]]]]){f}} + | ^ + |expectation: + |* must fail: exceeded maximum input value depth""".stripMargin + + val resOk = limitedParser.parseText(queryOk) + assert(resOk.hasValue) + + val resFail = limitedParser.parseText(queryFail) + assertEquals(resFail, Result.failure(expectedFail)) + } + + test("deep input object value") { + def mkQuery(depth: Int): String = + "query{f(l: " + ("{m:" *depth) + "0" + ("}" * depth) + "){f}}" + + val limit = 5 + val limitedParser = mkParser(maxInputValueDepth = limit) + + val queryOk = mkQuery(limit) + val queryFail = mkQuery(limit + 1) + + val expectedFail = + """|query{f(l: {m:{m:{m:{m:{m:{m:0}}}}}}){f}} + | ^ + |expectation: + |* must fail: exceeded maximum input value depth""".stripMargin + + val resOk = limitedParser.parseText(queryOk) + assert(resOk.hasValue) + + val resFail = limitedParser.parseText(queryFail) + assertEquals(resFail, Result.failure(expectedFail)) + } + + test("deep variable type") { + def mkQuery(depth: Int): String = + "query($l: " + ("[" *depth) + "Int" + ("]" * depth) + "){f(a:$l)}" + + val limit = 5 + val limitedParser = mkParser(maxListTypeDepth = limit) + + val queryOk = mkQuery(limit) + val queryFail = mkQuery(limit + 1) + + val expectedFail = + """|query($l: [[[[[[Int]]]]]]){f(a:$l)} + | ^ + |expectation: + |* must fail: exceeded maximum list type depth""".stripMargin + + val resOk = limitedParser.parseText(queryOk) + assert(resOk.hasValue) + + val resFail = limitedParser.parseText(queryFail) + assertEquals(resFail, Result.failure(expectedFail)) + } + + def mkParser( + maxSelectionDepth: Int = GraphQLParser.defaultConfig.maxSelectionDepth, + maxSelectionWidth: Int = GraphQLParser.defaultConfig.maxSelectionWidth, + maxInputValueDepth: Int = GraphQLParser.defaultConfig.maxInputValueDepth, + maxListTypeDepth: Int = GraphQLParser.defaultConfig.maxListTypeDepth): GraphQLParser = + GraphQLParser( + GraphQLParser.Config( + maxSelectionDepth = maxSelectionDepth, + maxSelectionWidth = maxSelectionWidth, + maxInputValueDepth = maxInputValueDepth, + maxListTypeDepth = maxListTypeDepth) + ) } diff --git a/modules/core/src/test/scala/sdl/SDLSuite.scala b/modules/core/src/test/scala/sdl/SDLSuite.scala index eb888071..92e04531 100644 --- a/modules/core/src/test/scala/sdl/SDLSuite.scala +++ b/modules/core/src/test/scala/sdl/SDLSuite.scala @@ -22,6 +22,9 @@ import grackle.syntax._ import Ast._, OperationType._, Type.{ List => _, _ } final class SDLSuite extends CatsEffectSuite { + val parser = GraphQLParser(GraphQLParser.defaultConfig) + val schemaParser = SchemaParser(parser) + test("parse schema definition") { val schema = """ schema { @@ -43,9 +46,9 @@ final class SDLSuite extends CatsEffectSuite { ) ) - val res = GraphQLParser.Document.parseAll(schema) + val res = parser.parseText(schema) - assertEquals(res, Right(expected)) + assertEquals(res, expected.success) } test("parse scalar type definition") { @@ -63,9 +66,9 @@ final class SDLSuite extends CatsEffectSuite { ScalarTypeDefinition(Name("Time"), Some("A scalar type"), List(Directive(Name("deprecated"), Nil))) ) - val res = GraphQLParser.Document.parseAll(schema) + val res = parser.parseText(schema) - assertEquals(res, Right(expected)) + assertEquals(res, expected.success) } test("parse object type definition") { @@ -95,9 +98,9 @@ final class SDLSuite extends CatsEffectSuite { ) ) - val res = GraphQLParser.Document.parseAll(schema) + val res = parser.parseText(schema) - assertEquals(res, Right(expected)) + assertEquals(res, expected.success) } test("parse interface type definition") { @@ -127,9 +130,9 @@ final class SDLSuite extends CatsEffectSuite { ) ) - val res = GraphQLParser.Document.parseAll(schema) + val res = parser.parseText(schema) - assertEquals(res, Right(expected)) + assertEquals(res, expected.success) } test("parse union type definition") { @@ -148,9 +151,9 @@ final class SDLSuite extends CatsEffectSuite { ) ) - val res = GraphQLParser.Document.parseAll(schema) + val res = parser.parseText(schema) - assertEquals(res, Right(expected)) + assertEquals(res, expected.success) } test("parse enum type definition") { @@ -176,9 +179,9 @@ final class SDLSuite extends CatsEffectSuite { ) ) - val res = GraphQLParser.Document.parseAll(schema) + val res = parser.parseText(schema) - assertEquals(res, Right(expected)) + assertEquals(res, expected.success) } test("parse input object type definition") { @@ -201,9 +204,9 @@ final class SDLSuite extends CatsEffectSuite { ) ) - val res = GraphQLParser.Document.parseAll(schema) + val res = parser.parseText(schema) - assertEquals(res, Right(expected)) + assertEquals(res, expected.success) } test("parse directive definition") { @@ -215,7 +218,7 @@ final class SDLSuite extends CatsEffectSuite { |directive @delegateField(name: String!) repeatable on OBJECT|INTERFACE|FIELD|FIELD_DEFINITION|ENUM|ENUM_VALUE |""".stripMargin - val res = SchemaParser.parseText(schema) + val res = schemaParser.parseText(schema) val ser = res.map(_.toString) assertEquals(ser, schema.success) @@ -240,7 +243,7 @@ final class SDLSuite extends CatsEffectSuite { | author(id: Int! = 23): Author |}""".stripMargin - val res = SchemaParser.parseText(schema) + val res = schemaParser.parseText(schema) val ser = res.map(_.toString) assertEquals(ser, schema.success) @@ -284,7 +287,7 @@ final class SDLSuite extends CatsEffectSuite { | primaryFunction: String |}""".stripMargin - val res = SchemaParser.parseText(schema) + val res = schemaParser.parseText(schema) val ser = res.map(_.toString) assertEquals(ser, schema.success) @@ -307,7 +310,7 @@ final class SDLSuite extends CatsEffectSuite { | y: Int |}""".stripMargin - val res = SchemaParser.parseText(schema) + val res = schemaParser.parseText(schema) val ser = res.map(_.toString) assertEquals(ser, schema.success) @@ -369,7 +372,7 @@ final class SDLSuite extends CatsEffectSuite { |directive @Inp on INPUT_OBJECT |""".stripMargin - val res = SchemaParser.parseText(schema) + val res = schemaParser.parseText(schema) val ser = res.map(_.toString) assertEquals(ser, schema.success) @@ -414,7 +417,7 @@ final class SDLSuite extends CatsEffectSuite { |directive @Inp on INPUT_OBJECT |""".stripMargin - val res = SchemaParser.parseText(schema) + val res = schemaParser.parseText(schema) val ser = res.map(_.toString) assertEquals(ser, schema.success) diff --git a/modules/generic/src/test/scala/ScalarsSuite.scala b/modules/generic/src/test/scala/ScalarsSuite.scala index 0f1dbeda..7baf3193 100644 --- a/modules/generic/src/test/scala/ScalarsSuite.scala +++ b/modules/generic/src/test/scala/ScalarsSuite.scala @@ -373,6 +373,38 @@ final class ScalarsSuite extends CatsEffectSuite { assertIO(res, expected) } + test("query with scalar argument without apostrophes") { + val query = """ + query { + moviesLongerThan(duration: PT3H) { + title + duration + } + } + """ + + val expected = json""" + { + "data" : { + "moviesLongerThan" : [ + { + "title" : "Celine et Julie Vont en Bateau", + "duration" : "PT3H25M" + }, + { + "title" : "L'Amour fou", + "duration" : "PT4H12M" + } + ] + } + } + """ + + val res = MovieMapping.compileAndRun(query) + + assertIO(res, expected) + } + test("query with LocalTime argument") { val query = """ query { From 0b7ed52320e9336714bd4f587c5e88fd46008940 Mon Sep 17 00:00:00 2001 From: Miles Sabin Date: Wed, 13 Dec 2023 12:05:27 +0000 Subject: [PATCH 2/3] Default to terse parser error messages --- modules/core/src/main/scala/parser.scala | 15 +++++++++++---- .../src/test/scala/compiler/CompilerSuite.scala | 2 +- .../core/src/test/scala/parser/ParserSuite.scala | 9 ++++++--- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/modules/core/src/main/scala/parser.scala b/modules/core/src/main/scala/parser.scala index bc9361bb..21ce3449 100644 --- a/modules/core/src/main/scala/parser.scala +++ b/modules/core/src/main/scala/parser.scala @@ -32,7 +32,8 @@ object GraphQLParser { maxSelectionDepth: Int, maxSelectionWidth: Int, maxInputValueDepth: Int, - maxListTypeDepth: Int + maxListTypeDepth: Int, + terseError: Boolean ) val defaultConfig: Config = @@ -40,7 +41,8 @@ object GraphQLParser { maxSelectionDepth = 100, maxSelectionWidth = 1000, maxInputValueDepth = 5, - maxListTypeDepth = 5 + maxListTypeDepth = 5, + terseError = true ) def apply(config: Config): GraphQLParser = @@ -49,14 +51,19 @@ object GraphQLParser { def toResult[T](pr: Either[Parser.Error, T]): Result[T] = Result.fromEither(pr.leftMap(_.show)) + def toResultTerseError[T](pr: Either[Parser.Error, T]): Result[T] = + Result.fromEither(pr.leftMap(_.copy().show)) + import CommentedText._ import Literals._ private final class Impl(config: Config) extends GraphQLParser { import config._ - def parseText(text: String): Result[Ast.Document] = - toResult(Document.parseAll(text)) + def parseText(text: String): Result[Ast.Document] = { + val res = Document.parseAll(text) + if (config.terseError) toResultTerseError(res) else toResult(res) + } val nameInitial = ('A' to 'Z') ++ ('a' to 'z') ++ Seq('_') val nameSubsequent = nameInitial ++ ('0' to '9') diff --git a/modules/core/src/test/scala/compiler/CompilerSuite.scala b/modules/core/src/test/scala/compiler/CompilerSuite.scala index 25e6c1a7..c45f085c 100644 --- a/modules/core/src/test/scala/compiler/CompilerSuite.scala +++ b/modules/core/src/test/scala/compiler/CompilerSuite.scala @@ -26,7 +26,7 @@ import Predicate._, Value._, UntypedOperation._ import QueryCompiler._, ComponentElaborator.TrivialJoin final class CompilerSuite extends CatsEffectSuite { - val queryParser = QueryParser(GraphQLParser(GraphQLParser.defaultConfig)) + val queryParser = QueryParser(GraphQLParser(GraphQLParser.defaultConfig.copy(terseError = false))) test("simple query") { val query = """ diff --git a/modules/core/src/test/scala/parser/ParserSuite.scala b/modules/core/src/test/scala/parser/ParserSuite.scala index bfe3f0b3..91fd5364 100644 --- a/modules/core/src/test/scala/parser/ParserSuite.scala +++ b/modules/core/src/test/scala/parser/ParserSuite.scala @@ -22,7 +22,7 @@ import grackle.syntax._ import Ast._, OperationType._, OperationDefinition._, Selection._, Value._, Type.Named final class ParserSuite extends CatsEffectSuite { - val parser = GraphQLParser(GraphQLParser.defaultConfig) + val parser = mkParser() test("simple query") { val query = doc""" @@ -801,12 +801,15 @@ final class ParserSuite extends CatsEffectSuite { maxSelectionDepth: Int = GraphQLParser.defaultConfig.maxSelectionDepth, maxSelectionWidth: Int = GraphQLParser.defaultConfig.maxSelectionWidth, maxInputValueDepth: Int = GraphQLParser.defaultConfig.maxInputValueDepth, - maxListTypeDepth: Int = GraphQLParser.defaultConfig.maxListTypeDepth): GraphQLParser = + maxListTypeDepth: Int = GraphQLParser.defaultConfig.maxListTypeDepth, + ): GraphQLParser = GraphQLParser( GraphQLParser.Config( maxSelectionDepth = maxSelectionDepth, maxSelectionWidth = maxSelectionWidth, maxInputValueDepth = maxInputValueDepth, - maxListTypeDepth = maxListTypeDepth) + maxListTypeDepth = maxListTypeDepth, + terseError = false ) + ) } From 05dde88db22fec94b403d3e99526607a6ba7bd32 Mon Sep 17 00:00:00 2001 From: Miles Sabin Date: Wed, 13 Dec 2023 13:04:32 +0000 Subject: [PATCH 3/3] Added option to disable unused variable/fragment validation --- modules/core/src/main/scala/compiler.scala | 21 +++++---- modules/core/src/main/scala/mapping.scala | 8 ++-- .../test/scala/compiler/FragmentSuite.scala | 47 ++++++++++++++++++- .../test/scala/compiler/VariablesSuite.scala | 32 +++++++++++++ 4 files changed, 95 insertions(+), 13 deletions(-) diff --git a/modules/core/src/main/scala/compiler.scala b/modules/core/src/main/scala/compiler.scala index 3bb5c3a8..111378e4 100644 --- a/modules/core/src/main/scala/compiler.scala +++ b/modules/core/src/main/scala/compiler.scala @@ -215,10 +215,10 @@ class QueryCompiler(parser: QueryParser, schema: Schema, phases: List[Phase]) { * * GraphQL errors and warnings are accumulated in the result. */ - def compile(text: String, name: Option[String] = None, untypedVars: Option[Json] = None, introspectionLevel: IntrospectionLevel = Full, env: Env = Env.empty): Result[Operation] = + def compile(text: String, name: Option[String] = None, untypedVars: Option[Json] = None, introspectionLevel: IntrospectionLevel = Full, reportUnused: Boolean = true, env: Env = Env.empty): Result[Operation] = parser.parseText(text).flatMap { case (ops, frags) => for { - _ <- Result.fromProblems(validateVariablesAndFragments(ops, frags)) + _ <- Result.fromProblems(validateVariablesAndFragments(ops, frags, reportUnused)) ops0 <- ops.traverse(op => compileOperation(op, untypedVars, frags, introspectionLevel, env).map(op0 => (op.name, op0))) res <- (ops0, name) match { case (List((_, op)), None) => @@ -323,7 +323,7 @@ class QueryCompiler(parser: QueryParser, schema: Schema, phases: List[Phase]) { loop(tpe, false) } - def validateVariablesAndFragments(ops: List[UntypedOperation], frags: List[UntypedFragment]): List[Problem] = { + def validateVariablesAndFragments(ops: List[UntypedOperation], frags: List[UntypedFragment], reportUnused: Boolean): List[Problem] = { val (uniqueFrags, duplicateFrags) = frags.map(_.name).foldLeft((Set.empty[String], Set.empty[String])) { case ((unique, duplicate), nme) => if (unique.contains(nme)) (unique, duplicate + nme) @@ -439,10 +439,13 @@ class QueryCompiler(parser: QueryParser, schema: Schema, phases: List[Phase]) { val varProblems = if (qv == pendingVars) Nil else { - val undefined = qv.diff(pendingVars) - val unused = pendingVars.diff(qv) - val undefinedProblems = undefined.toList.map(nme => Problem(s"Variable '$nme' is undefined")) - val unusedProblems = unused.toList.map(nme => Problem(s"Variable '$nme' is unused")) + val undefinedProblems = + qv.diff(pendingVars).toList.map(nme => Problem(s"Variable '$nme' is undefined")) + + val unusedProblems = + if (!reportUnused) Nil + else pendingVars.diff(qv).toList.map(nme => Problem(s"Variable '$nme' is unused")) + undefinedProblems ++ unusedProblems } @@ -464,7 +467,9 @@ class QueryCompiler(parser: QueryParser, schema: Schema, phases: List[Phase]) { (acc ++ problems, pendingFrags0) } - val unreferencedFragProblems = unreferencedFrags.toList.map(nme => Problem(s"Fragment '$nme' is unused")) + val unreferencedFragProblems = + if (!reportUnused) Nil + else unreferencedFrags.toList.map(nme => Problem(s"Fragment '$nme' is unused")) opProblems ++ unreferencedFragProblems } diff --git a/modules/core/src/main/scala/mapping.scala b/modules/core/src/main/scala/mapping.scala index b9d35239..36ddc94e 100644 --- a/modules/core/src/main/scala/mapping.scala +++ b/modules/core/src/main/scala/mapping.scala @@ -47,10 +47,10 @@ abstract class Mapping[F[_]] { * * Yields a JSON response containing the result of the query or mutation. */ - def compileAndRun(text: String, name: Option[String] = None, untypedVars: Option[Json] = None, introspectionLevel: IntrospectionLevel = Full, env: Env = Env.empty)( + def compileAndRun(text: String, name: Option[String] = None, untypedVars: Option[Json] = None, introspectionLevel: IntrospectionLevel = Full, reportUnused: Boolean = true, env: Env = Env.empty)( implicit sc: Compiler[F,F] ): F[Json] = - compileAndRunSubscription(text, name, untypedVars, introspectionLevel, env).compile.toList.flatMap { + compileAndRunSubscription(text, name, untypedVars, introspectionLevel, reportUnused, env).compile.toList.flatMap { case List(j) => j.pure[F] case Nil => M.raiseError(new IllegalStateException("Result stream was empty.")) case js => M.raiseError(new IllegalStateException(s"Result stream contained ${js.length} results; expected exactly one.")) @@ -61,8 +61,8 @@ abstract class Mapping[F[_]] { * * Yields a stream of JSON responses containing the results of the subscription. */ - def compileAndRunSubscription(text: String, name: Option[String] = None, untypedVars: Option[Json] = None, introspectionLevel: IntrospectionLevel = Full, env: Env = Env.empty): Stream[F,Json] = { - val compiled = compiler.compile(text, name, untypedVars, introspectionLevel, env) + def compileAndRunSubscription(text: String, name: Option[String] = None, untypedVars: Option[Json] = None, introspectionLevel: IntrospectionLevel = Full, reportUnused: Boolean = true, env: Env = Env.empty): Stream[F,Json] = { + val compiled = compiler.compile(text, name, untypedVars, introspectionLevel, reportUnused, env) Stream.eval(compiled.pure[F]).flatMap(_.flatTraverse(op => interpreter.run(op.query, op.rootTpe, env))).evalMap(mkResponse) } diff --git a/modules/core/src/test/scala/compiler/FragmentSuite.scala b/modules/core/src/test/scala/compiler/FragmentSuite.scala index 698b2b8c..d781eb44 100644 --- a/modules/core/src/test/scala/compiler/FragmentSuite.scala +++ b/modules/core/src/test/scala/compiler/FragmentSuite.scala @@ -817,7 +817,7 @@ final class FragmentSuite extends CatsEffectSuite { assertIO(res, expected) } - test("fragment used") { + test("fragment unused (1)") { val query = """ query withFragments { user(id: 1) { @@ -851,6 +851,51 @@ final class FragmentSuite extends CatsEffectSuite { assertIO(res, expected) } + test("fragment unused (2)") { + val query = """ + query withFragments { + user(id: 1) { + friends { + id + name + profilePic + } + } + } + + fragment friendFields on User { + id + name + profilePic + } + """ + + val expected = json""" + { + "data" : { + "user" : { + "friends" : [ + { + "id" : "2", + "name" : "Bob", + "profilePic" : "B" + }, + { + "id" : "3", + "name" : "Carol", + "profilePic" : "C" + } + ] + } + } + } + """ + + val res = FragmentMapping.compileAndRun(query, reportUnused = false) + + assertIO(res, expected) + } + test("fragment duplication") { val query = """ query withFragments { diff --git a/modules/core/src/test/scala/compiler/VariablesSuite.scala b/modules/core/src/test/scala/compiler/VariablesSuite.scala index c8e09bbc..178b27cd 100644 --- a/modules/core/src/test/scala/compiler/VariablesSuite.scala +++ b/modules/core/src/test/scala/compiler/VariablesSuite.scala @@ -511,6 +511,36 @@ final class VariablesSuite extends CatsEffectSuite { assertEquals(compiled, expected) } + + test("variable unused (2)") { + val query = """ + query getZuckProfile($devicePicSize: Int) { + user(id: 4) { + id + name + } + } + """ + + val compiled = VariablesMapping.compiler.compile(query, reportUnused = false) + println(compiled) + + val expected = + Operation( + UntypedSelect("user", None, List(Binding("id", IDValue("4"))), Nil, + Group( + List( + UntypedSelect("id", None, Nil, Nil, Empty), + UntypedSelect("name", None, Nil, Nil, Empty) + ) + ) + ), + VariablesMapping.QueryType, + Nil + ) + + assertEquals(compiled, Result.success(expected)) + } } object VariablesMapping extends TestMapping { @@ -544,5 +574,7 @@ object VariablesMapping extends TestMapping { scalar BigDecimal """ + val QueryType = schema.ref("Query").dealias + override val selectElaborator = PreserveArgsElaborator }