diff --git a/src/helpers/types.ts b/src/helpers/types.ts index ba7608301..ea839e003 100644 --- a/src/helpers/types.ts +++ b/src/helpers/types.ts @@ -100,6 +100,10 @@ export function convertToType(Target: any, data?: object): object | undefined { if (data instanceof Target) { return data; } + // convert array to instances + if (Array.isArray(data)) { + return data.map(item => convertToType(Target, item)); + } return Object.assign(new Target(), data); } diff --git a/src/resolvers/convert-args.ts b/src/resolvers/convert-args.ts new file mode 100644 index 000000000..8754346b5 --- /dev/null +++ b/src/resolvers/convert-args.ts @@ -0,0 +1,134 @@ +import { ArgParamMetadata, ClassMetadata, ArgsParamMetadata } from "../metadata/definitions"; +import { convertToType } from "../helpers/types"; +import { ArgsDictionary, ClassType } from "../interfaces"; +import { getMetadataStorage } from "../metadata/getMetadataStorage"; +import { TypeValue } from "../decorators/types"; + +interface TransformationTreeField { + name: string; + target: TypeValue; + fields?: TransformationTree; +} + +interface TransformationTree { + target: TypeValue; + getFields: () => TransformationTreeField[]; +} + +const generatedTrees = new Map(); + +function getInputType(target: TypeValue): ClassMetadata | undefined { + return getMetadataStorage().inputTypes.find(t => t.target === target); +} + +function getArgsType(target: TypeValue): ClassMetadata | undefined { + return getMetadataStorage().argumentTypes.find(t => t.target === target); +} + +function generateInstanceTransformationTree(target: TypeValue): TransformationTree | null { + if (generatedTrees.has(target)) { + return generatedTrees.get(target)!; + } + + const inputType = getInputType(target); + if (!inputType) { + generatedTrees.set(target, null); + return null; + } + + function generateTransformationTree(metadata: ClassMetadata): TransformationTree { + let inputFields = metadata.fields!; + let superClass = Object.getPrototypeOf(metadata.target); + while (superClass.prototype !== undefined) { + const superInputType = getInputType(superClass); + if (superInputType) { + inputFields = [...inputFields, ...superInputType.fields!]; + } + superClass = Object.getPrototypeOf(superClass); + } + + const transformationTree: TransformationTree = { + target: metadata.target, + getFields: () => + inputFields.map(field => { + const fieldTarget = field.getType(); + const fieldInputType = getInputType(fieldTarget); + return { + name: field.name, + target: fieldTarget, + fields: + fieldTarget === metadata.target + ? transformationTree + : fieldInputType && generateTransformationTree(fieldInputType), + }; + }), + }; + + return transformationTree; + } + + const generatedTransformationTree = generateTransformationTree(inputType); + generatedTrees.set(target, generatedTransformationTree); + return generatedTransformationTree; +} + +function convertToInput(tree: TransformationTree, data: any) { + const inputFields = tree.getFields().reduce>((fields, field) => { + const siblings = field.fields; + const value = data[field.name]; + if (!siblings || !value) { + fields[field.name] = convertToType(field.target, value); + } else if (Array.isArray(value)) { + fields[field.name] = value.map(itemValue => convertToInput(siblings, itemValue)); + } else { + fields[field.name] = convertToInput(siblings, value); + } + return fields; + }, {}); + + return convertToType(tree.target, inputFields); +} + +function convertValueToInstance(target: TypeValue, value: any) { + const transformationTree = generateInstanceTransformationTree(target); + return transformationTree + ? convertToInput(transformationTree, value) + : convertToType(target, value); +} + +function convertValuesToInstances(target: TypeValue, value: any) { + if (Array.isArray(value)) { + return value.map(itemValue => convertValueToInstance(target, itemValue)); + } + return convertValueToInstance(target, value); +} + +export function convertArgsToInstance(argsMetadata: ArgsParamMetadata, args: ArgsDictionary) { + const ArgsClass = argsMetadata.getType() as ClassType; + const argsType = getArgsType(ArgsClass)!; + + let argsFields = argsType.fields!; + let superClass = Object.getPrototypeOf(argsType.target); + while (superClass.prototype !== undefined) { + const superArgumentType = getArgsType(superClass); + if (superArgumentType) { + argsFields = [...argsFields, ...superArgumentType.fields!]; + } + superClass = Object.getPrototypeOf(superClass); + } + + const transformedFields = argsFields.reduce>((fields, field) => { + const fieldValue = args[field.name]; + const fieldTarget = field.getType(); + fields[field.name] = convertValuesToInstances(fieldTarget, fieldValue); + return fields; + }, {}); + + return convertToType(ArgsClass, transformedFields); +} + +export function convertArgToInstance(argMetadata: ArgParamMetadata, args: ArgsDictionary) { + const argValue = args[argMetadata.name]; + const argTarget = argMetadata.getType(); + return convertValuesToInstances(argTarget, argValue); +} diff --git a/src/resolvers/helpers.ts b/src/resolvers/helpers.ts index c21eb9728..cf0f32b52 100644 --- a/src/resolvers/helpers.ts +++ b/src/resolvers/helpers.ts @@ -8,6 +8,7 @@ import { ResolverData, AuthChecker, AuthMode } from "../interfaces"; import { Middleware, MiddlewareFn, MiddlewareClass } from "../interfaces/Middleware"; import { IOCContainer } from "../utils/container"; import { AuthMiddleware } from "../helpers/auth-middleware"; +import { convertArgsToInstance, convertArgToInstance } from "./convert-args"; export async function getParams( params: ParamMetadata[], @@ -22,13 +23,13 @@ export async function getParams( switch (paramInfo.kind) { case "args": return await validateArg( - convertToType(paramInfo.getType(), resolverData.args), + convertArgsToInstance(paramInfo, resolverData.args), globalValidate, paramInfo.validate, ); case "arg": return await validateArg( - convertToType(paramInfo.getType(), resolverData.args[paramInfo.name]), + convertArgToInstance(paramInfo, resolverData.args), globalValidate, paramInfo.validate, ); diff --git a/src/resolvers/validate-arg.ts b/src/resolvers/validate-arg.ts index 6c5e532e2..f813d1bde 100644 --- a/src/resolvers/validate-arg.ts +++ b/src/resolvers/validate-arg.ts @@ -23,7 +23,11 @@ export async function validateArg( const { validateOrReject } = await import("class-validator"); try { - await validateOrReject(arg, validatorOptions); + if (Array.isArray(arg)) { + await Promise.all(arg.map(argItem => validateOrReject(argItem, validatorOptions))); + } else { + await validateOrReject(arg, validatorOptions); + } return arg; } catch (err) { throw new ArgumentValidationError(err); diff --git a/src/schema/schema-generator.ts b/src/schema/schema-generator.ts index 6c77b29db..d153c7890 100644 --- a/src/schema/schema-generator.ts +++ b/src/schema/schema-generator.ts @@ -554,7 +554,7 @@ export abstract class SchemaGenerator { while (superClass.prototype !== undefined) { const superArgumentType = getMetadataStorage().argumentTypes.find( it => it.target === superClass, - )!; + ); if (superArgumentType) { this.mapArgFields(superArgumentType, args); } diff --git a/tests/functional/validation.ts b/tests/functional/validation.ts index ebdf4ad63..6fee3c344 100644 --- a/tests/functional/validation.ts +++ b/tests/functional/validation.ts @@ -1,5 +1,5 @@ import "reflect-metadata"; -import { MaxLength, Max, Min } from "class-validator"; +import { MaxLength, Max, Min, ValidateNested } from "class-validator"; import { GraphQLSchema, graphql } from "graphql"; import { getMetadataStorage } from "../../src/metadata/getMetadataStorage"; @@ -51,6 +51,14 @@ describe("Validation", () => { @Field({ nullable: true }) @Min(5) optionalField?: number; + + @Field(type => SampleInput, { nullable: true }) + @ValidateNested() + nestedField?: SampleInput; + + @Field(type => [SampleInput], { nullable: true }) + @ValidateNested({ each: true }) + arrayField?: SampleInput[]; } @ArgsType() @@ -81,6 +89,14 @@ describe("Validation", () => { argInput = input; return {}; } + + @Mutation() + mutationWithInputsArray( + @Arg("inputs", type => [SampleInput]) inputs: SampleInput[], + ): SampleObject { + argInput = inputs; + return {}; + } } sampleResolver = SampleResolver; @@ -92,13 +108,13 @@ describe("Validation", () => { it("should pass input validation when data without optional field is correct", async () => { const mutation = `mutation { - sampleMutation(input: { - stringField: "12345", - numberField: 5, - }) { - field - } - }`; + sampleMutation(input: { + stringField: "12345", + numberField: 5, + }) { + field + } + }`; await graphql(schema, mutation); expect(argInput).toEqual({ stringField: "12345", numberField: 5 }); @@ -106,14 +122,14 @@ describe("Validation", () => { it("should pass input validation when data with optional field is correct", async () => { const mutation = `mutation { - sampleMutation(input: { - stringField: "12345", - numberField: 5, - optionalField: 5, - }) { - field - } - }`; + sampleMutation(input: { + stringField: "12345", + numberField: 5, + optionalField: 5, + }) { + field + } + }`; await graphql(schema, mutation); expect(argInput).toEqual({ stringField: "12345", numberField: 5, optionalField: 5 }); @@ -121,13 +137,87 @@ describe("Validation", () => { it("should throw validation error when input is incorrect", async () => { const mutation = `mutation { - sampleMutation(input: { - stringField: "12345", - numberField: 15, - }) { - field - } - }`; + sampleMutation(input: { + stringField: "12345", + numberField: 15, + }) { + field + } + }`; + + const result = await graphql(schema, mutation); + expect(result.data).toBeNull(); + expect(result.errors).toHaveLength(1); + + const validationError = result.errors![0].originalError! as ArgumentValidationError; + expect(validationError).toBeInstanceOf(ArgumentValidationError); + expect(validationError.validationErrors).toHaveLength(1); + expect(validationError.validationErrors[0].property).toEqual("numberField"); + }); + + it("should throw validation error when nested input field is incorrect", async () => { + const mutation = `mutation { + sampleMutation(input: { + stringField: "12345", + numberField: 5, + nestedField: { + stringField: "12345", + numberField: 15, + } + }) { + field + } + }`; + + const result = await graphql(schema, mutation); + expect(result.data).toBeNull(); + expect(result.errors).toHaveLength(1); + + const validationError = result.errors![0].originalError! as ArgumentValidationError; + expect(validationError).toBeInstanceOf(ArgumentValidationError); + expect(validationError.validationErrors).toHaveLength(1); + expect(validationError.validationErrors[0].property).toEqual("nestedField"); + }); + + it("should throw validation error when nested array input field is incorrect", async () => { + const mutation = `mutation { + sampleMutation(input: { + stringField: "12345", + numberField: 5, + arrayField: [{ + stringField: "12345", + numberField: 15, + }] + }) { + field + } + }`; + + const result = await graphql(schema, mutation); + expect(result.data).toBeNull(); + expect(result.errors).toHaveLength(1); + + const validationError = result.errors![0].originalError! as ArgumentValidationError; + expect(validationError).toBeInstanceOf(ArgumentValidationError); + expect(validationError.validationErrors).toHaveLength(1); + expect(validationError.validationErrors[0].property).toEqual("arrayField"); + }); + + it("should throw validation error when one of input array is incorrect", async () => { + const mutation = `mutation { + mutationWithInputsArray(inputs: [ + { + stringField: "12345", + numberField: 5, + }, + { + stringField: "12345", + numberField: 15, + }, + ]) { + field + } + }`; const result = await graphql(schema, mutation); expect(result.data).toBeNull(); @@ -141,14 +231,14 @@ describe("Validation", () => { it("should throw validation error when optional input field is incorrect", async () => { const mutation = `mutation { - sampleMutation(input: { - stringField: "12345", - numberField: 5, - optionalField: -5, - }) { - field - } - }`; + sampleMutation(input: { + stringField: "12345", + numberField: 5, + optionalField: -5, + }) { + field + } + }`; const result = await graphql(schema, mutation); expect(result.data).toBeNull(); @@ -191,13 +281,13 @@ describe("Validation", () => { it("should throw validation error when one of arguments is incorrect", async () => { const query = `query { - sampleQuery( - stringField: "12345", - numberField: 15, - ) { - field - } - }`; + sampleQuery( + stringField: "12345", + numberField: 15, + ) { + field + } + }`; const result = await graphql(schema, query); expect(result.data).toBeNull(); @@ -211,14 +301,14 @@ describe("Validation", () => { it("should throw validation error when optional argument is incorrect", async () => { const query = `query { - sampleQuery( - stringField: "12345", - numberField: 5, - optionalField: -5, - ) { - field - } - }`; + sampleQuery( + stringField: "12345", + numberField: 5, + optionalField: -5, + ) { + field + } + }`; const result = await graphql(schema, query); expect(result.data).toBeNull();