Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shortfin-sd] Reusable service command buffers #963

Merged
merged 17 commits into from
Feb 14, 2025
Merged

Conversation

monorimet
Copy link
Contributor

@monorimet monorimet commented Feb 13, 2025

This PR non-trivially restructures the SDXL inference process to interface with reusable command buffers.

At service startup, we attach to each fiber a set of reusable command buffers (a set of empty device arrays to be reused throughout the inference process), and upon assigning each exec request a fiber (now the "meta_fiber" equipped with command buffers), we pop an appropriate command buffer from the meta fiber, mount the command buffer on the request, and return the command buffer to the meta_fiber for reuse.

Each meta_fiber is equipped with an explicit number of command buffers, and if this count is exceeded by the number of InferenceExecProcesses assigned to the fiber, a new command buffer will be initialized ad-hoc. This shouldn't happen if service is managed properly, as we shouldn't oversubscribe fibers too heavily (thundering herd).

These changes also simplify the request structure throughout the inference processor, operating on a single request as opposed to a bundle. This greatly simplifies the work required to enable larger batch sizes, and reduces how much we need to handle program I/O during execution.

We also remove all await device usage where it wasn't required. It should only be used after VAE decode, where we are done with the GPU and need to use the result contents of a device->host transfer. fyi @daveliddell

@monorimet monorimet marked this pull request as ready for review February 13, 2025 17:49
@monorimet monorimet changed the title (Draft) [shortfin-sd] Reusable service command buffers [shortfin-sd] Reusable service command buffers Feb 13, 2025
Copy link
Contributor

@IanNod IanNod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few nit comments but otherwise I think it looks good

self.print_debug = True

self.batch_size = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we hardcoding this? Also I don't see it used anywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.exec_request.batch_size is used in the inference process. I'll fix the hardcode once larger batch sizes are actually possible.

self.input_ids = input_ids

# Denoise phase.
self.prompt_embeds = prompt_embeds
self.text_embeds = text_embeds
self.sample = sample
self.steps = steps
self.steps_arr = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't look to be used anywhere

guidance_host = cb.guidance_scale.for_transfer()
with guidance_host.map(discard=True) as m:
# TODO: do this without numpy
np_arr = np.asarray(self.guidance_scale, dtype="float16")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is guidance_scale always fp16 or only for fp16/punet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. left hardcoded to avoid parsing the sfnp.DType to np.dtype. Will fix.

Copy link
Contributor

@daveliddell daveliddell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very elegant solution to the management of the device arrays!

All of my comments are at your discretion to disregard, as always. :-)

# Input IDs for CLIP if they are used as inputs instead of prompts.
if self.input_ids is not None:
# Take a batch of sets of input ids as ndarrays and fill cb.input_ids
host_arrs = [None] * 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did this 4 come from? Should it be len(cb.input_ids)?

@@ -350,366 +359,293 @@ def __init__(
service: GenerateService,
fiber,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A meta fiber is the same thing as an EquippedFiber, right? If so, can you use one or the other name consistently? Also, it would be nice if function arguments taking the thing would be called meta_fiber or equipped_fiber to distinguish it from raw sf fibers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... regrets. Best to nip this in the bud.

await self._postprocess(device=device)
await device
self.exec_request.done.set_success()
self.meta_fiber.command_buffers.append(self.exec_request.command_buffer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does returning the command buffer to the meta fiber (equipped fiber) need to be done outside the try block, so that an exception doesn't cause the command buffer to be lost?

self.assign_command_buffer(self.exec_request)

device = self.fiber.device(0)
await device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this wait for? If it's for metrics, we might need to break the actual "inference process" out of the run function:

async def run(self):
    ...
    device = self.fiber.device(0)
    await device
    run_inference(device)

@measure(...)
aync def run_inference(device):
    phases = ...
    _prepare, _encode, etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we assign a command buffer to the request in the beginning of the inference process, we copy some host arrays (anything user-submitted like input_ids as numpy arrays) and await here for the copy to device to finish.
The await after the phases is extraneous and will be removed.

Copy link
Contributor

@daveliddell daveliddell Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to wait for the copy to device to finish before starting a program? I thought that as long as they're sequenced correctly on a fiber, shortfin and/or the GPU command buffer guarantees correct order of execution

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You know, that makes sense to me.... let's try it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need to wait for it to finish. The only time you should need to await the device is if you copy back to the host and want to inspect the results (on the host).

sfnp.device_array.for_device(
device, [req_bs, self.service.model_params.max_seq_len], sfnp.sint64
),
]
host_arrs = [None] * 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another hard-coded 4. Should it be len(cb.input_ids}?

Copy link
Contributor Author

@monorimet monorimet Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For SDXL inference this is a very specific shape that isn't tied to a config variable (basically it's 4 because CLIP takes positive, negative, positive pooled, and negative pooled ids from the tokenizers.). I don't think it should be driven by the shape of a command buffer element -- it won't ever change for SDXL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meh... since we're doing h2d here it probably makes sense just to align them.

sfnp.fill_randn(sample_host, generator=generator)

cb.sample.copy_from(sample_host)
await device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new await device that wasn't in the original code. Is it needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, as it succeeds a host to device copy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't need await on h2d so long as you are not mutating the host side buffer while the pipeline is progressing.

(image,) = await fn(latents, fiber=self.fiber)

(cb.images,) = await fn(cb.latents, fiber=self.fiber)
cb.images_host.copy_from(cb.images)
await device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pulled this await device out of my (now obsolete) PR, as we're not accessing the actual bytes of the buffer, only the buffer metadata, which you can do without waiting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decode step transfers the VAE output to host, then creates a numpy array with the result. Doesn't the numpy array instantiation need the buffer contents and not just its metadata? Hence the await device

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like np.frombuffer does a shallow copy, which would mean the data is free to come in later. Anyway, Stella was the one who pointed out that none of the await devices should be necessary until postprocess, where you actually use the data. I had consistent results doing just one await device before preprocess, not that I was pounding on the server particularly hard. :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you only need to await when observing the bytes on the host. Just juggling views aren't touching the contents. With that said, if you were trying to have a "safety await" and didn't know for sure that the numpy API you are using only takes a view, then it would be a reasonable thing to do.

if phases[InferencePhase.DECODE]["required"]:
await self._decode(device=device0, requests=self.exec_requests)
await self._decode(device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right in here is where I had added (in my now obsolete PR):

            # Postprocessing needs the output data to be on the host.  Even
            # without postprocessing, we're done with the GPU, so we wait for
            # it to finish here.
            await device

And I also removed the await device from below and from _decode. Basically, none of the phases except _postprocess are using the data from the device buffers, so the whole thing can free run up to this point.

worker_idx = self.service.get_worker_index(fiber)
logger.debug(f"Sending batch to fiber {fiber_idx} (worker {worker_idx})")
self.board(batch["reqs"], fiber=fiber)
fiber = self.service.idle_meta_fibers.pop(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename this fiber variable to meta_fiber? I thought (improbably) that I had caught a type mismatch on line 307 when I saw meta_fiber=fiber. :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch!

@monorimet monorimet merged commit 16099b1 into main Feb 14, 2025
37 of 38 checks passed
@monorimet monorimet deleted the sdxl_prealloc_buffers branch February 14, 2025 18:00
renxida pushed a commit to renxida/shark-ai that referenced this pull request Feb 20, 2025
This PR non-trivially restructures the SDXL inference process to
interface with reusable command buffers.

At service startup, we attach to each fiber a set of reusable command
buffers (a set of empty device arrays to be reused throughout the
inference process), and upon assigning each exec request a fiber (now
the "meta_fiber" equipped with command buffers), we pop an appropriate
command buffer from the meta fiber, mount the command buffer on the
request, and return the command buffer to the meta_fiber for reuse.

Each meta_fiber is equipped with an explicit number of command buffers,
and if this count is exceeded by the number of InferenceExecProcesses
assigned to the fiber, a new command buffer will be initialized ad-hoc.
This _shouldn't_ happen if service is managed properly, as we shouldn't
oversubscribe fibers too heavily (thundering herd).

These changes also simplify the request structure throughout the
inference processor, operating on a _single_ request as opposed to a
bundle. This greatly simplifies the work required to enable larger batch
sizes, and reduces how much we need to handle program I/O during
execution.

We also remove all `await device` usage where it wasn't required. It
should only be used after VAE decode, where we are done with the GPU and
need to use the result contents of a device->host transfer. fyi
@daveliddell

---------

Co-authored-by: Ean Garvey <ean.garvey@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants