diff --git a/package.json b/package.json index 9153c90..59e7dd4 100644 --- a/package.json +++ b/package.json @@ -58,7 +58,8 @@ "got@<11.8.5": ">=11.8.5", "semver": ">=7.5.2", "word-wrap": ">=1.2.4", - "braces": ">=3.0.3" + "braces": ">=3.0.3", + "micromatch": ">=4.0.8" } }, "exports": { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 44b1cc1..1f9ef58 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -12,6 +12,7 @@ overrides: semver: '>=7.5.2' word-wrap: '>=1.2.4' braces: '>=3.0.3' + micromatch: '>=4.0.8' importers: @@ -2079,8 +2080,8 @@ packages: merge@2.1.1: resolution: {integrity: sha512-jz+Cfrg9GWOZbQAnDQ4hlVnQky+341Yk5ru8bZSe6sIDTCIg8n9i/u7hSQGSVOF3C7lH6mGtqjkiT9G4wFLL0w==} - micromatch@4.0.5: - resolution: {integrity: sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==} + micromatch@4.0.8: + resolution: {integrity: sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==} engines: {node: '>=8.6'} mimic-fn@1.2.0: @@ -3379,7 +3380,7 @@ snapshots: jest-util: 28.1.3 jest-validate: 28.1.3 jest-watcher: 28.1.3 - micromatch: 4.0.5 + micromatch: 4.0.8 pretty-format: 28.1.3 rimraf: 3.0.2 slash: 3.0.0 @@ -3498,7 +3499,7 @@ snapshots: jest-haste-map: 28.1.3 jest-regex-util: 28.0.2 jest-util: 28.1.3 - micromatch: 4.0.5 + micromatch: 4.0.8 pirates: 4.0.5 slash: 3.0.0 write-file-atomic: 4.0.2 @@ -4468,7 +4469,7 @@ snapshots: '@nodelib/fs.walk': 1.2.8 glob-parent: 5.1.2 merge2: 1.4.1 - micromatch: 4.0.5 + micromatch: 4.0.8 fast-glob@3.3.2: dependencies: @@ -4476,7 +4477,7 @@ snapshots: '@nodelib/fs.walk': 1.2.8 glob-parent: 5.1.2 merge2: 1.4.1 - micromatch: 4.0.5 + micromatch: 4.0.8 fast-json-stable-stringify@2.1.0: {} @@ -5016,7 +5017,7 @@ snapshots: jest-runner: 28.1.3 jest-util: 28.1.3 jest-validate: 28.1.3 - micromatch: 4.0.5 + micromatch: 4.0.8 parse-json: 5.2.0 pretty-format: 28.1.3 slash: 3.0.0 @@ -5077,7 +5078,7 @@ snapshots: jest-regex-util: 28.0.2 jest-util: 28.1.3 jest-worker: 28.1.3 - micromatch: 4.0.5 + micromatch: 4.0.8 walker: 1.0.8 optionalDependencies: fsevents: 2.3.3 @@ -5108,7 +5109,7 @@ snapshots: '@types/stack-utils': 2.0.1 chalk: 4.1.2 graceful-fs: 4.2.10 - micromatch: 4.0.5 + micromatch: 4.0.8 pretty-format: 28.1.3 slash: 3.0.0 stack-utils: 2.0.5 @@ -5120,7 +5121,7 @@ snapshots: '@types/stack-utils': 2.0.1 chalk: 4.1.2 graceful-fs: 4.2.10 - micromatch: 4.0.5 + micromatch: 4.0.8 pretty-format: 29.0.2 slash: 3.0.0 stack-utils: 2.0.5 @@ -5452,7 +5453,7 @@ snapshots: merge@2.1.1: {} - micromatch@4.0.5: + micromatch@4.0.8: dependencies: braces: 3.0.3 picomatch: 2.3.1 diff --git a/src/zod-utils.ts b/src/zod-utils.ts index abd35a2..b6e8ec5 100644 --- a/src/zod-utils.ts +++ b/src/zod-utils.ts @@ -35,6 +35,57 @@ export const parseFieldsAsArrays = >( ) as T } +/** + * Merge all objects into a single one. If a property appears on multiple members, it + * will create a ZodUnion with all possible values for that property. + * + * This is important because, when dealing with a union of multiple objects where all + * fields are optional, zod won't be able to differentiate between them in a union, and + * will end up always choosing the first union item, even if it means stripping some fields. + * @param schemas + */ +export const mergeObjectSchemas = (schemas: z.ZodTypeAny[]): z.ZodTypeAny => { + const objects = schemas.filter( + (schema): schema is z.ZodObject => schema instanceof z.ZodObject + ) + // It could be a union of a ZodObject and a ZodNumber for example + const otherTypes = schemas.filter( + (schema) => !(schema instanceof z.ZodObject) + ) + + // Get all possible ZodTypes for each property + const mergedTypes = objects + .flatMap((obj) => Object.entries(obj.shape) as [string, ZodTypeAny][]) + .reduce((acc, [key, value]): Record => { + if (key in acc) return { ...acc, [key]: [...acc[key], value] } + return { ...acc, [key]: [value] } + }, {} as Record) + + const mergedShape = Object.fromEntries( + Object.entries(mergedTypes).map(([key, values]) => { + // This is not 100% because it will only dedupe by reference + const uniqueValues = Array.from(new Set(values)) + const valueOrUnion = + uniqueValues.length === 1 + ? uniqueValues[0] + : ZodUnion.create( + uniqueValues as [ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]] + ) + // If the value doesn't appear in all objects, make sure it's optional + const optionalValue = + values.length < objects.length + ? ZodOptional.create(valueOrUnion) + : valueOrUnion + return [key, optionalValue] + }) + ) + + const allOptions = [...otherTypes, ZodObject.create(mergedShape)] + return allOptions.length === 1 + ? allOptions[0] + : ZodUnion.create(allOptions as [ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]]) +} + // Copied from github.com/colinhacks/zod/blob/6dad90785398885f7b058f5c0760d5ae5476b833/src/types.ts#L2189-L2217 // and extended to support unions and discriminated unions export const zodDeepPartial = (schema: ZodTypeAny): ZodTypeAny => { @@ -57,9 +108,8 @@ export const zodDeepPartial = (schema: ZodTypeAny): ZodTypeAny => { } else if (schema instanceof ZodUnion) { type Options = [ZodTypeAny, ZodTypeAny, ...ZodTypeAny[]] const options = schema._def.options as Options - return ZodUnion.create( - options.map((option) => zodDeepPartial(option)) as Options - ) as any + const partialOptions = options.map(zodDeepPartial) + return mergeObjectSchemas(partialOptions) } else if (schema instanceof ZodDiscriminatedUnion) { const types = Object.values(schema.options) as (AnyZodObject & ZodRawShape)[] diff --git a/test/zod-utils.test.ts b/test/zod-utils.test.ts index eacc3ad..55d5c40 100644 --- a/test/zod-utils.test.ts +++ b/test/zod-utils.test.ts @@ -1,5 +1,52 @@ import * as z from 'zod' -import { zodDeepPartial, parseFieldsAsArrays } from '../src/zod-utils' +import { + zodDeepPartial, + parseFieldsAsArrays, + mergeObjectSchemas, +} from '../src/zod-utils' + +describe('mergeObjectSchemas', () => { + test('should keep the properties of aa single object', () => { + const schema = mergeObjectSchemas([ + z.object({ n: z.number(), s: z.string().optional() }), + ]) + expect(schema.parse({ n: 1 })).toEqual({ n: 1 }) + expect(schema.parse({ n: 1, s: 'a' })).toEqual({ n: 1, s: 'a' }) + expect(schema.parse({ n: 1, unrelated: 'a' })).toEqual({ n: 1 }) + expect(() => schema.parse({ s: 'a' })).toThrow() + }) + test('should merge the properties of multiple objects', () => { + const schema = mergeObjectSchemas([ + z.object({ type: z.literal('a'), n: z.number() }), + z.object({ type: z.literal('b'), s: z.string() }), + z.object({ n: z.boolean().optional() }), + ]) + + expect(schema.parse({})).toEqual({}) + expect(schema.parse({ n: 1 })).toEqual({ n: 1 }) + expect(schema.parse({ n: true })).toEqual({ n: true }) + + expect(schema.parse({ type: 'a' })).toEqual({ type: 'a' }) + expect(schema.parse({ type: 'b' })).toEqual({ type: 'b' }) + expect(() => schema.parse({ type: 'c' })).toThrow() + + expect(schema.parse({ s: 'b' })).toEqual({ s: 'b' }) + }) + test('should be compatible with non-object types', () => { + const schema = mergeObjectSchemas([ + z.object({ type: z.literal('a'), n: z.number() }), + z.object({ type: z.literal('b'), s: z.string() }), + z.number(), + z.null(), + ]) + + expect(schema.parse(1)).toEqual(1) + expect(schema.parse(null)).toEqual(null) + expect(() => schema.parse({ type: 'c' })).toThrow() + expect(() => schema.parse({})).toThrow() + expect(() => schema.parse('str')).toThrow() + }) +}) describe('zodDeepPartial', () => { test('partial of primitives', () => { @@ -34,18 +81,59 @@ describe('zodDeepPartial', () => { }) test('partial of unions', () => { - const deepPartial = zodDeepPartial( - z.union([ - z.object({ s: z.string() }).strict(), - z.object({ n: z.number() }).strict(), - ]) - ) + const ZTestUnion = z.union([ + z.object({ s: z.string() }).strict(), + z.object({ n: z.number() }).strict(), + ]) + const deepPartial = zodDeepPartial(ZTestUnion) expect(deepPartial.parse({ s: 'string' })).toEqual({ s: 'string' }) expect(deepPartial.parse({ n: 10 })).toEqual({ n: 10 }) expect(deepPartial.parse({})).toEqual({}) expect(() => deepPartial.parse({ n: 'string' })).toThrow() + + const ZNestedTestUnion = z.object({ + field: z.number(), + union: ZTestUnion, + }) + const nestedDeepPartial = zodDeepPartial(ZNestedTestUnion) + expect( + nestedDeepPartial.parse({ field: 1, union: { s: 'string' } }) + ).toEqual({ field: 1, union: { s: 'string' } }) + expect(nestedDeepPartial.parse({ field: 1, union: { n: 10 } })).toEqual({ + field: 1, + union: { n: 10 }, + }) + expect(nestedDeepPartial.parse({ field: 1, union: {} })).toEqual({ + field: 1, + union: {}, + }) + + const deepPartial2 = zodDeepPartial( + z.union([ + z.object({ + t: z.literal('a'), + s: z.string(), + }), + z.object({ t: z.literal('b'), n: z.number() }), + ]) + ) + + expect(deepPartial2.parse({ s: 'a' })).toEqual({ s: 'a' }) + expect(deepPartial2.parse({ n: 2 })).toEqual({ n: 2 }) + + const ZUnionOfUnion = z.union([ + z.object({ randomField: z.string().optional() }), + ZTestUnion, + ]) + const deepPartial3 = zodDeepPartial(ZUnionOfUnion) + expect(deepPartial3.parse({ s: 'a' })).toEqual({ s: 'a' }) + expect(deepPartial3.parse({ n: 2 })).toEqual({ n: 2 }) + expect(deepPartial3.parse({ randomField: 'r', n: 3 })).toEqual({ + randomField: 'r', + n: 3, + }) }) test('partial of discriminated union', () => {