-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Feat: Zod Collection * Comment * Tweaks
- Loading branch information
Showing
4 changed files
with
281 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import { TsReadWriteCollection } from './collection' | ||
import z from 'zod' | ||
import { convertReadWriteCollection } from './converter' | ||
import { OptionalUnlessRequiredId, WithId, Document } from 'mongodb' | ||
import { WithTime } from './time-collection' | ||
import { parseFieldsAsArrays, zodDeepPartial } from './zod-utils' | ||
|
||
type WithIdTime<T> = WithId<WithTime<T>> | ||
|
||
/** | ||
* Create a collection that uses zod to validate inserted data and | ||
* Mongo operators at runtime | ||
* @param collection | ||
* @param schema | ||
*/ | ||
export const convertToZodCollection = <TSchema extends Document>( | ||
collection: TsReadWriteCollection<TSchema, WithIdTime<TSchema>>, | ||
schema: z.ZodType<TSchema> | ||
): TsReadWriteCollection<TSchema, WithIdTime<TSchema>> => | ||
convertReadWriteCollection(collection, { | ||
preInsert: ( | ||
obj: OptionalUnlessRequiredId<TSchema> | ||
): OptionalUnlessRequiredId<TSchema> => { | ||
return schema.parse(obj) as OptionalUnlessRequiredId<TSchema> | ||
}, | ||
preUpdate: (obj) => { | ||
const partialSchema = zodDeepPartial(schema) | ||
return { | ||
...obj, | ||
// TODO: validate other operators | ||
...(obj.$pull | ||
? { $pull: parseFieldsAsArrays(obj.$pull, partialSchema) } | ||
: {}), | ||
...(obj.$push | ||
? { $push: parseFieldsAsArrays(obj.$push, partialSchema) } | ||
: {}), | ||
...(obj.$addToSet | ||
? { $addToSet: parseFieldsAsArrays(obj.$addToSet, partialSchema) } | ||
: {}), | ||
...(obj.$set ? { $set: partialSchema.parse(obj.$set) } : {}), | ||
...(obj.$setOnInsert | ||
? { $setOnInsert: partialSchema.parse(obj.$setOnInsert) } | ||
: {}), | ||
} | ||
}, | ||
preReplace: (obj) => schema.parse(obj), | ||
postFind: (obj) => obj, | ||
preFilter: (obj) => obj, | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import z, { | ||
AnyZodObject, | ||
ZodArray, | ||
ZodDiscriminatedUnion, | ||
ZodIntersection, | ||
ZodNullable, | ||
ZodObject, | ||
ZodOptional, | ||
ZodRawShape, | ||
ZodTuple, | ||
ZodTypeAny, | ||
ZodUnion, | ||
} from 'zod' | ||
|
||
export const parseFieldsAsArrays = <T extends Record<string, unknown>>( | ||
obj: T, | ||
schema: z.ZodTypeAny | ||
): T => { | ||
// convert all fields to arrays and parse them | ||
const parsed = schema.parse( | ||
// eslint-disable-next-line custom-rules/prefer-map-to-object-from-entries | ||
Object.fromEntries( | ||
Object.entries(obj).map(([key, value]) => { | ||
// Handle adding values to an array using { $each: [value], ...otherModifiers } | ||
if (typeof value === 'object' && value !== null && '$each' in value) | ||
return [key, value.$each] | ||
return [key, [value]] | ||
}) | ||
) | ||
) | ||
// returns only fields parsed in their initial flattened shape | ||
// eslint-disable-next-line custom-rules/prefer-map-to-object-from-entries | ||
return Object.fromEntries( | ||
Object.keys(parsed).map((key: keyof T) => [key, obj[key]]) | ||
) as T | ||
} | ||
|
||
// Copied from github.com/colinhacks/zod/blob/6dad90785398885f7b058f5c0760d5ae5476b833/src/types.ts#L2189-L2217 | ||
// and extended to support unions and discriminated unions | ||
export const zodDeepPartial = (schema: ZodTypeAny): ZodTypeAny => { | ||
if (schema instanceof ZodObject) { | ||
const newShape: any = {} | ||
|
||
for (const key in schema.shape) { | ||
const fieldSchema = schema.shape[key] | ||
newShape[key] = ZodOptional.create(zodDeepPartial(fieldSchema)) | ||
} | ||
return new ZodObject({ | ||
...schema._def, | ||
shape: () => newShape, | ||
}) as any | ||
} else if (schema instanceof ZodIntersection) { | ||
return ZodIntersection.create( | ||
zodDeepPartial(schema._def.left), | ||
zodDeepPartial(schema._def.right) | ||
) | ||
} else if (schema instanceof ZodUnion) { | ||
type Options = [ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]] | ||
const options = schema._def.options as Options | ||
return ZodUnion.create( | ||
options.map((option) => zodDeepPartial(option)) as Options | ||
) as any | ||
} else if (schema instanceof ZodDiscriminatedUnion) { | ||
const types = Object.values(schema.options) as (AnyZodObject & | ||
ZodRawShape)[] | ||
const discriminator = schema.discriminator | ||
const newTypes = types.map((type) => { | ||
const newShape: any = {} | ||
|
||
for (const key in type.shape) { | ||
const fieldSchema = type.shape[key] | ||
if (key === discriminator) { | ||
newShape[key] = fieldSchema | ||
} else { | ||
newShape[key] = ZodOptional.create(zodDeepPartial(fieldSchema)) | ||
} | ||
} | ||
return new ZodObject({ | ||
...type._def, | ||
shape: () => newShape, | ||
} as any) | ||
}) as any | ||
return ZodDiscriminatedUnion.create(discriminator, newTypes) as any | ||
} else if (schema instanceof ZodArray) { | ||
return new ZodArray({ | ||
...schema._def, | ||
type: zodDeepPartial(schema.element), | ||
}) | ||
} else if (schema instanceof ZodOptional) { | ||
return ZodOptional.create(zodDeepPartial(schema.unwrap())) | ||
} else if (schema instanceof ZodNullable) { | ||
return ZodNullable.create(zodDeepPartial(schema.unwrap())) | ||
} else if (schema instanceof ZodTuple) { | ||
return ZodTuple.create( | ||
schema.items.map((item: any) => zodDeepPartial(item)) | ||
) | ||
} else { | ||
return schema | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import * as z from 'zod' | ||
import { zodDeepPartial, parseFieldsAsArrays } from '../src/zod-utils' | ||
|
||
describe('zodDeepPartial', () => { | ||
test('partial of primitives', () => { | ||
const deepPartial = zodDeepPartial( | ||
z.object({ | ||
s: z.string(), | ||
n: z.number(), | ||
i: z.bigint(), | ||
b: z.boolean(), | ||
z: z.null(), | ||
u: z.undefined(), | ||
}) | ||
) | ||
|
||
const correctValues = { | ||
s: 'string', | ||
n: 10, | ||
i: BigInt(2), | ||
b: true, | ||
z: null, | ||
u: undefined, | ||
} | ||
|
||
expect(deepPartial.parse({})).toEqual({}) | ||
expect(deepPartial.parse({ z: null })).toEqual({ z: null }) | ||
expect(deepPartial.parse({ u: undefined })).toEqual({ u: undefined }) | ||
expect(deepPartial.parse(correctValues)).toEqual(correctValues) | ||
expect(() => deepPartial.parse({ s: 10 })).toThrow() | ||
expect(() => deepPartial.parse({ n: 'a' })).toThrow() | ||
expect(() => deepPartial.parse({ i: 'a' })).toThrow() | ||
expect(() => deepPartial.parse({ b: 'a' })).toThrow() | ||
}) | ||
|
||
test('partial of unions', () => { | ||
const deepPartial = zodDeepPartial( | ||
z.union([ | ||
z.object({ s: z.string() }).strict(), | ||
z.object({ n: z.number() }).strict(), | ||
]) | ||
) | ||
|
||
expect(deepPartial.parse({ s: 'string' })).toEqual({ s: 'string' }) | ||
expect(deepPartial.parse({ n: 10 })).toEqual({ n: 10 }) | ||
expect(deepPartial.parse({})).toEqual({}) | ||
|
||
expect(() => deepPartial.parse({ n: 'string' })).toThrow() | ||
}) | ||
|
||
test('partial of discriminated union', () => { | ||
const deepPartial = zodDeepPartial( | ||
z.discriminatedUnion('t', [ | ||
z.object({ t: z.literal('s1'), s: z.string() }), | ||
z.object({ t: z.literal('n1'), n: z.number() }), | ||
]) | ||
) | ||
|
||
expect(deepPartial.parse({ t: 's1' })).toEqual({ t: 's1' }) | ||
expect(deepPartial.parse({ t: 's1', s: 's' })).toEqual({ t: 's1', s: 's' }) | ||
expect(deepPartial.parse({ t: 'n1', n: 10 })).toEqual({ t: 'n1', n: 10 }) | ||
|
||
expect(() => deepPartial.parse({ n: 10 })).toThrow() | ||
expect(() => deepPartial.parse({ s: 's' })).toThrow() | ||
expect(() => deepPartial.parse({})).toThrow() | ||
}) | ||
|
||
test('partial of intersection', () => { | ||
const deepPartial = zodDeepPartial( | ||
z.object({ t: z.literal('t1') }).and(z.object({ y: z.literal('y1') })) | ||
) | ||
|
||
expect(deepPartial.parse({})).toEqual({}) | ||
expect(deepPartial.parse({ t: 't1' })).toEqual({ t: 't1' }) | ||
expect(deepPartial.parse({ t: 't1', y: 'y1' })).toEqual({ | ||
t: 't1', | ||
y: 'y1', | ||
}) | ||
|
||
expect(() => deepPartial.parse({ t: 't2' })).toThrow() | ||
expect(() => deepPartial.parse({ y: 'y2' })).toThrow() | ||
}) | ||
}) | ||
|
||
describe('parseFieldsAsArrays', () => { | ||
const schema = z.object({ | ||
strings: z.string().array(), | ||
numbers: z.number().array(), | ||
}) | ||
|
||
test('should parse valid values successfully', () => { | ||
expect(parseFieldsAsArrays({ numbers: 6, strings: 's' }, schema)).toEqual({ | ||
numbers: 6, | ||
strings: 's', | ||
}) | ||
const withEachOperator = { | ||
numbers: 6, | ||
strings: { $each: ['s'], $otherOperator: 5 }, | ||
} | ||
expect(parseFieldsAsArrays(withEachOperator, schema)).toEqual( | ||
withEachOperator | ||
) | ||
}) | ||
|
||
test('should remove extra fields', () => { | ||
expect( | ||
parseFieldsAsArrays( | ||
{ numbers: 6, strings: 's', objects: {}, maps: new Map() }, | ||
schema | ||
) | ||
).toEqual({ | ||
numbers: 6, | ||
strings: 's', | ||
}) | ||
}) | ||
|
||
test('should throw for invalid values', () => { | ||
expect(() => | ||
parseFieldsAsArrays({ numbers: 's', strings: 7 }, schema) | ||
).toThrow() | ||
expect(() => | ||
parseFieldsAsArrays({ numbers: 's', strings: 7 }, schema) | ||
).toThrow() | ||
expect(() => | ||
parseFieldsAsArrays({ numbers: [7], strings: ['s'] }, schema) | ||
).toThrow() | ||
expect(() => | ||
parseFieldsAsArrays({ numbers: 7, strings: { $each: [7] } }, schema) | ||
).toThrow() | ||
}) | ||
}) |