diff --git a/packages/server/src/context.ts b/packages/server/src/context.ts index 69fa71eb..7bdb83a7 100644 --- a/packages/server/src/context.ts +++ b/packages/server/src/context.ts @@ -2,8 +2,6 @@ import type { IsNever } from '@orpc/shared' export type Context = Record -export type TypeInitialContext = (type: T) => unknown - export type MergedContext = T & U export function mergeContext( diff --git a/packages/server/src/middleware-decorated.test.ts b/packages/server/src/middleware-decorated.test.ts index 00c5f483..0b3b43e6 100644 --- a/packages/server/src/middleware-decorated.test.ts +++ b/packages/server/src/middleware-decorated.test.ts @@ -36,21 +36,25 @@ describe('decorateMiddleware', () => { const decorated = decorateMiddleware((options, input, output) => { fn(options, input, output) - return options.next({ context: { auth: true } }) + return options.next({ context: { auth: 1, mid1: true } }) }).concat((options, input, output) => { fn2(options, input, output) - return options.next({}) + return options.next({ context: { auth: 2, mid2: true } }) }) as any next.mockReturnValueOnce('__mocked__') const outputFn = vi.fn() - expect((await decorated({ next }, 'input', outputFn))).toBe('__mocked__') + const signal = AbortSignal.timeout(100) + expect((await decorated({ next, context: { origin: true }, signal }, 'input', outputFn))).toBe('__mocked__') expect(fn).toHaveBeenCalledTimes(1) - expect(fn).toHaveBeenCalledWith({ next: expect.any(Function) }, 'input', outputFn) + expect(fn).toHaveBeenCalledWith({ next: expect.any(Function), context: { origin: true }, signal }, 'input', outputFn) expect(fn2).toHaveBeenCalledTimes(1) - expect(fn2).toHaveBeenCalledWith({ next, context: { auth: true } }, 'input', outputFn) + expect(fn2).toHaveBeenCalledWith({ next: expect.any(Function), context: { origin: true, auth: 1, mid1: true }, signal }, 'input', outputFn) + + expect(next).toHaveBeenCalledTimes(1) + expect(next).toHaveBeenCalledWith({ context: { auth: 2, mid2: true, mid1: true } }) }) it('can concat with map input', async () => { @@ -80,6 +84,9 @@ describe('decorateMiddleware', () => { expect(map).toHaveBeenCalledWith('input') expect(fn2).toHaveBeenCalledTimes(1) - expect(fn2).toHaveBeenCalledWith({ context: { auth: true }, next }, { name: 'input' }, outputFn) + expect(fn2).toHaveBeenCalledWith({ context: { auth: true }, next: expect.any(Function) }, { name: 'input' }, outputFn) + + expect(next).toHaveBeenCalledTimes(1) + expect(next).toHaveBeenCalledWith({ context: { auth: true } }) }) }) diff --git a/packages/server/src/middleware-decorated.ts b/packages/server/src/middleware-decorated.ts index 51b63866..13d67642 100644 --- a/packages/server/src/middleware-decorated.ts +++ b/packages/server/src/middleware-decorated.ts @@ -1,6 +1,6 @@ import type { Meta, ORPCErrorConstructorMap } from '@orpc/contract' import type { Context, MergedContext } from './context' -import type { AnyMiddleware, MapInputMiddleware, Middleware, MiddlewareNextFn } from './middleware' +import type { AnyMiddleware, MapInputMiddleware, Middleware } from './middleware' export interface DecoratedMiddleware< TInContext extends Context, @@ -82,11 +82,14 @@ export function decorateMiddleware< : concatMiddleware const concatted = decorateMiddleware((options, input, output, ...rest) => { - const next: MiddlewareNextFn = async (...[nextOptions]) => { - return mapped({ ...options, context: { ...nextOptions?.context, ...options.context } }, input, output, ...rest) - } - - const merged = middleware({ ...options, next } as any, input as any, output as any, ...rest) + const merged = middleware({ + ...options, + next: (...[nextOptions1]: [any]) => mapped({ + ...options, + context: { ...options.context, ...nextOptions1?.context }, + next: (...[nextOptions2]) => options.next({ context: { ...nextOptions1?.context, ...nextOptions2?.context } }) as any, + }, input, output, ...rest), + } as any, input as any, output as any, ...rest) return merged }) diff --git a/packages/server/src/procedure.ts b/packages/server/src/procedure.ts index 626d592a..a1284bd5 100644 --- a/packages/server/src/procedure.ts +++ b/packages/server/src/procedure.ts @@ -1,6 +1,6 @@ import type { ContractProcedureDef, ErrorMap, Meta, ORPCErrorConstructorMap, Schema, SchemaInput, SchemaOutput } from '@orpc/contract' import type { Promisable } from '@orpc/shared' -import type { Context, TypeInitialContext } from './context' +import type { Context } from './context' import type { AnyMiddleware } from './middleware' import { isContractProcedure } from '@orpc/contract' @@ -41,7 +41,7 @@ export interface ProcedureDef< TErrorMap extends ErrorMap, TMeta extends Meta, > extends ContractProcedureDef { - __initialContext?: TypeInitialContext + __initialContext?: (type: TInitialContext) => unknown middlewares: AnyMiddleware[] inputValidationIndex: number outputValidationIndex: number