Skip to content

Commit

Permalink
merge from dev (#1110)
Browse files Browse the repository at this point in the history
Co-authored-by: ErikMCM <[email protected]>
Co-authored-by: Jason Kleinberg <[email protected]>
Co-authored-by: Jonathan S <[email protected]>
Co-authored-by: Jiasheng <[email protected]>
  • Loading branch information
5 people authored Mar 9, 2024
1 parent 9f9d277 commit df07830
Show file tree
Hide file tree
Showing 14 changed files with 829 additions and 150 deletions.
34 changes: 23 additions & 11 deletions packages/runtime/src/cross/nested-write-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import type { FieldInfo, ModelMeta } from './model-meta';
import { resolveField } from './model-meta';
import { MaybePromise, PrismaWriteActionType, PrismaWriteActions } from './types';
import { enumerate, getModelFields } from './utils';
import { getModelFields } from './utils';

type NestingPathItem = { field?: FieldInfo; model: string; where: any; unique: boolean };

Expand Down Expand Up @@ -155,7 +155,7 @@ export class NestedWriteVisitor {
// visit payload
switch (action) {
case 'create':
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, {});
let callbackResult: any;
if (this.callback.create) {
Expand Down Expand Up @@ -183,7 +183,7 @@ export class NestedWriteVisitor {
break;

case 'connectOrCreate':
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, item.where);
let callbackResult: any;
if (this.callback.connectOrCreate) {
Expand All @@ -198,7 +198,7 @@ export class NestedWriteVisitor {

case 'connect':
if (this.callback.connect) {
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, item, true);
await this.callback.connect(model, item, newContext);
}
Expand All @@ -210,7 +210,7 @@ export class NestedWriteVisitor {
// if relation is to-many, the payload is a unique filter object
// if relation is to-one, the payload can only be boolean `true`
if (this.callback.disconnect) {
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, item, typeof item === 'object');
await this.callback.disconnect(model, item, newContext);
}
Expand All @@ -219,15 +219,15 @@ export class NestedWriteVisitor {

case 'set':
if (this.callback.set) {
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, item, true);
await this.callback.set(model, item, newContext);
}
}
break;

case 'update':
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, item.where);
let callbackResult: any;
if (this.callback.update) {
Expand All @@ -246,7 +246,7 @@ export class NestedWriteVisitor {
break;

case 'updateMany':
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, item.where);
let callbackResult: any;
if (this.callback.updateMany) {
Expand All @@ -260,7 +260,7 @@ export class NestedWriteVisitor {
break;

case 'upsert': {
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, item.where);
let callbackResult: any;
if (this.callback.upsert) {
Expand All @@ -280,7 +280,7 @@ export class NestedWriteVisitor {

case 'delete': {
if (this.callback.delete) {
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, toplevel ? item.where : item);
await this.callback.delete(model, item, newContext);
}
Expand All @@ -290,7 +290,7 @@ export class NestedWriteVisitor {

case 'deleteMany':
if (this.callback.deleteMany) {
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, toplevel ? item.where : item);
await this.callback.deleteMany(model, item, newContext);
}
Expand Down Expand Up @@ -338,4 +338,16 @@ export class NestedWriteVisitor {
}
}
}

// enumerate a (possible) array in reverse order, so that the enumeration
// callback can safely delete the current item
private *enumerateReverse(data: any) {
if (Array.isArray(data)) {
for (let i = data.length - 1; i >= 0; i--) {
yield data[i];
}
} else {
yield data;
}
}
}
53 changes: 35 additions & 18 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -357,29 +357,19 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}
}

if (context.parent.connect) {
// if the payload parent already has a "connect" clause, merge it
if (Array.isArray(context.parent.connect)) {
context.parent.connect.push(args.where);
} else {
context.parent.connect = [context.parent.connect, args.where];
}
} else {
// otherwise, create a new "connect" clause
context.parent.connect = args.where;
}
this.mergeToParent(context.parent, 'connect', args.where);
// record the key of connected entities so we can avoid validating them later
connectedEntities.add(getEntityKey(model, existing));
} else {
// create case
pushIdFields(model, context);

// create a new "create" clause at the parent level
context.parent.create = args.create;
this.mergeToParent(context.parent, 'create', args.create);
}

// remove the connectOrCreate clause
delete context.parent['connectOrCreate'];
this.removeFromParent(context.parent, 'connectOrCreate', args);

// return false to prevent visiting the nested payload
return false;
Expand Down Expand Up @@ -917,7 +907,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
await _create(model, args, context);

// remove it from the update payload
delete context.parent.create;
this.removeFromParent(context.parent, 'create', args);

// don't visit payload
return false;
Expand Down Expand Up @@ -950,22 +940,23 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
await _registerPostUpdateCheck(model, uniqueFilter);

// convert upsert to update
context.parent.update = {
const convertedUpdate = {
where: args.where,
data: this.validateUpdateInputSchema(model, args.update),
};
delete context.parent.upsert;
this.mergeToParent(context.parent, 'update', convertedUpdate);
this.removeFromParent(context.parent, 'upsert', args);

// continue visiting the new payload
return context.parent.update;
return convertedUpdate;
} else {
// create case

// process the entire create subtree separately
await _create(model, args.create, context);

// remove it from the update payload
delete context.parent.upsert;
this.removeFromParent(context.parent, 'upsert', args);

// don't visit payload
return false;
Expand Down Expand Up @@ -1388,5 +1379,31 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
return requireField(this.modelMeta, fieldInfo.type, fieldInfo.backLink);
}

private mergeToParent(parent: any, key: string, value: any) {
if (parent[key]) {
if (Array.isArray(parent[key])) {
parent[key].push(value);
} else {
parent[key] = [parent[key], value];
}
} else {
parent[key] = value;
}
}

private removeFromParent(parent: any, key: string, data: any) {
if (parent[key] === data) {
delete parent[key];
} else if (Array.isArray(parent[key])) {
const idx = parent[key].indexOf(data);
if (idx >= 0) {
parent[key].splice(idx, 1);
if (parent[key].length === 0) {
delete parent[key];
}
}
}
}

//#endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import {
isEnum,
isReferenceExpr,
} from '@zenstackhq/language/ast';
import { isFutureExpr, isRelationshipField, resolved } from '@zenstackhq/sdk';
import { isDataModelFieldReference, isFutureExpr, isRelationshipField, resolved } from '@zenstackhq/sdk';
import { ValidationAcceptor, streamAst } from 'langium';
import pluralize from 'pluralize';
import { AstValidator } from '../types';
Expand Down Expand Up @@ -151,6 +151,19 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
}
}

@check('@@validate')
private _checkValidate(attr: AttributeApplication, accept: ValidationAcceptor) {
const condition = attr.args[0]?.value;
if (
condition &&
streamAst(condition).some(
(node) => isDataModelFieldReference(node) && isDataModel(node.$resolvedType?.decl)
)
) {
accept('error', `\`@@validate\` condition cannot use relation fields`, { node: condition });
}
}

private validatePolicyKinds(
kind: string,
candidates: string[],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import {
AstNode,
BinaryExpr,
Expression,
ExpressionType,
isDataModel,
isDataModelAttribute,
isDataModelField,
isEnum,
isLiteralExpr,
Expand All @@ -12,7 +14,7 @@ import {
} from '@zenstackhq/language/ast';
import { isAuthInvocation, isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk';
import { ValidationAcceptor } from 'langium';
import { getContainingDataModel, isCollectionPredicate } from '../../utils/ast-utils';
import { findUpAst, getContainingDataModel, isCollectionPredicate } from '../../utils/ast-utils';
import { AstValidator } from '../types';
import { typeAssignable } from './utils';

Expand Down Expand Up @@ -123,6 +125,17 @@ export default class ExpressionValidator implements AstValidator<Expression> {

case '==':
case '!=': {
if (this.isInValidationContext(expr)) {
// in validation context, all fields are optional, so we should allow
// comparing any field against null
if (
(isDataModelFieldReference(expr.left) && isNullExpr(expr.right)) ||
(isDataModelFieldReference(expr.right) && isNullExpr(expr.left))
) {
return;
}
}

if (!!expr.left.$resolvedType?.array !== !!expr.right.$resolvedType?.array) {
accept('error', 'incompatible operand types', { node: expr });
break;
Expand Down Expand Up @@ -211,6 +224,10 @@ export default class ExpressionValidator implements AstValidator<Expression> {
}
}

private isInValidationContext(node: AstNode) {
return findUpAst(node, (n) => isDataModelAttribute(n) && n.decl.$refText === '@@validate');
}

private isNotModelFieldExpr(expr: Expression) {
return (
isLiteralExpr(expr) || isEnumFieldReference(expr) || isNullExpr(expr) || this.isAuthOrAuthMemberAccess(expr)
Expand Down
10 changes: 9 additions & 1 deletion packages/schema/src/plugins/zod/utils/schema-gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
getAttributeArg,
getAttributeArgLiteral,
getLiteral,
isDataModelFieldReference,
isFromStdlib,
} from '@zenstackhq/sdk';
import {
Expand Down Expand Up @@ -203,10 +204,17 @@ export function makeValidationRefinements(model: DataModel) {
const message = messageArg ? `, { message: ${JSON.stringify(messageArg)} }` : '';

try {
const expr = new TypeScriptExpressionTransformer({
let expr = new TypeScriptExpressionTransformer({
context: ExpressionContext.ValidationRule,
fieldReferenceContext: 'value',
}).transform(valueArg);

if (isDataModelFieldReference(valueArg)) {
// if the expression is a simple field reference, treat undefined
// as true since the all fields are optional in validation context
expr = `${expr} ?? true`;
}

return `.refine((value: any) => ${expr}${message})`;
} catch (err) {
if (err instanceof TypeScriptExpressionTransformerError) {
Expand Down
14 changes: 14 additions & 0 deletions packages/schema/src/utils/ast-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,17 @@ export function getRecursiveBases(dataModel: DataModel, includeDelegate = true):
});
return result;
}

/**
* Walk upward from the current AST node to find the first node that satisfies the predicate.
*/
export function findUpAst(node: AstNode, predicate: (node: AstNode) => boolean): AstNode | undefined {
let curr: AstNode | undefined = node;
while (curr) {
if (predicate(curr)) {
return curr;
}
curr = curr.$container;
}
return undefined;
}
21 changes: 8 additions & 13 deletions packages/schema/src/utils/pkg-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ import { match } from 'ts-pattern';
export type PackageManagers = 'npm' | 'yarn' | 'pnpm';

/**
* A type named FindUp that takes a type parameter e which extends boolean.
* If e extends true, it returns a union type of string[] or undefined.
* A type named FindUp that takes a type parameter e which extends boolean.
* If e extends true, it returns a union type of string[] or undefined.
* If e does not extend true, it returns a union type of string or undefined.
*
* @export
* @template e A type parameter that extends boolean
*/
export type FindUp<e extends boolean> = e extends true ? string[] | undefined : string | undefined;
export type FindUp<e extends boolean> = e extends true ? string[] | undefined : string | undefined
/**
* Find and return file paths by searching parent directories based on the given names list and current working directory (cwd) path.
* Optionally return a single path or multiple paths.
* If multiple allowed, return all paths found.
* Find and return file paths by searching parent directories based on the given names list and current working directory (cwd) path.
* Optionally return a single path or multiple paths.
* If multiple allowed, return all paths found.
* If no paths are found, return undefined.
*
* @export
Expand All @@ -28,12 +28,7 @@ export type FindUp<e extends boolean> = e extends true ? string[] | undefined :
* @param [result=[]] An array of strings representing the accumulated results used in multiple results
* @returns Path(s) to a specific file or folder within the directory or parent directories
*/
export function findUp<e extends boolean = false>(
names: string[],
cwd: string = process.cwd(),
multiple: e = false as e,
result: string[] = []
): FindUp<e> {
export function findUp<e extends boolean = false>(names: string[], cwd: string = process.cwd(), multiple: e = false as e, result: string[] = []): FindUp<e> {
if (!names.some((name) => !!name)) return undefined;
const target = names.find((name) => fs.existsSync(path.join(cwd, name)));
if (multiple == false && target) return path.join(cwd, target) as FindUp<e>;
Expand Down Expand Up @@ -111,7 +106,7 @@ export function ensurePackage(
}

/**
* A function that searches for the nearest package.json file starting from the provided search path or the current working directory if no search path is provided.
* A function that searches for the nearest package.json file starting from the provided search path or the current working directory if no search path is provided.
* It iterates through the directory structure going one level up at a time until it finds a package.json file. If no package.json file is found, it returns undefined.
* @deprecated Use findUp instead @see findUp
*/
Expand Down
Loading

0 comments on commit df07830

Please sign in to comment.