Skip to content

Commit

Permalink
Fix: Deep partial stripping out fields in union (#53)
Browse files Browse the repository at this point in the history
* Improve test

* Implement mergeObjectSchemas

* Comment

* Fix vulnerability

* Add test for nested object types

* Add test for nested union
  • Loading branch information
cau777 authored Sep 23, 2024
1 parent f7b68ae commit b26a9d4
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 22 deletions.
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
"got@<11.8.5": ">=11.8.5",
"semver": ">=7.5.2",
"word-wrap": ">=1.2.4",
"braces": ">=3.0.3"
"braces": ">=3.0.3",
"micromatch": ">=4.0.8"
}
},
"exports": {
Expand Down
23 changes: 12 additions & 11 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 53 additions & 3 deletions src/zod-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,57 @@ export const parseFieldsAsArrays = <T extends Record<string, unknown>>(
) as T
}

/**
* Merge all objects into a single one. If a property appears on multiple members, it
* will create a ZodUnion with all possible values for that property.
*
* This is important because, when dealing with a union of multiple objects where all
* fields are optional, zod won't be able to differentiate between them in a union, and
* will end up always choosing the first union item, even if it means stripping some fields.
* @param schemas
*/
export const mergeObjectSchemas = (schemas: z.ZodTypeAny[]): z.ZodTypeAny => {
const objects = schemas.filter(
(schema): schema is z.ZodObject<any> => schema instanceof z.ZodObject
)
// It could be a union of a ZodObject and a ZodNumber for example
const otherTypes = schemas.filter(
(schema) => !(schema instanceof z.ZodObject)
)

// Get all possible ZodTypes for each property
const mergedTypes = objects
.flatMap((obj) => Object.entries(obj.shape) as [string, ZodTypeAny][])
.reduce((acc, [key, value]): Record<string, ZodTypeAny[]> => {
if (key in acc) return { ...acc, [key]: [...acc[key], value] }
return { ...acc, [key]: [value] }
}, {} as Record<string, ZodTypeAny[]>)

const mergedShape = Object.fromEntries(
Object.entries(mergedTypes).map(([key, values]) => {
// This is not 100% because it will only dedupe by reference
const uniqueValues = Array.from(new Set(values))
const valueOrUnion =
uniqueValues.length === 1
? uniqueValues[0]
: ZodUnion.create(
uniqueValues as [ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]]
)
// If the value doesn't appear in all objects, make sure it's optional
const optionalValue =
values.length < objects.length
? ZodOptional.create(valueOrUnion)
: valueOrUnion
return [key, optionalValue]
})
)

const allOptions = [...otherTypes, ZodObject.create(mergedShape)]
return allOptions.length === 1
? allOptions[0]
: ZodUnion.create(allOptions as [ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]])
}

// Copied from github.com/colinhacks/zod/blob/6dad90785398885f7b058f5c0760d5ae5476b833/src/types.ts#L2189-L2217
// and extended to support unions and discriminated unions
export const zodDeepPartial = (schema: ZodTypeAny): ZodTypeAny => {
Expand All @@ -57,9 +108,8 @@ export const zodDeepPartial = (schema: ZodTypeAny): ZodTypeAny => {
} else if (schema instanceof ZodUnion) {
type Options = [ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]]
const options = schema._def.options as Options
return ZodUnion.create(
options.map((option) => zodDeepPartial(option)) as Options
) as any
const partialOptions = options.map(zodDeepPartial)
return mergeObjectSchemas(partialOptions)
} else if (schema instanceof ZodDiscriminatedUnion) {
const types = Object.values(schema.options) as (AnyZodObject &
ZodRawShape)[]
Expand Down
102 changes: 95 additions & 7 deletions test/zod-utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,52 @@
import * as z from 'zod'
import { zodDeepPartial, parseFieldsAsArrays } from '../src/zod-utils'
import {
zodDeepPartial,
parseFieldsAsArrays,
mergeObjectSchemas,
} from '../src/zod-utils'

describe('mergeObjectSchemas', () => {
test('should keep the properties of aa single object', () => {
const schema = mergeObjectSchemas([
z.object({ n: z.number(), s: z.string().optional() }),
])
expect(schema.parse({ n: 1 })).toEqual({ n: 1 })
expect(schema.parse({ n: 1, s: 'a' })).toEqual({ n: 1, s: 'a' })
expect(schema.parse({ n: 1, unrelated: 'a' })).toEqual({ n: 1 })
expect(() => schema.parse({ s: 'a' })).toThrow()
})
test('should merge the properties of multiple objects', () => {
const schema = mergeObjectSchemas([
z.object({ type: z.literal('a'), n: z.number() }),
z.object({ type: z.literal('b'), s: z.string() }),
z.object({ n: z.boolean().optional() }),
])

expect(schema.parse({})).toEqual({})
expect(schema.parse({ n: 1 })).toEqual({ n: 1 })
expect(schema.parse({ n: true })).toEqual({ n: true })

expect(schema.parse({ type: 'a' })).toEqual({ type: 'a' })
expect(schema.parse({ type: 'b' })).toEqual({ type: 'b' })
expect(() => schema.parse({ type: 'c' })).toThrow()

expect(schema.parse({ s: 'b' })).toEqual({ s: 'b' })
})
test('should be compatible with non-object types', () => {
const schema = mergeObjectSchemas([
z.object({ type: z.literal('a'), n: z.number() }),
z.object({ type: z.literal('b'), s: z.string() }),
z.number(),
z.null(),
])

expect(schema.parse(1)).toEqual(1)
expect(schema.parse(null)).toEqual(null)
expect(() => schema.parse({ type: 'c' })).toThrow()
expect(() => schema.parse({})).toThrow()
expect(() => schema.parse('str')).toThrow()
})
})

describe('zodDeepPartial', () => {
test('partial of primitives', () => {
Expand Down Expand Up @@ -34,18 +81,59 @@ describe('zodDeepPartial', () => {
})

test('partial of unions', () => {
const deepPartial = zodDeepPartial(
z.union([
z.object({ s: z.string() }).strict(),
z.object({ n: z.number() }).strict(),
])
)
const ZTestUnion = z.union([
z.object({ s: z.string() }).strict(),
z.object({ n: z.number() }).strict(),
])
const deepPartial = zodDeepPartial(ZTestUnion)

expect(deepPartial.parse({ s: 'string' })).toEqual({ s: 'string' })
expect(deepPartial.parse({ n: 10 })).toEqual({ n: 10 })
expect(deepPartial.parse({})).toEqual({})

expect(() => deepPartial.parse({ n: 'string' })).toThrow()

const ZNestedTestUnion = z.object({
field: z.number(),
union: ZTestUnion,
})
const nestedDeepPartial = zodDeepPartial(ZNestedTestUnion)
expect(
nestedDeepPartial.parse({ field: 1, union: { s: 'string' } })
).toEqual({ field: 1, union: { s: 'string' } })
expect(nestedDeepPartial.parse({ field: 1, union: { n: 10 } })).toEqual({
field: 1,
union: { n: 10 },
})
expect(nestedDeepPartial.parse({ field: 1, union: {} })).toEqual({
field: 1,
union: {},
})

const deepPartial2 = zodDeepPartial(
z.union([
z.object({
t: z.literal('a'),
s: z.string(),
}),
z.object({ t: z.literal('b'), n: z.number() }),
])
)

expect(deepPartial2.parse({ s: 'a' })).toEqual({ s: 'a' })
expect(deepPartial2.parse({ n: 2 })).toEqual({ n: 2 })

const ZUnionOfUnion = z.union([
z.object({ randomField: z.string().optional() }),
ZTestUnion,
])
const deepPartial3 = zodDeepPartial(ZUnionOfUnion)
expect(deepPartial3.parse({ s: 'a' })).toEqual({ s: 'a' })
expect(deepPartial3.parse({ n: 2 })).toEqual({ n: 2 })
expect(deepPartial3.parse({ randomField: 'r', n: 3 })).toEqual({
randomField: 'r',
n: 3,
})
})

test('partial of discriminated union', () => {
Expand Down

0 comments on commit b26a9d4

Please sign in to comment.