Skip to content

Commit

Permalink
[Misc][Simulator] Update vllm simulator backend (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeldaHuang authored Oct 8, 2024
1 parent f4a617c commit e9cf870
Show file tree
Hide file tree
Showing 14 changed files with 195 additions and 91 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Visit our [documentation](./docs/) to get started:
- [QuickStart](./docs/Quickstart.md)
- [Supported Models](./docs/Supported_Models.md)
- [Fault Tolerance](./docs/Fault_Tolerance.md)
- [Simulator](./docs/Simulator.md)

## Performance
We evaluate the performance of the KV-cache-aware load-balancing scheduler and migration mechanism of Llumnix with 16 Llama2-7B/Qwen1.5-7B instances, each using an A10 GPU (24GB).
Expand Down
57 changes: 57 additions & 0 deletions docs/Simulator.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Getting Started
Llumnix can generate latency data from logs. After run a real benchmark with `--log-instance-info`, you can find a `$LOG_FILENAME.csv` file.

After running profiling with `python llumnix.backends.profiling.py`. You can get a `$PROFILING_RESULT_FILE_PATH.pkl`

Then, you can run simulator with `--profiling-result-file-path PROFILING_RESULT_FILE_PATH`.


```
usage: -m llumnix.backends.profiling [-h]
[--database PROFILING_RESULT_FILE_PATH]
[--log-csv-path CSV_FILE_PATH]
[--model MODEL_NAME]
[--tp TENSOR_PARALLEL_SIZE]
[--pp PIPELINE_PARALLEL_SIZE]
[--gpu-memory-utilization GPU_MEMORY_UTILIZATION]
[--block-size BLOCK_SIZE]
[--max-num-batched-tokens MAX_NUM_BATCHED_TOKENS]
[--num-gpu-blocks NUM_GPU_BLOCKS]
[--new-data]
```

`--database`
- Path to profiling result file.

`--log-csv-path`
- Path to real llumnix benchmark csv file.

`--model`
- Name of model (same as huggingface model name when use vllm).

`--tp`
- Number of tensor parallel replicas.
- Default: 1

`--pp`
- Number of pipeline parallel replicas.
- Default: 1

`--gpu-memory-utilization`
- The fraction of GPU memory to be used for the model executor.
- Default: 0.9

`--block-size`
- Token block size for contiguous chunks of tokens.
- Default: 16

`--max-num-batched-tokens`
- Maximum number of batched tokens per iteration.
- Default: 16

`--num-gpu-blocks`
- Number of gpu blocks profiled by inference engine.

`--new-data`
- create a new profiling result file, otherwise write to a exist file.

5 changes: 0 additions & 5 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class EngineManagerArgs:
log_instance_info: bool = None
profiling_result_file_path: str = None

gpu_type: str = None
migration_backend_init_timeout: float = None
migration_backend: str = None
migration_cache_blocks: int = None
Expand Down Expand Up @@ -193,10 +192,6 @@ def add_cli_args(
type=str,
help='profiling result file path')

parser.add_argument('--gpu-type',
type=str,
help='gpu type specified when using simulator')

parser.add_argument('--migration-backend',
type=str,
choices=['gloo','nccl','rpc'],
Expand Down
48 changes: 20 additions & 28 deletions llumnix/backends/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

# 2D parallel configuration
# (gpu, tensor parallel, pipeline parallel)
SimParallelConfig = namedtuple("SimParallelConfig", ("gpu", "tp", "pp"))
SimParallelConfig = namedtuple("SimParallelConfig", ("tp", "pp"))
# vllm blocks gpu cache configuration
SimCacheConfig = namedtuple("SimCacheConfig", ("gpu_memory_utilization", "block_size", "max_num_batched_tokens"))

Expand Down Expand Up @@ -56,7 +56,7 @@ class LatencyMemData:

def add_latency_result(self, inference_type: RequestInferenceType, batch_size: int, tot_seq_len: int, latency: List[float]):
if inference_type == RequestInferenceType.PREFILL:
self.prefill_latency[batch_size] = latency
self.prefill_latency[tot_seq_len] = latency
else:
self.decode_latency[(batch_size, tot_seq_len)] = latency

Expand Down Expand Up @@ -146,20 +146,17 @@ def update(self, result: ProfilingResult):

def _extract_data(self, row):
"""Extract the profiling results from a row of the profiling CSV file."""
inference_type = RequestInferenceType.PREFILL if row["inference_type"] == "prefill" else RequestInferenceType.DECODE
# assert pp==1
stage_latencies = [float(row["latency"])]
batch_size = _pad_to_alignment(int(row["bs"]), 8)
tot_seq_len = 0
seq_lens_str = row["seq_lens"].strip('"[]"').split(",")
for len_str in seq_lens_str:
if len_str != "":
tot_seq_len += int(len_str)
tot_seq_len = _pad_to_alignment(tot_seq_len, 8)
profiling_data = row["profiling_data"].strip('"()"').split(",")
inference_type = RequestInferenceType.PREFILL if profiling_data[0] == "'prefill'" else RequestInferenceType.DECODE
batch_size = _pad_to_alignment(int(profiling_data[1]), 8)
tot_seq_len =_pad_to_alignment(int(profiling_data[2]), 8)
stage_latencies = [float(profiling_data[3])]

return stage_latencies, inference_type, batch_size, tot_seq_len

def update_from_instance_log(self, file_name: str, model: str, parallel_config: SimParallelConfig):
df = pd.read_csv(file_name+"_instance.csv")
df = pd.read_csv(file_name)
df = df[df['bs'] > 0]
# read lines
if model not in self.results:
Expand All @@ -180,7 +177,7 @@ def model_decode(x, a, b, c):
bs, tot_seq_len = x
return a * bs + b * tot_seq_len + c

def get_latency_mem(backend_type: BackendType, profiling_database: ProfilingDatabase, gpu_type: str, **backend_args):
def get_latency_mem(backend_type: BackendType, profiling_database: ProfilingDatabase, **backend_args):
assert BackendType.is_sim_backend(backend_type)
if backend_type == BackendType.SIM_VLLM:
# TODO(ZeldaHuang): support multi-lora, more device, vision language model
Expand All @@ -194,7 +191,7 @@ def get_latency_mem(backend_type: BackendType, profiling_database: ProfilingData
model_name = model_name[:-1]
model_name = os.path.basename(model_name)
profiling_result: ProfilingResult = profiling_database.get(model_name)
sim_parallel_config = SimParallelConfig(gpu_type, parallel_config.tensor_parallel_size,
sim_parallel_config = SimParallelConfig(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
assert sim_parallel_config in profiling_result.para_dict.keys(), "sim parallel config not in database"
latency_mem: LatencyMemData = profiling_result.para_dict[sim_parallel_config]
Expand All @@ -205,29 +202,24 @@ def get_latency_mem(backend_type: BackendType, profiling_database: ProfilingData
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--database", type=str, default="profiling.pkl")
parser.add_argument("--log-path", type=str)
parser.add_argument("--model", type=str)
parser.add_argument("--gpu", type=str, default="a10")
parser.add_argument("--log-csv-path", type=str, required=True)
parser.add_argument("--model", type=str, help="filename of your model, like 'Meta-Llama-3-8B-Instruct'")
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--block-size", type=int, default=16)
parser.add_argument("--max-num-batched-tokens", type=int, default=8000)
parser.add_argument("--num-gpu-blocks", type=int, default=0)
parser.add_argument("--num-gpu-blocks", type=int, required=True, help="kv cache blocks number")
parser.add_argument("--new-data", action="store_true")
parser.add_argument("--fit", action="store_true")

args = parser.parse_args()
args_parallel_config = SimParallelConfig(args.gpu, args.tp, args.pp)
args_parallel_config = SimParallelConfig(args.tp, args.pp)
args_cache_config = SimCacheConfig(args.gpu_memory_utilization, args.block_size, args.max_num_batched_tokens)
database = ProfilingDatabase(args.database, args.new_data)
if args.log_path:
database.update_from_instance_log(args.log_path, args.model, args_parallel_config)
if args.fit:
model_result = database.get(args.model)
model_result.fit_from_database(parallel_config=args_parallel_config)
if args.num_gpu_blocks:
model_result = database.get(args.model)
model_result.add_cache_result(args_parallel_config, args_cache_config, args.num_gpu_blocks)
database.update_from_instance_log(args.log_csv_path, args.model, args_parallel_config)
model_result = database.get(args.model)
model_result.fit_from_database(parallel_config=args_parallel_config)
model_result = database.get(args.model)
model_result.add_cache_result(args_parallel_config, args_cache_config, args.num_gpu_blocks)

database.materialize()
1 change: 1 addition & 0 deletions llumnix/backends/vllm/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def execute_model(
decode_bs += meta_data.token_chunk_size
decode_seq_len += list(meta_data.seq_data.values())[0].get_len()
decode_bs = _pad_to_alignment(decode_bs, 8)
prefill_seq_len = _pad_to_alignment(prefill_seq_len, 8)
latency = 0
if prefill_seq_len:
latency += self.latency_mem.prefill_latency[prefill_seq_len][0] if prefill_seq_len in self.latency_mem.prefill_latency \
Expand Down
8 changes: 5 additions & 3 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,10 @@ def step(self) -> None:
instance_info.instance_id = self.instance_id
instance_info.step_id = next(self.step_counter)
instance_info.timestamp = time.time()
instance_info.latency = self.model_executor.last_inference_latency

instance_info.profiling_data=(instance_info.inference_type.value,
instance_info.num_seqs,
sum(instance_info.running_seq_lens),
self.model_executor.last_inference_latency)
seq_groups = self.scheduler.running
if seq_groups:
tot_blocks = []
Expand All @@ -189,7 +191,7 @@ def update_instance_info(self, instance_info: InstanceInfo) -> None:
instance_info.instance_id = self.instance_info.instance_id
instance_info.step_id = self.instance_info.step_id
instance_info.timestamp = self.instance_info.timestamp
instance_info.latency = self.instance_info.latency
instance_info.profiling_data = self.instance_info.profiling_data
instance_info.num_blocks_last_running_request = self.instance_info.num_blocks_last_running_request
self.instance_info = instance_info

Expand Down
38 changes: 20 additions & 18 deletions llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@

from llumnix.instance_info import InstanceInfo
from llumnix.logger import init_logger
from llumnix.llumlet.request import LlumnixRequest, RequestInferenceType
from llumnix.llumlet.request import RequestInferenceType
from llumnix.backends.vllm.utils import scheduler_lock
from llumnix.backends.vllm.sequence import SequenceGroupLlumnix

logger = init_logger(__name__)

Expand Down Expand Up @@ -55,15 +56,15 @@ def __init__(self, *args, **kwargs) -> None:
enable_caching=self.cache_config.enable_prefix_caching)
self.pre_alloc_cache_dict: Dict[str, BlockTable] = {}
self.scheduler_lock = threading.Lock()
self.migrating_out_request_last_stage: List[LlumnixRequest] = []
self.migrating_out_request_last_stage: List[SequenceGroupLlumnix] = []

def add_update_instance_info_callback(self, update_instance_info_callback):
self.update_instance_info_callback = update_instance_info_callback
self.update_instance_info_callback(self._get_instance_info())
self.update_instance_info_callback(self._get_instance_info([]))

def _preempt(
self,
seq_group: LlumnixRequest,
seq_group: SequenceGroupLlumnix,
blocks_to_swap_out: Dict[int, int],
preemption_mode: Optional[PreemptionMode] = None,
) -> PreemptionMode:
Expand All @@ -90,7 +91,7 @@ def get_all_request_ids(self) -> List[str]:
return request_ids

@scheduler_lock
def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> List[int]:
def get_request_incremental_blocks(self, backend_request: SequenceGroupLlumnix, pre_stage_num_blocks: int) -> List[int]:
seq = backend_request.get_seqs()[0]
blocks = self.block_manager.get_block_table(seq)
return blocks[pre_stage_num_blocks:]
Expand All @@ -104,13 +105,13 @@ def remove_running_request(self, request_id: str) -> None:
seq.status = SequenceStatus.WAITING
break

def add_migrating_out_request_last_stage(self, backend_request: LlumnixRequest) -> None:
def add_migrating_out_request_last_stage(self, backend_request: SequenceGroupLlumnix) -> None:
self.migrating_out_request_last_stage.append(backend_request)

def remove_migrating_out_request_last_stage(self, backend_request: LlumnixRequest) -> None:
def remove_migrating_out_request_last_stage(self, backend_request: SequenceGroupLlumnix) -> None:
self.migrating_out_request_last_stage.remove(backend_request)

def pop_migrating_out_requests_last_stage(self) -> List[LlumnixRequest]:
def pop_migrating_out_requests_last_stage(self) -> List[SequenceGroupLlumnix]:
migrating_out_request_last_stage = self.migrating_out_request_last_stage.copy()
self.migrating_out_request_last_stage.clear()
return migrating_out_request_last_stage
Expand All @@ -125,13 +126,13 @@ def pre_alloc(self, request_id: str, block_num: int) -> List[int]:
return blocks

@scheduler_lock
def add_running_request(self, backend_request: LlumnixRequest) -> None:
def add_running_request(self, backend_request: SequenceGroupLlumnix) -> None:
seq = backend_request.get_seqs()[0]
seq.status = SequenceStatus.RUNNING
self.running.append(backend_request)

@scheduler_lock
def is_request_running(self, backend_request: LlumnixRequest) -> bool:
def is_request_running(self, backend_request: SequenceGroupLlumnix) -> bool:
return backend_request in self.running

@scheduler_lock
Expand All @@ -149,12 +150,12 @@ def free_dst_pre_alloc_cache(self, request_id: str = None) -> None:
self.block_manager._free_block_table(blocks)

@scheduler_lock
def free_src_request(self, backend_request: LlumnixRequest) -> None:
def free_src_request(self, backend_request: SequenceGroupLlumnix) -> None:
seq = backend_request.get_seqs()[0]
logger.info("free seq {}".format(seq.seq_id))
self.free_seq(seq)

def _get_instance_info(self) -> InstanceInfo:
def _get_instance_info(self, scheduled_seq_groups: List[SequenceGroupLlumnix]) -> InstanceInfo:
num_total_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
num_used_gpu_blocks = num_total_gpu_blocks - num_free_gpu_blocks
Expand Down Expand Up @@ -188,21 +189,22 @@ def _get_instance_info(self) -> InstanceInfo:
waiting_time_first_waiting_request=waiting_time_first_waiting_request,
num_blocks_all_waiting_requests=num_blocks_all_waiting_requests,
)
for seq_group in self.running:
for seq_group in scheduled_seq_groups:
instance_info.running_seq_lens.extend([seq.get_len() for seq in seq_group.get_seqs()])
instance_info.num_seq = len(instance_info.running_seq_lens)
if self.running:
instance_info.inference_type = self.running[-1].inference_type
instance_info.num_seqs = len(instance_info.running_seq_lens)
if scheduled_seq_groups:
instance_info.inference_type = scheduled_seq_groups[-1].inference_type
# TODO(ZeldaHuang) adapt chunked-prefill
instance_info.num_batched_tokens = sum([seq_group.request_len for seq_group in self.running])\
instance_info.num_batched_tokens = sum([seq_group.request_len for seq_group in scheduled_seq_groups])\
if instance_info.inference_type == RequestInferenceType.PREFILL else len(instance_info.running_seq_lens)
instance_info.finished_request_ids = [seq_group.request_id for seq_group in self.running if seq_group.is_finished()]
return instance_info

@scheduler_lock
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
seq_group_metadata_list, scheduler_outputs = super().schedule()
self.update_instance_info_callback(self._get_instance_info())
self.update_instance_info_callback(self._get_instance_info([scheduled_seq_group.seq_group \
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups]))
return seq_group_metadata_list, scheduler_outputs

def add_seq_group(self, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion llumnix/backends/vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def output_len(self) -> int:
return self.get_seqs()[0].get_output_len()

@property
def inference_type(self) -> bool:
def inference_type(self) -> RequestInferenceType:
if self.is_prefill():
return RequestInferenceType.PREFILL
return RequestInferenceType.DECODE
23 changes: 14 additions & 9 deletions llumnix/backends/vllm/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# limitations under the License.

import os
import threading
from typing import List

from vllm.utils import Counter
from vllm.engine.arg_utils import EngineArgs

from llumnix.logger import init_logger
Expand All @@ -30,10 +30,9 @@ class BackendSimVLLM(BackendVLLM):
# pylint: disable=super-init-not-called
def __init__(
self,
instance_id: int,
instance_id: str,
migration_config: MigrationConfig,
profiling_result_file_path: str,
gpu_type: str,
engine_args: EngineArgs,
) -> None:
# load database
Expand All @@ -48,18 +47,24 @@ def __init__(
model_name = os.path.basename(model_name)
# get latency mem
profiling_result: ProfilingResult = profiling_database.get(model_name)
sim_parallel_config = SimParallelConfig(gpu_type, parallel_config.tensor_parallel_size,
assert profiling_result is not None, f"can't find {model_name} in profiling database"
sim_parallel_config = SimParallelConfig(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
assert sim_parallel_config in profiling_result.para_dict.keys(), "sim parallel config not in database"
latency_mem: LatencyMemData = profiling_result.para_dict[sim_parallel_config]

self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(migration_config=migration_config,
latency_mem=latency_mem, engine_args=engine_args)
# multi-instance args
self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(engine_args=engine_args,
migration_config=migration_config,
instance_id=instance_id,
latency_mem=latency_mem)
self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)
self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info)
self.engine.output_processor.scheduler = self.engine.scheduler
self.migration_config = migration_config
self.instance_id = instance_id
self.step_counter = Counter()
self._thread = threading.Thread(
target=self._start_engine_loop, args=(), daemon=True, name="engine_loop"
)
self._thread.start()

def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
self.engine.model_executor.send_blocks(len(src_blocks))
Loading

0 comments on commit e9cf870

Please sign in to comment.