From 8c882a72b40240a5c8430baaa11d25a0e033dc98 Mon Sep 17 00:00:00 2001 From: Biao Sun Date: Wed, 19 Feb 2025 15:18:29 +0800 Subject: [PATCH] [BugFix] Change max_concurrency of llumlet, worker and proxy actor to avoid potential dead lock and performance degradation (#106) --- llumnix/backends/vllm/executor.py | 1 + llumnix/backends/vllm/migration_backend.py | 4 +- llumnix/llumlet/llumlet.py | 2 - .../unit_test/backends/vllm/test_migration.py | 6 ++- .../backends/vllm/test_migration_backend.py | 2 +- .../unit_test/backends/vllm/test_simulator.py | 5 +- tests/unit_test/backends/vllm/test_worker.py | 50 +++++++++++++++++-- .../global_scheduler/test_manager.py | 2 +- .../llumlet/test_engine_step_exception.py | 3 +- 9 files changed, 61 insertions(+), 14 deletions(-) diff --git a/llumnix/backends/vllm/executor.py b/llumnix/backends/vllm/executor.py index 66be7ff3..820efc4e 100644 --- a/llumnix/backends/vllm/executor.py +++ b/llumnix/backends/vllm/executor.py @@ -84,6 +84,7 @@ def _init_workers_ray(self, placement_group: PlacementGroup, num_cpus=0, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, + max_concurrency=2, **ray_remote_kwargs, )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) diff --git a/llumnix/backends/vllm/migration_backend.py b/llumnix/backends/vllm/migration_backend.py index 9af2304c..32182a8f 100644 --- a/llumnix/backends/vllm/migration_backend.py +++ b/llumnix/backends/vllm/migration_backend.py @@ -24,7 +24,7 @@ logger = init_logger(__name__) -@ray.remote(num_cpus=1) +@ray.remote(num_cpus=1, max_concurrency=2) class ProxyActor: def exec_method(self, is_driver_worker, handle, *args, **kwargs): try: @@ -83,7 +83,7 @@ def destory_backend(self) -> None: pass def warmup(self) -> bool: - self.actor.exec_method.remote(self.is_driver_worker, "do_send", [0]) + self.actor.exec_method.remote(self.is_driver_worker, self.worker_handle_list[self.worker_rank], "do_send", None, [0]) logger.info("Rayrpc migration backend warmup successfully.") return True diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index f016b77c..185aea64 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -101,12 +101,10 @@ def from_args(cls, if backend_type == backend_type.BLADELLM: world_size = get_engine_world_size(engine_args, backend_type) num_gpus = world_size - # TODO(s5u13b): Check the max_concurrency. llumlet_class = ray.remote(num_cpus=1, num_gpus=num_gpus, name=get_instance_name(instance_id), namespace='llumnix', - max_concurrency=4, lifetime="detached")(cls).options( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=placement_group, diff --git a/tests/unit_test/backends/vllm/test_migration.py b/tests/unit_test/backends/vllm/test_migration.py index 5313b46f..7797eb28 100644 --- a/tests/unit_test/backends/vllm/test_migration.py +++ b/tests/unit_test/backends/vllm/test_migration.py @@ -55,16 +55,19 @@ def init_llumlet(request_output_queue_type, instance_id, instance_args, engine_a engine_args=engine_args) return llumlet + class MockBackendVLLM(BackendVLLM): def __init__(self): self.engine = MockEngine() + class MockLlumlet(Llumlet): def __init__(self): self.instance_id = "0" self.backend_engine = MockBackendVLLM() -@ray.remote(num_cpus=1, max_concurrency=4) + +@ray.remote(num_cpus=1) class MockLlumletDoNotSchedule(Llumlet): def __init__(self, *args, **kwargs): instance_id = kwargs["instance_id"] @@ -98,6 +101,7 @@ async def step_async_try_schedule(): self.backend_engine.engine.step_async = step_async_try_schedule + @pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo', 'nccl']) @pytest.mark.parametrize("migration_request_status", ['waiting', 'running']) @pytest.mark.asyncio diff --git a/tests/unit_test/backends/vllm/test_migration_backend.py b/tests/unit_test/backends/vllm/test_migration_backend.py index 2be5a006..3c0b1e22 100644 --- a/tests/unit_test/backends/vllm/test_migration_backend.py +++ b/tests/unit_test/backends/vllm/test_migration_backend.py @@ -26,9 +26,9 @@ from tests.conftest import ray_env from .test_worker import create_worker + class MockMigrationWorker(MigrationWorker): def set_gpu_cache(self, data): - print(f"data shape:::{self.gpu_cache[0][0].shape, data[0].shape}") for layer_idx in range(self.cache_engine[0].num_attention_layers): self.gpu_cache[0][layer_idx].copy_(data[layer_idx]) torch.cuda.synchronize() diff --git a/tests/unit_test/backends/vllm/test_simulator.py b/tests/unit_test/backends/vllm/test_simulator.py index 4ddb8613..19afce1c 100644 --- a/tests/unit_test/backends/vllm/test_simulator.py +++ b/tests/unit_test/backends/vllm/test_simulator.py @@ -20,6 +20,7 @@ from .utils import create_dummy_prompt, initialize_scheduler + class MockBackendSim(BackendSimVLLM): def _get_lantecy_mem(self, *args, **kwargs): @@ -28,6 +29,7 @@ def _get_lantecy_mem(self, *args, **kwargs): latency_mem.decode_model_params = (0,0,0) return latency_mem + @pytest.mark.asyncio async def test_executor(): engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) @@ -84,8 +86,7 @@ def __init__(self): pass dummy_actor = ray.remote(num_cpus=1, name="instance_0", - namespace='llumnix', - max_concurrency=4)(DummyActor) + namespace='llumnix')(DummyActor) dummy_actor = dummy_actor.remote() placement_group = initialize_placement_group(get_placement_group_name("0"), num_cpus=2, num_gpus=0, detached=True) sim_backend = MockBackendSim(instance_id="0", diff --git a/tests/unit_test/backends/vllm/test_worker.py b/tests/unit_test/backends/vllm/test_worker.py index 9fb6c974..5fbb4fb1 100644 --- a/tests/unit_test/backends/vllm/test_worker.py +++ b/tests/unit_test/backends/vllm/test_worker.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time import pytest import torch import ray @@ -24,23 +25,43 @@ from llumnix.arg_utils import InstanceArgs from llumnix.utils import random_uuid from llumnix.utils import initialize_placement_group, get_placement_group_name +from llumnix.backends.vllm.worker import MigrationWorker # pylint: disable=unused-import from tests.conftest import ray_env + +class MockMigrationWorker(MigrationWorker): + def __init__(self, *args, **kwargs): + self.do_a_started = False + self.do_a_finished = False + super().__init__(*args, **kwargs) + + def do_a(self): + self.do_a_started = True + self.do_a_finished = False + time.sleep(3.0) + self.do_a_finished = True + + def do_b(self): + return self.do_a_started, self.do_a_finished + + def create_worker(rank: int, local_rank: int, engine_config: EngineConfig, worker_module_name="llumnix.backends.vllm.worker", - worker_class_name="MigrationWorker"): + worker_class_name="MigrationWorker", + max_concurrency=1): worker = ray.remote( num_cpus=0, - num_gpus=1 + num_gpus=1, + max_concurrency=max_concurrency )(RayWorkerWrapper).remote( worker_module_name=worker_module_name, worker_class_name=worker_class_name, trust_remote_code=True ) - worker.init_worker.remote( + ray.get(worker.init_worker.remote( model_config=engine_config.model_config, parallel_config=engine_config.parallel_config, scheduler_config=engine_config.scheduler_config, @@ -52,7 +73,7 @@ def create_worker(rank: int, local_rank: int, engine_config: EngineConfig, distributed_init_method=get_distributed_init_method(get_ip(), get_open_port()), lora_config=engine_config.lora_config, is_driver_worker = False - ) + )) return worker @@ -126,3 +147,24 @@ def test_rebuild_migration_backend(ray_env, backend): assert ray.get(worker0.execute_method.remote('rebuild_migration_backend', instance_rank=instance_rank, group_name=random_uuid())) assert ray.get(worker0.execute_method.remote('warmup')) + +def test_max_concurrency(ray_env): + engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config() + worker_no_concurrency = create_worker(rank=0, local_rank=0, engine_config=engine_config, + worker_module_name="tests.unit_test.backends.vllm.test_worker", + worker_class_name="MockMigrationWorker", + max_concurrency=1) + + worker_no_concurrency.execute_method.remote('do_a') + do_a_started, do_a_finished = ray.get(worker_no_concurrency.execute_method.remote('do_b')) + assert do_a_started and do_a_finished + + worker_with_concurrency = create_worker(rank=0, local_rank=0, engine_config=engine_config, + worker_module_name="tests.unit_test.backends.vllm.test_worker", + worker_class_name="MockMigrationWorker", + max_concurrency=2) + + worker_with_concurrency.execute_method.remote('do_a') + time.sleep(1.0) + do_a_started, do_a_finished = ray.get(worker_with_concurrency.execute_method.remote('do_b')) + assert do_a_started and not do_a_finished diff --git a/tests/unit_test/global_scheduler/test_manager.py b/tests/unit_test/global_scheduler/test_manager.py index 8d32ddb0..3d8e06f9 100644 --- a/tests/unit_test/global_scheduler/test_manager.py +++ b/tests/unit_test/global_scheduler/test_manager.py @@ -41,7 +41,7 @@ from tests.conftest import ray_env -@ray.remote(num_cpus=1, max_concurrency=4) +@ray.remote(num_cpus=1) class MockLlumlet: def __init__(self, instance_id): self.instance_id = instance_id diff --git a/tests/unit_test/llumlet/test_engine_step_exception.py b/tests/unit_test/llumlet/test_engine_step_exception.py index 92a85a1f..dd4b8423 100644 --- a/tests/unit_test/llumlet/test_engine_step_exception.py +++ b/tests/unit_test/llumlet/test_engine_step_exception.py @@ -28,7 +28,8 @@ # pylint: disable=unused-import from tests.conftest import ray_env -@ray.remote(num_cpus=1, max_concurrency=4) + +@ray.remote(num_cpus=1) class MockLlumlet(Llumlet): def __init__(self, *args, **kwargs) -> None: instance_id = kwargs["instance_id"]