-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[shortfin-sd] Reusable service command buffers #963
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
few nit comments but otherwise I think it looks good
self.print_debug = True | ||
|
||
self.batch_size = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we hardcoding this? Also I don't see it used anywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.exec_request.batch_size is used in the inference process. I'll fix the hardcode once larger batch sizes are actually possible.
self.input_ids = input_ids | ||
|
||
# Denoise phase. | ||
self.prompt_embeds = prompt_embeds | ||
self.text_embeds = text_embeds | ||
self.sample = sample | ||
self.steps = steps | ||
self.steps_arr = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't look to be used anywhere
guidance_host = cb.guidance_scale.for_transfer() | ||
with guidance_host.map(discard=True) as m: | ||
# TODO: do this without numpy | ||
np_arr = np.asarray(self.guidance_scale, dtype="float16") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is guidance_scale always fp16 or only for fp16/punet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops. left hardcoded to avoid parsing the sfnp.DType to np.dtype. Will fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very elegant solution to the management of the device arrays!
All of my comments are at your discretion to disregard, as always. :-)
# Input IDs for CLIP if they are used as inputs instead of prompts. | ||
if self.input_ids is not None: | ||
# Take a batch of sets of input ids as ndarrays and fill cb.input_ids | ||
host_arrs = [None] * 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where did this 4 come from? Should it be len(cb.input_ids)?
@@ -350,366 +359,293 @@ def __init__( | |||
service: GenerateService, | |||
fiber, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A meta fiber is the same thing as an EquippedFiber, right? If so, can you use one or the other name consistently? Also, it would be nice if function arguments taking the thing would be called meta_fiber or equipped_fiber to distinguish it from raw sf fibers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah... regrets. Best to nip this in the bud.
await self._postprocess(device=device) | ||
await device | ||
self.exec_request.done.set_success() | ||
self.meta_fiber.command_buffers.append(self.exec_request.command_buffer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does returning the command buffer to the meta fiber (equipped fiber) need to be done outside the try block, so that an exception doesn't cause the command buffer to be lost?
self.assign_command_buffer(self.exec_request) | ||
|
||
device = self.fiber.device(0) | ||
await device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this wait for? If it's for metrics, we might need to break the actual "inference process" out of the run function:
async def run(self):
...
device = self.fiber.device(0)
await device
run_inference(device)
@measure(...)
aync def run_inference(device):
phases = ...
_prepare, _encode, etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we assign a command buffer to the request in the beginning of the inference process, we copy some host arrays (anything user-submitted like input_ids as numpy arrays) and await here for the copy to device to finish.
The await after the phases is extraneous and will be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to wait for the copy to device to finish before starting a program? I thought that as long as they're sequenced correctly on a fiber, shortfin and/or the GPU command buffer guarantees correct order of execution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You know, that makes sense to me.... let's try it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need to wait for it to finish. The only time you should need to await the device is if you copy back to the host and want to inspect the results (on the host).
sfnp.device_array.for_device( | ||
device, [req_bs, self.service.model_params.max_seq_len], sfnp.sint64 | ||
), | ||
] | ||
host_arrs = [None] * 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another hard-coded 4. Should it be len(cb.input_ids}?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For SDXL inference this is a very specific shape that isn't tied to a config variable (basically it's 4 because CLIP takes positive, negative, positive pooled, and negative pooled ids from the tokenizers.). I don't think it should be driven by the shape of a command buffer element -- it won't ever change for SDXL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meh... since we're doing h2d here it probably makes sense just to align them.
sfnp.fill_randn(sample_host, generator=generator) | ||
|
||
cb.sample.copy_from(sample_host) | ||
await device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a new await device that wasn't in the original code. Is it needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, as it succeeds a host to device copy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't need await on h2d so long as you are not mutating the host side buffer while the pipeline is progressing.
(image,) = await fn(latents, fiber=self.fiber) | ||
|
||
(cb.images,) = await fn(cb.latents, fiber=self.fiber) | ||
cb.images_host.copy_from(cb.images) | ||
await device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pulled this await device out of my (now obsolete) PR, as we're not accessing the actual bytes of the buffer, only the buffer metadata, which you can do without waiting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decode step transfers the VAE output to host, then creates a numpy array with the result. Doesn't the numpy array instantiation need the buffer contents and not just its metadata? Hence the await device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like np.frombuffer does a shallow copy, which would mean the data is free to come in later. Anyway, Stella was the one who pointed out that none of the await devices should be necessary until postprocess, where you actually use the data. I had consistent results doing just one await device before preprocess, not that I was pounding on the server particularly hard. :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you only need to await when observing the bytes on the host. Just juggling views aren't touching the contents. With that said, if you were trying to have a "safety await" and didn't know for sure that the numpy API you are using only takes a view, then it would be a reasonable thing to do.
if phases[InferencePhase.DECODE]["required"]: | ||
await self._decode(device=device0, requests=self.exec_requests) | ||
await self._decode(device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right in here is where I had added (in my now obsolete PR):
# Postprocessing needs the output data to be on the host. Even
# without postprocessing, we're done with the GPU, so we wait for
# it to finish here.
await device
And I also removed the await device from below and from _decode. Basically, none of the phases except _postprocess are using the data from the device buffers, so the whole thing can free run up to this point.
worker_idx = self.service.get_worker_index(fiber) | ||
logger.debug(f"Sending batch to fiber {fiber_idx} (worker {worker_idx})") | ||
self.board(batch["reqs"], fiber=fiber) | ||
fiber = self.service.idle_meta_fibers.pop(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename this fiber
variable to meta_fiber
? I thought (improbably) that I had caught a type mismatch on line 307 when I saw meta_fiber=fiber
. :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch!
This PR non-trivially restructures the SDXL inference process to interface with reusable command buffers. At service startup, we attach to each fiber a set of reusable command buffers (a set of empty device arrays to be reused throughout the inference process), and upon assigning each exec request a fiber (now the "meta_fiber" equipped with command buffers), we pop an appropriate command buffer from the meta fiber, mount the command buffer on the request, and return the command buffer to the meta_fiber for reuse. Each meta_fiber is equipped with an explicit number of command buffers, and if this count is exceeded by the number of InferenceExecProcesses assigned to the fiber, a new command buffer will be initialized ad-hoc. This _shouldn't_ happen if service is managed properly, as we shouldn't oversubscribe fibers too heavily (thundering herd). These changes also simplify the request structure throughout the inference processor, operating on a _single_ request as opposed to a bundle. This greatly simplifies the work required to enable larger batch sizes, and reduces how much we need to handle program I/O during execution. We also remove all `await device` usage where it wasn't required. It should only be used after VAE decode, where we are done with the GPU and need to use the result contents of a device->host transfer. fyi @daveliddell --------- Co-authored-by: Ean Garvey <ean.garvey@amd.com>
This PR non-trivially restructures the SDXL inference process to interface with reusable command buffers.
At service startup, we attach to each fiber a set of reusable command buffers (a set of empty device arrays to be reused throughout the inference process), and upon assigning each exec request a fiber (now the "meta_fiber" equipped with command buffers), we pop an appropriate command buffer from the meta fiber, mount the command buffer on the request, and return the command buffer to the meta_fiber for reuse.
Each meta_fiber is equipped with an explicit number of command buffers, and if this count is exceeded by the number of InferenceExecProcesses assigned to the fiber, a new command buffer will be initialized ad-hoc. This shouldn't happen if service is managed properly, as we shouldn't oversubscribe fibers too heavily (thundering herd).
These changes also simplify the request structure throughout the inference processor, operating on a single request as opposed to a bundle. This greatly simplifies the work required to enable larger batch sizes, and reduces how much we need to handle program I/O during execution.
We also remove all
await device
usage where it wasn't required. It should only be used after VAE decode, where we are done with the GPU and need to use the result contents of a device->host transfer. fyi @daveliddell