Skip to content

Commit

Permalink
Feat: Zod Collection (#46)
Browse files Browse the repository at this point in the history
* Feat: Zod Collection

* Comment

* Tweaks
  • Loading branch information
cau777 authored Jun 13, 2024
1 parent 1c1359d commit 778cc43
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ export * from './converter'
export * from './middleware'
export * from './time-collection'
export * from './types'
export * from './zod-collection'
49 changes: 49 additions & 0 deletions src/zod-collection.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import { TsReadWriteCollection } from './collection'
import z from 'zod'
import { convertReadWriteCollection } from './converter'
import { OptionalUnlessRequiredId, WithId, Document } from 'mongodb'
import { WithTime } from './time-collection'
import { parseFieldsAsArrays, zodDeepPartial } from './zod-utils'

type WithIdTime<T> = WithId<WithTime<T>>

/**
* Create a collection that uses zod to validate inserted data and
* Mongo operators at runtime
* @param collection
* @param schema
*/
export const convertToZodCollection = <TSchema extends Document>(
collection: TsReadWriteCollection<TSchema, WithIdTime<TSchema>>,
schema: z.ZodType<TSchema>
): TsReadWriteCollection<TSchema, WithIdTime<TSchema>> =>
convertReadWriteCollection(collection, {
preInsert: (
obj: OptionalUnlessRequiredId<TSchema>
): OptionalUnlessRequiredId<TSchema> => {
return schema.parse(obj) as OptionalUnlessRequiredId<TSchema>
},
preUpdate: (obj) => {
const partialSchema = zodDeepPartial(schema)
return {
...obj,
// TODO: validate other operators
...(obj.$pull
? { $pull: parseFieldsAsArrays(obj.$pull, partialSchema) }
: {}),
...(obj.$push
? { $push: parseFieldsAsArrays(obj.$push, partialSchema) }
: {}),
...(obj.$addToSet
? { $addToSet: parseFieldsAsArrays(obj.$addToSet, partialSchema) }
: {}),
...(obj.$set ? { $set: partialSchema.parse(obj.$set) } : {}),
...(obj.$setOnInsert
? { $setOnInsert: partialSchema.parse(obj.$setOnInsert) }
: {}),
}
},
preReplace: (obj) => schema.parse(obj),
postFind: (obj) => obj,
preFilter: (obj) => obj,
})
100 changes: 100 additions & 0 deletions src/zod-utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import z, {
AnyZodObject,
ZodArray,
ZodDiscriminatedUnion,
ZodIntersection,
ZodNullable,
ZodObject,
ZodOptional,
ZodRawShape,
ZodTuple,
ZodTypeAny,
ZodUnion,
} from 'zod'

export const parseFieldsAsArrays = <T extends Record<string, unknown>>(
obj: T,
schema: z.ZodTypeAny
): T => {
// convert all fields to arrays and parse them
const parsed = schema.parse(
// eslint-disable-next-line custom-rules/prefer-map-to-object-from-entries
Object.fromEntries(
Object.entries(obj).map(([key, value]) => {
// Handle adding values to an array using { $each: [value], ...otherModifiers }
if (typeof value === 'object' && value !== null && '$each' in value)
return [key, value.$each]
return [key, [value]]
})
)
)
// returns only fields parsed in their initial flattened shape
// eslint-disable-next-line custom-rules/prefer-map-to-object-from-entries
return Object.fromEntries(
Object.keys(parsed).map((key: keyof T) => [key, obj[key]])
) as T
}

// Copied from github.com/colinhacks/zod/blob/6dad90785398885f7b058f5c0760d5ae5476b833/src/types.ts#L2189-L2217
// and extended to support unions and discriminated unions
export const zodDeepPartial = (schema: ZodTypeAny): ZodTypeAny => {
if (schema instanceof ZodObject) {
const newShape: any = {}

for (const key in schema.shape) {
const fieldSchema = schema.shape[key]
newShape[key] = ZodOptional.create(zodDeepPartial(fieldSchema))
}
return new ZodObject({
...schema._def,
shape: () => newShape,
}) as any
} else if (schema instanceof ZodIntersection) {
return ZodIntersection.create(
zodDeepPartial(schema._def.left),
zodDeepPartial(schema._def.right)
)
} else if (schema instanceof ZodUnion) {
type Options = [ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]]
const options = schema._def.options as Options
return ZodUnion.create(
options.map((option) => zodDeepPartial(option)) as Options
) as any
} else if (schema instanceof ZodDiscriminatedUnion) {
const types = Object.values(schema.options) as (AnyZodObject &
ZodRawShape)[]
const discriminator = schema.discriminator
const newTypes = types.map((type) => {
const newShape: any = {}

for (const key in type.shape) {
const fieldSchema = type.shape[key]
if (key === discriminator) {
newShape[key] = fieldSchema
} else {
newShape[key] = ZodOptional.create(zodDeepPartial(fieldSchema))
}
}
return new ZodObject({
...type._def,
shape: () => newShape,
} as any)
}) as any
return ZodDiscriminatedUnion.create(discriminator, newTypes) as any
} else if (schema instanceof ZodArray) {
return new ZodArray({
...schema._def,
type: zodDeepPartial(schema.element),
})
} else if (schema instanceof ZodOptional) {
return ZodOptional.create(zodDeepPartial(schema.unwrap()))
} else if (schema instanceof ZodNullable) {
return ZodNullable.create(zodDeepPartial(schema.unwrap()))
} else if (schema instanceof ZodTuple) {
return ZodTuple.create(
schema.items.map((item: any) => zodDeepPartial(item))
)
} else {
return schema
}
}
131 changes: 131 additions & 0 deletions test/zod-utils.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import * as z from 'zod'
import { zodDeepPartial, parseFieldsAsArrays } from '../src/zod-utils'

describe('zodDeepPartial', () => {
test('partial of primitives', () => {
const deepPartial = zodDeepPartial(
z.object({
s: z.string(),
n: z.number(),
i: z.bigint(),
b: z.boolean(),
z: z.null(),
u: z.undefined(),
})
)

const correctValues = {
s: 'string',
n: 10,
i: BigInt(2),
b: true,
z: null,
u: undefined,
}

expect(deepPartial.parse({})).toEqual({})
expect(deepPartial.parse({ z: null })).toEqual({ z: null })
expect(deepPartial.parse({ u: undefined })).toEqual({ u: undefined })
expect(deepPartial.parse(correctValues)).toEqual(correctValues)
expect(() => deepPartial.parse({ s: 10 })).toThrow()
expect(() => deepPartial.parse({ n: 'a' })).toThrow()
expect(() => deepPartial.parse({ i: 'a' })).toThrow()
expect(() => deepPartial.parse({ b: 'a' })).toThrow()
})

test('partial of unions', () => {
const deepPartial = zodDeepPartial(
z.union([
z.object({ s: z.string() }).strict(),
z.object({ n: z.number() }).strict(),
])
)

expect(deepPartial.parse({ s: 'string' })).toEqual({ s: 'string' })
expect(deepPartial.parse({ n: 10 })).toEqual({ n: 10 })
expect(deepPartial.parse({})).toEqual({})

expect(() => deepPartial.parse({ n: 'string' })).toThrow()
})

test('partial of discriminated union', () => {
const deepPartial = zodDeepPartial(
z.discriminatedUnion('t', [
z.object({ t: z.literal('s1'), s: z.string() }),
z.object({ t: z.literal('n1'), n: z.number() }),
])
)

expect(deepPartial.parse({ t: 's1' })).toEqual({ t: 's1' })
expect(deepPartial.parse({ t: 's1', s: 's' })).toEqual({ t: 's1', s: 's' })
expect(deepPartial.parse({ t: 'n1', n: 10 })).toEqual({ t: 'n1', n: 10 })

expect(() => deepPartial.parse({ n: 10 })).toThrow()
expect(() => deepPartial.parse({ s: 's' })).toThrow()
expect(() => deepPartial.parse({})).toThrow()
})

test('partial of intersection', () => {
const deepPartial = zodDeepPartial(
z.object({ t: z.literal('t1') }).and(z.object({ y: z.literal('y1') }))
)

expect(deepPartial.parse({})).toEqual({})
expect(deepPartial.parse({ t: 't1' })).toEqual({ t: 't1' })
expect(deepPartial.parse({ t: 't1', y: 'y1' })).toEqual({
t: 't1',
y: 'y1',
})

expect(() => deepPartial.parse({ t: 't2' })).toThrow()
expect(() => deepPartial.parse({ y: 'y2' })).toThrow()
})
})

describe('parseFieldsAsArrays', () => {
const schema = z.object({
strings: z.string().array(),
numbers: z.number().array(),
})

test('should parse valid values successfully', () => {
expect(parseFieldsAsArrays({ numbers: 6, strings: 's' }, schema)).toEqual({
numbers: 6,
strings: 's',
})
const withEachOperator = {
numbers: 6,
strings: { $each: ['s'], $otherOperator: 5 },
}
expect(parseFieldsAsArrays(withEachOperator, schema)).toEqual(
withEachOperator
)
})

test('should remove extra fields', () => {
expect(
parseFieldsAsArrays(
{ numbers: 6, strings: 's', objects: {}, maps: new Map() },
schema
)
).toEqual({
numbers: 6,
strings: 's',
})
})

test('should throw for invalid values', () => {
expect(() =>
parseFieldsAsArrays({ numbers: 's', strings: 7 }, schema)
).toThrow()
expect(() =>
parseFieldsAsArrays({ numbers: 's', strings: 7 }, schema)
).toThrow()
expect(() =>
parseFieldsAsArrays({ numbers: [7], strings: ['s'] }, schema)
).toThrow()
expect(() =>
parseFieldsAsArrays({ numbers: 7, strings: { $each: [7] } }, schema)
).toThrow()
})
})

0 comments on commit 778cc43

Please sign in to comment.