Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Deep partial stripping out fields in union #53

Merged
merged 6 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
"got@<11.8.5": ">=11.8.5",
"semver": ">=7.5.2",
"word-wrap": ">=1.2.4",
"braces": ">=3.0.3"
"braces": ">=3.0.3",
"micromatch": ">=4.0.8"
}
},
"exports": {
Expand Down
23 changes: 12 additions & 11 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 53 additions & 3 deletions src/zod-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,57 @@ export const parseFieldsAsArrays = <T extends Record<string, unknown>>(
) as T
}

/**
* Merge all objects into a single one. If a property appears on multiple members, it
* will create a ZodUnion with all possible values for that property.
*
* This is important because, when dealing with a union of multiple objects where all
* fields are optional, zod won't be able to differentiate between them in a union, and
* will end up always choosing the first union item, even if it means stripping some fields.
* @param schemas
*/
export const mergeObjectSchemas = (schemas: z.ZodTypeAny[]): z.ZodTypeAny => {
const objects = schemas.filter(
(schema): schema is z.ZodObject<any> => schema instanceof z.ZodObject
)
// It could be a union of a ZodObject and a ZodNumber for example
const otherTypes = schemas.filter(
(schema) => !(schema instanceof z.ZodObject)
)

// Get all possible ZodTypes for each property
const mergedTypes = objects
.flatMap((obj) => Object.entries(obj.shape) as [string, ZodTypeAny][])
.reduce((acc, [key, value]): Record<string, ZodTypeAny[]> => {
if (key in acc) return { ...acc, [key]: [...acc[key], value] }
return { ...acc, [key]: [value] }
}, {} as Record<string, ZodTypeAny[]>)

const mergedShape = Object.fromEntries(
Object.entries(mergedTypes).map(([key, values]) => {
// This is not 100% because it will only dedupe by reference
const uniqueValues = Array.from(new Set(values))
const valueOrUnion =
uniqueValues.length === 1
? uniqueValues[0]
: ZodUnion.create(
uniqueValues as [ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]]
)
// If the value doesn't appear in all objects, make sure it's optional
const optionalValue =
values.length < objects.length
? ZodOptional.create(valueOrUnion)
: valueOrUnion
return [key, optionalValue]
})
)

const allOptions = [...otherTypes, ZodObject.create(mergedShape)]
return allOptions.length === 1
? allOptions[0]
: ZodUnion.create(allOptions as [ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]])
}

// Copied from github.com/colinhacks/zod/blob/6dad90785398885f7b058f5c0760d5ae5476b833/src/types.ts#L2189-L2217
// and extended to support unions and discriminated unions
export const zodDeepPartial = (schema: ZodTypeAny): ZodTypeAny => {
Expand All @@ -57,9 +108,8 @@ export const zodDeepPartial = (schema: ZodTypeAny): ZodTypeAny => {
} else if (schema instanceof ZodUnion) {
type Options = [ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]]
const options = schema._def.options as Options
return ZodUnion.create(
options.map((option) => zodDeepPartial(option)) as Options
) as any
const partialOptions = options.map(zodDeepPartial)
return mergeObjectSchemas(partialOptions)
} else if (schema instanceof ZodDiscriminatedUnion) {
const types = Object.values(schema.options) as (AnyZodObject &
ZodRawShape)[]
Expand Down
102 changes: 95 additions & 7 deletions test/zod-utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,52 @@
import * as z from 'zod'
import { zodDeepPartial, parseFieldsAsArrays } from '../src/zod-utils'
import {
zodDeepPartial,
parseFieldsAsArrays,
mergeObjectSchemas,
} from '../src/zod-utils'

describe('mergeObjectSchemas', () => {
test('should keep the properties of aa single object', () => {
const schema = mergeObjectSchemas([
z.object({ n: z.number(), s: z.string().optional() }),
])
expect(schema.parse({ n: 1 })).toEqual({ n: 1 })
expect(schema.parse({ n: 1, s: 'a' })).toEqual({ n: 1, s: 'a' })
expect(schema.parse({ n: 1, unrelated: 'a' })).toEqual({ n: 1 })
expect(() => schema.parse({ s: 'a' })).toThrow()
})
test('should merge the properties of multiple objects', () => {
const schema = mergeObjectSchemas([
z.object({ type: z.literal('a'), n: z.number() }),
z.object({ type: z.literal('b'), s: z.string() }),
z.object({ n: z.boolean().optional() }),
])

expect(schema.parse({})).toEqual({})
expect(schema.parse({ n: 1 })).toEqual({ n: 1 })
expect(schema.parse({ n: true })).toEqual({ n: true })

expect(schema.parse({ type: 'a' })).toEqual({ type: 'a' })
expect(schema.parse({ type: 'b' })).toEqual({ type: 'b' })
expect(() => schema.parse({ type: 'c' })).toThrow()

expect(schema.parse({ s: 'b' })).toEqual({ s: 'b' })
})
test('should be compatible with non-object types', () => {
const schema = mergeObjectSchemas([
z.object({ type: z.literal('a'), n: z.number() }),
z.object({ type: z.literal('b'), s: z.string() }),
z.number(),
z.null(),
])

expect(schema.parse(1)).toEqual(1)
expect(schema.parse(null)).toEqual(null)
expect(() => schema.parse({ type: 'c' })).toThrow()
expect(() => schema.parse({})).toThrow()
expect(() => schema.parse('str')).toThrow()
})
})

describe('zodDeepPartial', () => {
test('partial of primitives', () => {
Expand Down Expand Up @@ -34,18 +81,59 @@ describe('zodDeepPartial', () => {
})

test('partial of unions', () => {
const deepPartial = zodDeepPartial(
z.union([
z.object({ s: z.string() }).strict(),
z.object({ n: z.number() }).strict(),
])
)
const ZTestUnion = z.union([
z.object({ s: z.string() }).strict(),
z.object({ n: z.number() }).strict(),
])
const deepPartial = zodDeepPartial(ZTestUnion)

expect(deepPartial.parse({ s: 'string' })).toEqual({ s: 'string' })
expect(deepPartial.parse({ n: 10 })).toEqual({ n: 10 })
expect(deepPartial.parse({})).toEqual({})

expect(() => deepPartial.parse({ n: 'string' })).toThrow()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Can we add a test for nested object types? Here we are only fixing the issue for unions passed directly, but if the problematic union is on an object property, I think we would still have the same issue:
/** PR fixes parsing this type */
const ZTestUnion = z.union([
  z.object({ a: z.number().optional() }),
  z.object({ b: z.number().optional() })
[)

const ZNestedTestUnion = z.object({
  field: z.number()
  union: ZTestUnion
})
  1. If a test using the above object does not work, we should probably think of applying mergeObjectSchemas recursively. I'm mostly worried about affecting Typescript performance with that (I don't think it will cause performance issues on runtime though).
  2. Given the above is true, we should apply this on zodDeepPartial even if the result is an object

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this test and it passes. This works because zodDeepPartial is recursive: https://github.com/cau777/ts-mongo/blob/a6011a6b376c276db4c0d0b32bb81871da79e1e4/src/zod-utils.ts#L97-L97

Copy link
Collaborator

@nicolassanmar nicolassanmar Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Makes sense

const ZNestedTestUnion = z.object({
field: z.number(),
union: ZTestUnion,
})
const nestedDeepPartial = zodDeepPartial(ZNestedTestUnion)
expect(
nestedDeepPartial.parse({ field: 1, union: { s: 'string' } })
).toEqual({ field: 1, union: { s: 'string' } })
expect(nestedDeepPartial.parse({ field: 1, union: { n: 10 } })).toEqual({
field: 1,
union: { n: 10 },
})
expect(nestedDeepPartial.parse({ field: 1, union: {} })).toEqual({
field: 1,
union: {},
})

const deepPartial2 = zodDeepPartial(
z.union([
z.object({
t: z.literal('a'),
s: z.string(),
}),
z.object({ t: z.literal('b'), n: z.number() }),
])
)

expect(deepPartial2.parse({ s: 'a' })).toEqual({ s: 'a' })
expect(deepPartial2.parse({ n: 2 })).toEqual({ n: 2 })
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a test for a union that includes this union?:

const ZTestUnion = 
      z.union([
        z.object({
          t: z.literal('a'),
          s: z.string(),
        }),
        z.object({ t: z.literal('b'), n: z.number() }),
      ])
    
const ZUnionOfUnion = z.union([z.object({ randomField: z.string().optional()}), ZTestUnion])
const deepPartial3 = zodDeepPartial(ZUnionOfUnion)
expect(deepPartial3.parse({s: 'hi'})).toEqual({s: 'hi'})


const ZUnionOfUnion = z.union([
z.object({ randomField: z.string().optional() }),
ZTestUnion,
])
const deepPartial3 = zodDeepPartial(ZUnionOfUnion)
expect(deepPartial3.parse({ s: 'a' })).toEqual({ s: 'a' })
expect(deepPartial3.parse({ n: 2 })).toEqual({ n: 2 })
expect(deepPartial3.parse({ randomField: 'r', n: 3 })).toEqual({
randomField: 'r',
n: 3,
})
})

test('partial of discriminated union', () => {
Expand Down