diff --git a/src/index.ts b/src/index.ts index d99c021..342dfbb 100644 --- a/src/index.ts +++ b/src/index.ts @@ -3,3 +3,4 @@ export * from './converter' export * from './middleware' export * from './time-collection' export * from './types' +export * from './zod-collection' diff --git a/src/zod-collection.ts b/src/zod-collection.ts new file mode 100644 index 0000000..db599eb --- /dev/null +++ b/src/zod-collection.ts @@ -0,0 +1,49 @@ +import { TsReadWriteCollection } from './collection' +import z from 'zod' +import { convertReadWriteCollection } from './converter' +import { OptionalUnlessRequiredId, WithId, Document } from 'mongodb' +import { WithTime } from './time-collection' +import { parseFieldsAsArrays, zodDeepPartial } from './zod-utils' + +type WithIdTime = WithId> + +/** + * Create a collection that uses zod to validate inserted data and + * Mongo operators at runtime + * @param collection + * @param schema + */ +export const convertToZodCollection = ( + collection: TsReadWriteCollection>, + schema: z.ZodType +): TsReadWriteCollection> => + convertReadWriteCollection(collection, { + preInsert: ( + obj: OptionalUnlessRequiredId + ): OptionalUnlessRequiredId => { + return schema.parse(obj) as OptionalUnlessRequiredId + }, + preUpdate: (obj) => { + const partialSchema = zodDeepPartial(schema) + return { + ...obj, + // TODO: validate other operators + ...(obj.$pull + ? { $pull: parseFieldsAsArrays(obj.$pull, partialSchema) } + : {}), + ...(obj.$push + ? { $push: parseFieldsAsArrays(obj.$push, partialSchema) } + : {}), + ...(obj.$addToSet + ? { $addToSet: parseFieldsAsArrays(obj.$addToSet, partialSchema) } + : {}), + ...(obj.$set ? { $set: partialSchema.parse(obj.$set) } : {}), + ...(obj.$setOnInsert + ? { $setOnInsert: partialSchema.parse(obj.$setOnInsert) } + : {}), + } + }, + preReplace: (obj) => schema.parse(obj), + postFind: (obj) => obj, + preFilter: (obj) => obj, + }) diff --git a/src/zod-utils.ts b/src/zod-utils.ts new file mode 100644 index 0000000..abd35a2 --- /dev/null +++ b/src/zod-utils.ts @@ -0,0 +1,100 @@ +import z, { + AnyZodObject, + ZodArray, + ZodDiscriminatedUnion, + ZodIntersection, + ZodNullable, + ZodObject, + ZodOptional, + ZodRawShape, + ZodTuple, + ZodTypeAny, + ZodUnion, +} from 'zod' + +export const parseFieldsAsArrays = >( + obj: T, + schema: z.ZodTypeAny +): T => { + // convert all fields to arrays and parse them + const parsed = schema.parse( + // eslint-disable-next-line custom-rules/prefer-map-to-object-from-entries + Object.fromEntries( + Object.entries(obj).map(([key, value]) => { + // Handle adding values to an array using { $each: [value], ...otherModifiers } + if (typeof value === 'object' && value !== null && '$each' in value) + return [key, value.$each] + return [key, [value]] + }) + ) + ) + // returns only fields parsed in their initial flattened shape + // eslint-disable-next-line custom-rules/prefer-map-to-object-from-entries + return Object.fromEntries( + Object.keys(parsed).map((key: keyof T) => [key, obj[key]]) + ) as T +} + +// Copied from github.com/colinhacks/zod/blob/6dad90785398885f7b058f5c0760d5ae5476b833/src/types.ts#L2189-L2217 +// and extended to support unions and discriminated unions +export const zodDeepPartial = (schema: ZodTypeAny): ZodTypeAny => { + if (schema instanceof ZodObject) { + const newShape: any = {} + + for (const key in schema.shape) { + const fieldSchema = schema.shape[key] + newShape[key] = ZodOptional.create(zodDeepPartial(fieldSchema)) + } + return new ZodObject({ + ...schema._def, + shape: () => newShape, + }) as any + } else if (schema instanceof ZodIntersection) { + return ZodIntersection.create( + zodDeepPartial(schema._def.left), + zodDeepPartial(schema._def.right) + ) + } else if (schema instanceof ZodUnion) { + type Options = [ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]] + const options = schema._def.options as Options + return ZodUnion.create( + options.map((option) => zodDeepPartial(option)) as Options + ) as any + } else if (schema instanceof ZodDiscriminatedUnion) { + const types = Object.values(schema.options) as (AnyZodObject & + ZodRawShape)[] + const discriminator = schema.discriminator + const newTypes = types.map((type) => { + const newShape: any = {} + + for (const key in type.shape) { + const fieldSchema = type.shape[key] + if (key === discriminator) { + newShape[key] = fieldSchema + } else { + newShape[key] = ZodOptional.create(zodDeepPartial(fieldSchema)) + } + } + return new ZodObject({ + ...type._def, + shape: () => newShape, + } as any) + }) as any + return ZodDiscriminatedUnion.create(discriminator, newTypes) as any + } else if (schema instanceof ZodArray) { + return new ZodArray({ + ...schema._def, + type: zodDeepPartial(schema.element), + }) + } else if (schema instanceof ZodOptional) { + return ZodOptional.create(zodDeepPartial(schema.unwrap())) + } else if (schema instanceof ZodNullable) { + return ZodNullable.create(zodDeepPartial(schema.unwrap())) + } else if (schema instanceof ZodTuple) { + return ZodTuple.create( + schema.items.map((item: any) => zodDeepPartial(item)) + ) + } else { + return schema + } +} diff --git a/test/zod-utils.test.ts b/test/zod-utils.test.ts new file mode 100644 index 0000000..515ed2f --- /dev/null +++ b/test/zod-utils.test.ts @@ -0,0 +1,131 @@ +import * as z from 'zod' +import { zodDeepPartial, parseFieldsAsArrays } from '../src/zod-utils' + +describe('zodDeepPartial', () => { + test('partial of primitives', () => { + const deepPartial = zodDeepPartial( + z.object({ + s: z.string(), + n: z.number(), + i: z.bigint(), + b: z.boolean(), + z: z.null(), + u: z.undefined(), + }) + ) + + const correctValues = { + s: 'string', + n: 10, + i: BigInt(2), + b: true, + z: null, + u: undefined, + } + + expect(deepPartial.parse({})).toEqual({}) + expect(deepPartial.parse({ z: null })).toEqual({ z: null }) + expect(deepPartial.parse({ u: undefined })).toEqual({ u: undefined }) + expect(deepPartial.parse(correctValues)).toEqual(correctValues) + expect(() => deepPartial.parse({ s: 10 })).toThrow() + expect(() => deepPartial.parse({ n: 'a' })).toThrow() + expect(() => deepPartial.parse({ i: 'a' })).toThrow() + expect(() => deepPartial.parse({ b: 'a' })).toThrow() + }) + + test('partial of unions', () => { + const deepPartial = zodDeepPartial( + z.union([ + z.object({ s: z.string() }).strict(), + z.object({ n: z.number() }).strict(), + ]) + ) + + expect(deepPartial.parse({ s: 'string' })).toEqual({ s: 'string' }) + expect(deepPartial.parse({ n: 10 })).toEqual({ n: 10 }) + expect(deepPartial.parse({})).toEqual({}) + + expect(() => deepPartial.parse({ n: 'string' })).toThrow() + }) + + test('partial of discriminated union', () => { + const deepPartial = zodDeepPartial( + z.discriminatedUnion('t', [ + z.object({ t: z.literal('s1'), s: z.string() }), + z.object({ t: z.literal('n1'), n: z.number() }), + ]) + ) + + expect(deepPartial.parse({ t: 's1' })).toEqual({ t: 's1' }) + expect(deepPartial.parse({ t: 's1', s: 's' })).toEqual({ t: 's1', s: 's' }) + expect(deepPartial.parse({ t: 'n1', n: 10 })).toEqual({ t: 'n1', n: 10 }) + + expect(() => deepPartial.parse({ n: 10 })).toThrow() + expect(() => deepPartial.parse({ s: 's' })).toThrow() + expect(() => deepPartial.parse({})).toThrow() + }) + + test('partial of intersection', () => { + const deepPartial = zodDeepPartial( + z.object({ t: z.literal('t1') }).and(z.object({ y: z.literal('y1') })) + ) + + expect(deepPartial.parse({})).toEqual({}) + expect(deepPartial.parse({ t: 't1' })).toEqual({ t: 't1' }) + expect(deepPartial.parse({ t: 't1', y: 'y1' })).toEqual({ + t: 't1', + y: 'y1', + }) + + expect(() => deepPartial.parse({ t: 't2' })).toThrow() + expect(() => deepPartial.parse({ y: 'y2' })).toThrow() + }) +}) + +describe('parseFieldsAsArrays', () => { + const schema = z.object({ + strings: z.string().array(), + numbers: z.number().array(), + }) + + test('should parse valid values successfully', () => { + expect(parseFieldsAsArrays({ numbers: 6, strings: 's' }, schema)).toEqual({ + numbers: 6, + strings: 's', + }) + const withEachOperator = { + numbers: 6, + strings: { $each: ['s'], $otherOperator: 5 }, + } + expect(parseFieldsAsArrays(withEachOperator, schema)).toEqual( + withEachOperator + ) + }) + + test('should remove extra fields', () => { + expect( + parseFieldsAsArrays( + { numbers: 6, strings: 's', objects: {}, maps: new Map() }, + schema + ) + ).toEqual({ + numbers: 6, + strings: 's', + }) + }) + + test('should throw for invalid values', () => { + expect(() => + parseFieldsAsArrays({ numbers: 's', strings: 7 }, schema) + ).toThrow() + expect(() => + parseFieldsAsArrays({ numbers: 's', strings: 7 }, schema) + ).toThrow() + expect(() => + parseFieldsAsArrays({ numbers: [7], strings: ['s'] }, schema) + ).toThrow() + expect(() => + parseFieldsAsArrays({ numbers: 7, strings: { $each: [7] } }, schema) + ).toThrow() + }) +})