Skip to content

Commit

Permalink
feat: generate strong typing for the user context of enhance API
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 committed Mar 15, 2024
1 parent 8099793 commit 769cdb0
Show file tree
Hide file tree
Showing 6 changed files with 339 additions and 12 deletions.
4 changes: 2 additions & 2 deletions packages/runtime/src/enhancements/create-enhancement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ export type InternalEnhancementOptions = EnhancementOptions & {
/**
* Context for creating enhanced `PrismaClient`
*/
export type EnhancementContext = {
user?: AuthUser;
export type EnhancementContext<User extends AuthUser = AuthUser> = {
user?: User;
};

let hasPassword: boolean | undefined = undefined;
Expand Down
4 changes: 3 additions & 1 deletion packages/runtime/src/enhancements/policy/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { getIdFields } from '../../cross';
import { DbClientContract } from '../../types';
import { hasAllFields } from '../../validation';
import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement';
import { Logger } from '../logger';
import { makeProxy } from '../proxy';
import { PolicyProxyHandler } from './handler';

Expand Down Expand Up @@ -44,7 +45,8 @@ export function withPolicy<DbClient extends object>(
if (authSelector) {
Object.keys(authSelector).forEach((f) => {
if (!(f in userContext)) {
console.warn(`User context does not have field "${f}" used in policy rules`);
const logger = new Logger(prisma);
logger.warn(`User context does not have field "${f}" used in policy rules`);
}
});
}
Expand Down
141 changes: 141 additions & 0 deletions packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import { getIdFields, isAuthInvocation, isDataModelFieldReference } from '@zenstackhq/sdk';
import {
DataModel,
DataModelField,
Expression,
isDataModel,
isMemberAccessExpr,
type Model,
} from '@zenstackhq/sdk/ast';
import { streamAst, type AstNode } from 'langium';
import { isCollectionPredicate } from '../../../utils/ast-utils';

/**
* Generate types for typing the `user` context object passed to the `enhance` call, based
* on the fields (potentially deeply) access through `auth()`.
*/
export function generateAuthType(model: Model, authModel: DataModel) {
const types = new Map<
string,
{
// scalar fields to directly pick from Prisma-generated type
pickFields: string[];

// relation fields to include
addFields: { name: string; type: string }[];
}
>();

types.set(authModel.name, { pickFields: getIdFields(authModel).map((f) => f.name), addFields: [] });

const ensureType = (model: string) => {
if (!types.has(model)) {
types.set(model, { pickFields: [], addFields: [] });
}
};

const addPickField = (model: string, field: string) => {
let fields = types.get(model);
if (!fields) {
fields = { pickFields: [], addFields: [] };
types.set(model, fields);
}
if (!fields.pickFields.includes(field)) {
fields.pickFields.push(field);
}
};

const addAddField = (model: string, name: string, type: string, array: boolean) => {
let fields = types.get(model);
if (!fields) {
fields = { pickFields: [], addFields: [] };
types.set(model, fields);
}
if (!fields.addFields.find((f) => f.name === name)) {
fields.addFields.push({ name, type: array ? `${type}[]` : type });
}
};

// get all policy expressions involving `auth()`
const authInvolvedExprs = streamAst(model).filter(isAuthAccess);

// traverse the expressions and collect types and fields involved
authInvolvedExprs.forEach((expr) => {
streamAst(expr).forEach((node) => {
if (isMemberAccessExpr(node)) {
const exprType = node.operand.$resolvedType?.decl;
if (isDataModel(exprType)) {
const memberDecl = node.member.ref;
if (isDataModel(memberDecl?.type.reference?.ref)) {
// member is a relation
const fieldType = memberDecl.type.reference.ref.name;
ensureType(fieldType);
addAddField(exprType.name, memberDecl.name, fieldType, memberDecl.type.array);
} else {
// member is a scalar
addPickField(exprType.name, node.member.$refText);
}
}
}

if (isDataModelFieldReference(node)) {
// this can happen inside collection predicates
const fieldDecl = node.target.ref as DataModelField;
const fieldType = fieldDecl.type.reference?.ref;
if (isDataModel(fieldType)) {
// field is a relation
ensureType(fieldType.name);
addAddField(fieldDecl.$container.name, node.target.$refText, fieldType.name, fieldDecl.type.array);
} else {
// field is a scalar
addPickField(fieldDecl.$container.name, node.target.$refText);
}
}
});
});

// generate:
// `
// namespace auth {
// export type User = WithRequired<Partial<_P.User>, 'id'> & { profile: Profile; };
// export type Profile = WithRequired<Partial<_P.Profile>, 'age'>;
// }
// `

return `namespace auth {
type WithRequired<T, K extends keyof T> = T & { [P in K]-?: T[P] };
${Array.from(types.entries())
.map(([model, fields]) => {
let result = `Partial<_P.${model}>`;
if (fields.pickFields.length > 0) {
result = `WithRequired<${result}, ${fields.pickFields.map((f) => `'${f}'`).join('|')}>`;
}
if (fields.addFields.length > 0) {
result = `${result} & { ${fields.addFields.map(({ name, type }) => `${name}: ${type}`).join('; ')} }`;
}
return ` export type ${model} = ${result};`;
})
.join('\n')}
}`;
}

function isAuthAccess(node: AstNode): node is Expression {
if (isAuthInvocation(node)) {
return true;
}

if (isMemberAccessExpr(node) && isAuthAccess(node.operand)) {
return true;
}

if (isCollectionPredicate(node)) {
if (isAuthAccess(node.left)) {
return true;
}
}

return false;
}
20 changes: 16 additions & 4 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { DMMF } from '@prisma/generator-helper';
import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime';
import {
getAttribute,
getAuthModel,
getDataModels,
getDMMF,
getPrismaClientImportSpec,
Expand All @@ -28,6 +29,7 @@ import { execPackage } from '../../../utils/exec-utils';
import { trackPrismaSchemaError } from '../../prisma';
import { PrismaSchemaGenerator } from '../../prisma/schema-generator';
import { isDefaultWithAuth } from '../enhancer-utils';
import { generateAuthType } from './auth-type-generator';

// information of delegate models and their sub models
type DelegateInfo = [DataModel, DataModel[]][];
Expand Down Expand Up @@ -62,16 +64,26 @@ export async function generate(model: Model, options: PluginOptions, project: Pr
await prismaDts.save();
}

const authModel = getAuthModel(model.declarations.filter(isDataModel));
const authTypes = authModel ? generateAuthType(model, authModel) : '';
const authTypeParam = authModel ? `auth.${authModel.name}` : 'AuthUser';

const enhanceTs = project.createSourceFile(
path.join(outDir, 'enhance.ts'),
`import { createEnhancement, type EnhancementContext, type EnhancementOptions, type ZodSchemas } from '@zenstackhq/runtime';
`import { createEnhancement, type EnhancementContext, type EnhancementOptions, type ZodSchemas, type AuthUser } from '@zenstackhq/runtime';
import modelMeta from './model-meta';
import policy from './policy';
${options.withZodSchemas ? "import * as zodSchemas from './zod';" : 'const zodSchemas = undefined;'}
import { Prisma } from '${getPrismaClientImportSpec(outDir, options)}';
${withLogicalClient ? `import { type PrismaClient } from '${logicalPrismaClientDir}/index-fixed';` : ``}
${
withLogicalClient
? `import type * as _P, { type PrismaClient } from '${logicalPrismaClientDir}/index-fixed';`
: `import type * as _P from '@prisma/client';`
}
${options.withZodSchemas ? "import * as zodSchemas from './zod';" : 'const zodSchemas = undefined;'}
${authTypes}
export function enhance<DbClient extends object>(prisma: DbClient, context?: EnhancementContext, options?: EnhancementOptions)${
export function enhance<DbClient extends object>(prisma: DbClient, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions)${
withLogicalClient ? ': PrismaClient' : ''
} {
return createEnhancement(prisma, {
Expand Down
12 changes: 8 additions & 4 deletions packages/testtools/src/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,17 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) {
prisma = prisma.$extends(withPulse({ apiKey: opt.pulseApiKey }));
}

opt.extraSourceFiles?.forEach(({ name, content }) => {
fs.writeFileSync(path.join(projectRoot, name), content);
});

if (opt.extraSourceFiles && opt.extraSourceFiles.length > 0 && !opt.compile) {
console.warn('`extraSourceFiles` is true but `compile` is false.');
}

if (opt.compile) {
console.log('Compiling...');

opt.extraSourceFiles?.forEach(({ name, content }) => {
fs.writeFileSync(path.join(projectRoot, name), content);
});

run('npx tsc --init');

// add generated '.zenstack/zod' folder to typescript's search path,
Expand Down
Loading

0 comments on commit 769cdb0

Please sign in to comment.