Skip to content

Commit

Permalink
Fix regression in how array input values are transformed (#266)
Browse files Browse the repository at this point in the history
* Fix regression in how array input values are transformed

* Refactor tests
  • Loading branch information
mattt authored May 29, 2024
1 parent e059286 commit d981fc1
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 34 deletions.
106 changes: 74 additions & 32 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -185,42 +185,84 @@ describe("Replicate client", () => {
});

describe("predictions.create", () => {
test("Calls the correct API route with the correct payload", async () => {
nock(BASE_URL)
.post("/predictions")
.reply(200, {
id: "ufawqhfynnddngldkgtslldrkq",
model: "replicate/hello-world",
version:
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
urls: {
get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
cancel:
"https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel",
},
created_at: "2022-04-26T22:13:06.224088Z",
started_at: null,
completed_at: null,
status: "starting",
input: {
text: "Alice",
const predictionTestCases = [
{
description: "String input",
input: {
text: "Alice",
},
},
{
description: "Number input",
input: {
text: 123,
},
},
{
description: "Boolean input",
input: {
text: true,
},
},
{
description: "Array input",
input: {
text: ["Alice", "Bob", "Charlie"],
},
},
{
description: "Object input",
input: {
text: {
name: "Alice",
},
output: null,
error: null,
logs: null,
metrics: {},
});
const prediction = await client.predictions.create({
},
},
].map((testCase) => ({
...testCase,
expectedResponse: {
id: "ufawqhfynnddngldkgtslldrkq",
model: "replicate/hello-world",
version:
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
input: {
text: "Alice",
urls: {
get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
cancel:
"https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel",
},
webhook: "http://test.host/webhook",
webhook_events_filter: ["output", "completed"],
});
expect(prediction.id).toBe("ufawqhfynnddngldkgtslldrkq");
});
input: testCase.input,
created_at: "2022-04-26T22:13:06.224088Z",
started_at: null,
completed_at: null,
status: "starting",
},
}));

test.each(predictionTestCases)(
"$description",
async ({ input, expectedResponse }) => {
nock(BASE_URL)
.post("/predictions", {
version:
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
input: input as Record<string, any>,
webhook: "http://test.host/webhook",
webhook_events_filter: ["output", "completed"],
})
.reply(200, expectedResponse);

const response = await client.predictions.create({
version:
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
input: input as Record<string, any>,
webhook: "http://test.host/webhook",
webhook_events_filter: ["output", "completed"],
});

expect(response.input).toEqual(input);
expect(response.status).toBe(expectedResponse.status);
}
);

const fileTestCases = [
// Skip test case if File type is not available
Expand Down
5 changes: 3 additions & 2 deletions lib/util.js
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,10 @@ async function transformFileInputsToBase64EncodedDataURIs(inputs) {
// Walk a JavaScript object and transform the leaf values.
async function transform(value, mapper) {
if (Array.isArray(value)) {
let copy = [];
const copy = [];
for (const val of value) {
copy = await transform(val, mapper);
const transformed = await transform(val, mapper);
copy.push(transformed);
}
return copy;
}
Expand Down

0 comments on commit d981fc1

Please sign in to comment.