-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[shortfin-sd] Reusable service command buffers #963
Changes from 5 commits
bfeeae1
32d0a6c
909dc53
5800327
cfc4f69
f9b206a
c7c01f0
92b9e39
cf1e860
f986eac
2395ff2
845c66b
6277c6d
501b154
796f83f
9502843
d14352c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
|
||
import shortfin as sf | ||
import shortfin.array as sfnp | ||
import numpy as np | ||
|
||
from .io_struct import GenerateReqInput | ||
|
||
|
@@ -41,14 +42,14 @@ class InferenceExecRequest(sf.Message): | |
|
||
def __init__( | ||
self, | ||
prompt: str | None = None, | ||
neg_prompt: str | None = None, | ||
prompt: str | list[str] | None = None, | ||
neg_prompt: str | list[str] | None = None, | ||
height: int | None = None, | ||
width: int | None = None, | ||
steps: int | None = None, | ||
guidance_scale: float | sfnp.device_array | None = None, | ||
seed: int | None = None, | ||
input_ids: list[list[int]] | None = None, | ||
guidance_scale: float | list[float] | sfnp.device_array | None = None, | ||
seed: int | list[int] | None = None, | ||
input_ids: list[list[int]] | list[list[list[int]]] | list[sfnp.device_array] | None = None, | ||
sample: sfnp.device_array | None = None, | ||
prompt_embeds: sfnp.device_array | None = None, | ||
text_embeds: sfnp.device_array | None = None, | ||
|
@@ -58,8 +59,9 @@ def __init__( | |
image_array: sfnp.device_array | None = None, | ||
): | ||
super().__init__() | ||
self.command_buffer = None | ||
self.print_debug = True | ||
|
||
self.batch_size = 1 | ||
self.phases = {} | ||
self.phase = None | ||
self.height = height | ||
|
@@ -74,14 +76,15 @@ def __init__( | |
self.seed = seed | ||
|
||
# Encode phase. | ||
# This is a list of sequenced positive and negative token ids and pooler token ids. | ||
# This is a list of sequenced positive and negative token ids and pooler token ids (tokenizer outputs) | ||
self.input_ids = input_ids | ||
|
||
# Denoise phase. | ||
self.prompt_embeds = prompt_embeds | ||
self.text_embeds = text_embeds | ||
self.sample = sample | ||
self.steps = steps | ||
self.steps_arr = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't look to be used anywhere |
||
self.timesteps = timesteps | ||
self.time_ids = time_ids | ||
self.guidance_scale = guidance_scale | ||
|
@@ -104,36 +107,50 @@ def __init__( | |
|
||
self.post_init() | ||
|
||
@staticmethod | ||
def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest": | ||
gen_inputs = [ | ||
"prompt", | ||
"neg_prompt", | ||
"height", | ||
"width", | ||
"steps", | ||
"guidance_scale", | ||
"seed", | ||
"input_ids", | ||
] | ||
rec_inputs = {} | ||
for item in gen_inputs: | ||
received = getattr(gen_req, item, None) | ||
if isinstance(received, list): | ||
if index >= (len(received)): | ||
if len(received) == 1: | ||
rec_input = received[0] | ||
else: | ||
logging.error( | ||
"Inputs in request must be singular or as many as the list of prompts." | ||
) | ||
else: | ||
rec_input = received[index] | ||
else: | ||
rec_input = received | ||
rec_inputs[item] = rec_input | ||
return InferenceExecRequest(**rec_inputs) | ||
|
||
def set_command_buffer(self, cb): | ||
# Input IDs for CLIP if they are used as inputs instead of prompts. | ||
if self.input_ids: | ||
# Take a batch of sets of input ids as ndarrays and fill cb.input_ids | ||
host_arrs = [None] * 4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where did this 4 come from? Should it be len(cb.input_ids)? |
||
for idx, arr in enumerate(cb.input_ids): | ||
host_arrs[idx] = arr.for_transfer() | ||
for i in range(cb.sample.shape[0]): | ||
with host_arrs[idx].view(i).map(write=True, discard=True) as m: | ||
|
||
# TODO: fix this attr redundancy | ||
np_arr = self.input_ids[i][idx] | ||
|
||
m.fill(np_arr) | ||
cb.input_ids[idx].copy_from(host_arrs[idx]) | ||
|
||
# Same for noisy latents if they are explicitly provided as a numpy array. | ||
if self.sample: | ||
sample_host = cb.sample.for_transfer() | ||
with sample_host.map(discard=True) as m: | ||
m.fill(self.sample.tobytes()) | ||
cb.sample.copy_from(sample_host) | ||
|
||
# Copy other inference parameters for denoise to device arrays. | ||
steps_arr = list(range(0, self.steps)) | ||
steps_host = cb.steps_arr.for_transfer() | ||
steps_host.items = steps_arr | ||
cb.steps_arr.copy_from(steps_host) | ||
|
||
num_step_host = cb.num_steps.for_transfer() | ||
num_step_host.items = [self.steps] | ||
cb.num_steps.copy_from(num_step_host) | ||
|
||
guidance_host = cb.guidance_scale.for_transfer() | ||
with guidance_host.map(discard=True) as m: | ||
# TODO: do this without numpy | ||
np_arr = np.asarray(self.guidance_scale, dtype="float16") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is guidance_scale always fp16 or only for fp16/punet There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops. left hardcoded to avoid parsing the sfnp.DType to np.dtype. Will fix. |
||
|
||
m.fill(np_arr) | ||
cb.guidance_scale.copy_from(guidance_host) | ||
|
||
self.command_buffer = cb | ||
return | ||
|
||
def post_init(self): | ||
"""Determines necessary inference phases and tags them with static program parameters.""" | ||
for p in reversed(list(InferencePhase)): | ||
|
@@ -179,6 +196,37 @@ def reset(self, phase: InferencePhase): | |
self.done = sf.VoidFuture() | ||
self.return_host_array = True | ||
|
||
@staticmethod | ||
def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest": | ||
gen_inputs = [ | ||
"prompt", | ||
"neg_prompt", | ||
"height", | ||
"width", | ||
"steps", | ||
"guidance_scale", | ||
"seed", | ||
"input_ids", | ||
] | ||
rec_inputs = {} | ||
for item in gen_inputs: | ||
received = getattr(gen_req, item, None) | ||
if isinstance(received, list): | ||
if index >= (len(received)): | ||
if len(received) == 1: | ||
rec_input = received[0] | ||
else: | ||
logging.error( | ||
"Inputs in request must be singular or as many as the list of prompts." | ||
) | ||
else: | ||
rec_input = received[index] | ||
else: | ||
rec_input = received | ||
rec_inputs[item] = rec_input | ||
req = InferenceExecRequest(**rec_inputs) | ||
return req | ||
|
||
|
||
class StrobeMessage(sf.Message): | ||
"""Sent to strobe a queue with fake activity (generate a wakeup).""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we hardcoding this? Also I don't see it used anywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.exec_request.batch_size is used in the inference process. I'll fix the hardcode once larger batch sizes are actually possible.