Skip to content

Commit

Permalink
[BugFix] Change max_concurrency of llumlet, worker and proxy actor to…
Browse files Browse the repository at this point in the history
… avoid potential dead lock and performance degradation (#106)
  • Loading branch information
s5u13b authored Feb 19, 2025
1 parent 5066d32 commit 8c882a7
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 14 deletions.
1 change: 1 addition & 0 deletions llumnix/backends/vllm/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _init_workers_ray(self, placement_group: PlacementGroup,
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
max_concurrency=2,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)

Expand Down
4 changes: 2 additions & 2 deletions llumnix/backends/vllm/migration_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

logger = init_logger(__name__)

@ray.remote(num_cpus=1)
@ray.remote(num_cpus=1, max_concurrency=2)
class ProxyActor:
def exec_method(self, is_driver_worker, handle, *args, **kwargs):
try:
Expand Down Expand Up @@ -83,7 +83,7 @@ def destory_backend(self) -> None:
pass

def warmup(self) -> bool:
self.actor.exec_method.remote(self.is_driver_worker, "do_send", [0])
self.actor.exec_method.remote(self.is_driver_worker, self.worker_handle_list[self.worker_rank], "do_send", None, [0])
logger.info("Rayrpc migration backend warmup successfully.")
return True

Expand Down
2 changes: 0 additions & 2 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,10 @@ def from_args(cls,
if backend_type == backend_type.BLADELLM:
world_size = get_engine_world_size(engine_args, backend_type)
num_gpus = world_size
# TODO(s5u13b): Check the max_concurrency.
llumlet_class = ray.remote(num_cpus=1,
num_gpus=num_gpus,
name=get_instance_name(instance_id),
namespace='llumnix',
max_concurrency=4,
lifetime="detached")(cls).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
Expand Down
6 changes: 5 additions & 1 deletion tests/unit_test/backends/vllm/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,19 @@ def init_llumlet(request_output_queue_type, instance_id, instance_args, engine_a
engine_args=engine_args)
return llumlet


class MockBackendVLLM(BackendVLLM):
def __init__(self):
self.engine = MockEngine()


class MockLlumlet(Llumlet):
def __init__(self):
self.instance_id = "0"
self.backend_engine = MockBackendVLLM()

@ray.remote(num_cpus=1, max_concurrency=4)

@ray.remote(num_cpus=1)
class MockLlumletDoNotSchedule(Llumlet):
def __init__(self, *args, **kwargs):
instance_id = kwargs["instance_id"]
Expand Down Expand Up @@ -98,6 +101,7 @@ async def step_async_try_schedule():

self.backend_engine.engine.step_async = step_async_try_schedule


@pytest.mark.parametrize("migration_backend", ['rayrpc', 'gloo', 'nccl'])
@pytest.mark.parametrize("migration_request_status", ['waiting', 'running'])
@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_test/backends/vllm/test_migration_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from tests.conftest import ray_env
from .test_worker import create_worker


class MockMigrationWorker(MigrationWorker):
def set_gpu_cache(self, data):
print(f"data shape:::{self.gpu_cache[0][0].shape, data[0].shape}")
for layer_idx in range(self.cache_engine[0].num_attention_layers):
self.gpu_cache[0][layer_idx].copy_(data[layer_idx])
torch.cuda.synchronize()
Expand Down
5 changes: 3 additions & 2 deletions tests/unit_test/backends/vllm/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .utils import create_dummy_prompt, initialize_scheduler


class MockBackendSim(BackendSimVLLM):

def _get_lantecy_mem(self, *args, **kwargs):
Expand All @@ -28,6 +29,7 @@ def _get_lantecy_mem(self, *args, **kwargs):
latency_mem.decode_model_params = (0,0,0)
return latency_mem


@pytest.mark.asyncio
async def test_executor():
engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True)
Expand Down Expand Up @@ -84,8 +86,7 @@ def __init__(self):
pass
dummy_actor = ray.remote(num_cpus=1,
name="instance_0",
namespace='llumnix',
max_concurrency=4)(DummyActor)
namespace='llumnix')(DummyActor)
dummy_actor = dummy_actor.remote()
placement_group = initialize_placement_group(get_placement_group_name("0"), num_cpus=2, num_gpus=0, detached=True)
sim_backend = MockBackendSim(instance_id="0",
Expand Down
50 changes: 46 additions & 4 deletions tests/unit_test/backends/vllm/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import pytest
import torch
import ray
Expand All @@ -24,23 +25,43 @@
from llumnix.arg_utils import InstanceArgs
from llumnix.utils import random_uuid
from llumnix.utils import initialize_placement_group, get_placement_group_name
from llumnix.backends.vllm.worker import MigrationWorker

# pylint: disable=unused-import
from tests.conftest import ray_env


class MockMigrationWorker(MigrationWorker):
def __init__(self, *args, **kwargs):
self.do_a_started = False
self.do_a_finished = False
super().__init__(*args, **kwargs)

def do_a(self):
self.do_a_started = True
self.do_a_finished = False
time.sleep(3.0)
self.do_a_finished = True

def do_b(self):
return self.do_a_started, self.do_a_finished


def create_worker(rank: int, local_rank: int, engine_config: EngineConfig,
worker_module_name="llumnix.backends.vllm.worker",
worker_class_name="MigrationWorker"):
worker_class_name="MigrationWorker",
max_concurrency=1):
worker = ray.remote(
num_cpus=0,
num_gpus=1
num_gpus=1,
max_concurrency=max_concurrency
)(RayWorkerWrapper).remote(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=True
)

worker.init_worker.remote(
ray.get(worker.init_worker.remote(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
Expand All @@ -52,7 +73,7 @@ def create_worker(rank: int, local_rank: int, engine_config: EngineConfig,
distributed_init_method=get_distributed_init_method(get_ip(), get_open_port()),
lora_config=engine_config.lora_config,
is_driver_worker = False
)
))

return worker

Expand Down Expand Up @@ -126,3 +147,24 @@ def test_rebuild_migration_backend(ray_env, backend):
assert ray.get(worker0.execute_method.remote('rebuild_migration_backend', instance_rank=instance_rank,
group_name=random_uuid()))
assert ray.get(worker0.execute_method.remote('warmup'))

def test_max_concurrency(ray_env):
engine_config = EngineArgs(model='facebook/opt-125m', max_model_len=8, enforce_eager=True).create_engine_config()
worker_no_concurrency = create_worker(rank=0, local_rank=0, engine_config=engine_config,
worker_module_name="tests.unit_test.backends.vllm.test_worker",
worker_class_name="MockMigrationWorker",
max_concurrency=1)

worker_no_concurrency.execute_method.remote('do_a')
do_a_started, do_a_finished = ray.get(worker_no_concurrency.execute_method.remote('do_b'))
assert do_a_started and do_a_finished

worker_with_concurrency = create_worker(rank=0, local_rank=0, engine_config=engine_config,
worker_module_name="tests.unit_test.backends.vllm.test_worker",
worker_class_name="MockMigrationWorker",
max_concurrency=2)

worker_with_concurrency.execute_method.remote('do_a')
time.sleep(1.0)
do_a_started, do_a_finished = ray.get(worker_with_concurrency.execute_method.remote('do_b'))
assert do_a_started and not do_a_finished
2 changes: 1 addition & 1 deletion tests/unit_test/global_scheduler/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from tests.conftest import ray_env


@ray.remote(num_cpus=1, max_concurrency=4)
@ray.remote(num_cpus=1)
class MockLlumlet:
def __init__(self, instance_id):
self.instance_id = instance_id
Expand Down
3 changes: 2 additions & 1 deletion tests/unit_test/llumlet/test_engine_step_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
# pylint: disable=unused-import
from tests.conftest import ray_env

@ray.remote(num_cpus=1, max_concurrency=4)

@ray.remote(num_cpus=1)
class MockLlumlet(Llumlet):
def __init__(self, *args, **kwargs) -> None:
instance_id = kwargs["instance_id"]
Expand Down

0 comments on commit 8c882a7

Please sign in to comment.