From d981fc14fbb45748ecf032d735cb8e4f55b1aa15 Mon Sep 17 00:00:00 2001 From: Mattt Date: Wed, 29 May 2024 11:04:09 -0700 Subject: [PATCH] Fix regression in how array input values are transformed (#266) * Fix regression in how array input values are transformed * Refactor tests --- index.test.ts | 106 +++++++++++++++++++++++++++++++++++--------------- lib/util.js | 5 ++- 2 files changed, 77 insertions(+), 34 deletions(-) diff --git a/index.test.ts b/index.test.ts index 7502969..834b786 100644 --- a/index.test.ts +++ b/index.test.ts @@ -185,42 +185,84 @@ describe("Replicate client", () => { }); describe("predictions.create", () => { - test("Calls the correct API route with the correct payload", async () => { - nock(BASE_URL) - .post("/predictions") - .reply(200, { - id: "ufawqhfynnddngldkgtslldrkq", - model: "replicate/hello-world", - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - urls: { - get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - cancel: - "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", - }, - created_at: "2022-04-26T22:13:06.224088Z", - started_at: null, - completed_at: null, - status: "starting", - input: { - text: "Alice", + const predictionTestCases = [ + { + description: "String input", + input: { + text: "Alice", + }, + }, + { + description: "Number input", + input: { + text: 123, + }, + }, + { + description: "Boolean input", + input: { + text: true, + }, + }, + { + description: "Array input", + input: { + text: ["Alice", "Bob", "Charlie"], + }, + }, + { + description: "Object input", + input: { + text: { + name: "Alice", }, - output: null, - error: null, - logs: null, - metrics: {}, - }); - const prediction = await client.predictions.create({ + }, + }, + ].map((testCase) => ({ + ...testCase, + expectedResponse: { + id: "ufawqhfynnddngldkgtslldrkq", + model: "replicate/hello-world", version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - input: { - text: "Alice", + urls: { + get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", + cancel: + "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", }, - webhook: "http://test.host/webhook", - webhook_events_filter: ["output", "completed"], - }); - expect(prediction.id).toBe("ufawqhfynnddngldkgtslldrkq"); - }); + input: testCase.input, + created_at: "2022-04-26T22:13:06.224088Z", + started_at: null, + completed_at: null, + status: "starting", + }, + })); + + test.each(predictionTestCases)( + "$description", + async ({ input, expectedResponse }) => { + nock(BASE_URL) + .post("/predictions", { + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: input as Record, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + }) + .reply(200, expectedResponse); + + const response = await client.predictions.create({ + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: input as Record, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + }); + + expect(response.input).toEqual(input); + expect(response.status).toBe(expectedResponse.status); + } + ); const fileTestCases = [ // Skip test case if File type is not available diff --git a/lib/util.js b/lib/util.js index e164899..3745d9f 100644 --- a/lib/util.js +++ b/lib/util.js @@ -310,9 +310,10 @@ async function transformFileInputsToBase64EncodedDataURIs(inputs) { // Walk a JavaScript object and transform the leaf values. async function transform(value, mapper) { if (Array.isArray(value)) { - let copy = []; + const copy = []; for (const val of value) { - copy = await transform(val, mapper); + const transformed = await transform(val, mapper); + copy.push(transformed); } return copy; }