From d2def025c50aab8ec95f7ff634246012e66c0c25 Mon Sep 17 00:00:00 2001 From: Xinyi-ECNU <1668529909@qq.com> Date: Mon, 14 Oct 2024 11:30:28 +0800 Subject: [PATCH] rebase main --- llumnix/__init__.py | 2 +- .../unit_test/backends/vllm/test_migration.py | 27 ++++++++++--------- .../test_dispatch_scheduler.py | 18 ++++--------- .../global_scheduler/test_global_scheduler.py | 3 ++- .../test_migration_scheduler.py | 8 +++--- 5 files changed, 27 insertions(+), 31 deletions(-) diff --git a/llumnix/__init__.py b/llumnix/__init__.py index 9f321fc4..71de719c 100644 --- a/llumnix/__init__.py +++ b/llumnix/__init__.py @@ -37,4 +37,4 @@ "QueueType", ] -__all__.extend(getattr(vllm, "__all__", [])) \ No newline at end of file +__all__.extend(getattr(vllm, "__all__", [])) diff --git a/tests/unit_test/backends/vllm/test_migration.py b/tests/unit_test/backends/vllm/test_migration.py index eb42ee2d..2a8ad19e 100644 --- a/tests/unit_test/backends/vllm/test_migration.py +++ b/tests/unit_test/backends/vllm/test_migration.py @@ -141,19 +141,19 @@ async def test_correctness(prompt): await test_correctness(prompt) que.cleanup() -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) @pytest.mark.asyncio async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend): engine_args = EngineArgs(model="facebook/opt-125m",worker_use_ray=True) id_rank_map = {"0":0,"1":1} migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20) - server_info = init_server_info() - que = init_request_output_queue(server_info) + + output_queue_type = QueueType.RAYQUEUE + que, server_info = request_output_queue_server(output_queue_type) asyncio.create_task(que.run_server_loop()) llumlet_0:Llumlet = Llumlet.from_args( + output_queue_type, False, True, ray.get_runtime_context().get_node_id(), @@ -164,6 +164,7 @@ async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend): engine_args,) llumlet_1:Llumlet = Llumlet.from_args( + output_queue_type, False, True, ray.get_runtime_context().get_node_id(), @@ -193,9 +194,10 @@ async def test_correctness(prompt): origin_output = None finished = False while not finished: - request_output = await request_output_queue.get() - origin_output = request_output.outputs[0] - finished = request_output.finished + request_outputs = await request_output_queue.get() + for request_output in request_outputs: + origin_output = request_output.outputs[0] + finished = request_output.finished request_id1 = random_uuid() request_expected_steps_id1 = 1 @@ -209,11 +211,12 @@ async def test_correctness(prompt): output = None finished = False while not finished: - request_output = await request_output_queue.get() - origin_output = request_output.outputs[0] - finished = request_output.finished - if request_output.request_id != request_id1: - continue + request_outputs = await request_output_queue.get() + for request_output in request_outputs: + origin_output = request_output.outputs[0] + finished = request_output.finished + if request_output.request_id != request_id1: + continue output = request_output.outputs[0] finished = request_output.finished diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index 84676474..8cee3a69 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -21,7 +21,7 @@ def init_dispatch_scheduler(policy='load'): instance_load_calculator = InstanceLoadCalculator('remaining_steps', True) - dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator, random.randint(-1,4)) + dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator, random.randint(1,4)) return dispatch_scheduler @pytest.fixture @@ -42,17 +42,11 @@ def test_add_instance_and_remove_instance(dispatch_scheduler): assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 dispatch_scheduler.add_instance('instance_3') assert dispatch_scheduler.num_instances == 2 - if dispatch_scheduler.num_dispatch_instances <= 0: - assert len(dispatch_scheduler.available_dispatch_instance_set) == 2 - else: - assert len(dispatch_scheduler.available_dispatch_instance_set) == min(2, dispatch_scheduler.num_dispatch_instances) + assert len(dispatch_scheduler.available_dispatch_instance_set) == min(2, dispatch_scheduler.num_dispatch_instances) dispatch_scheduler.remove_instance('instance_2') assert dispatch_scheduler.num_instances == 1 - if dispatch_scheduler.num_dispatch_instances <= 0: - assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 - else: - assert len(dispatch_scheduler.available_dispatch_instance_set) == min(1, dispatch_scheduler.num_dispatch_instances-1) + assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 dispatch_scheduler.remove_instance('instance_3') assert dispatch_scheduler.num_instances == 0 @@ -62,8 +56,7 @@ def test_dispatch_balanced(): dispatch_scheduler = init_dispatch_scheduler('balanced') instance_num_requests = {} for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]: - if dispatch_scheduler.num_dispatch_instances <= 0 or (dispatch_scheduler.num_dispatch_instances > 0 - and len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances): + if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances: dispatch_scheduler.available_dispatch_instance_set.add(instance_id) instance_num_requests[instance_id] = random.randint(1, 10) dispatch_scheduler.instance_num_requests = instance_num_requests @@ -106,8 +99,7 @@ def test_dispatch_queue(): instance_info.instance_id = instance_id instance_info.num_waiting_requests = random.randint(1, 10) instance_info_dict[instance_id] = instance_info - if dispatch_scheduler.num_dispatch_instances <= 0 or (dispatch_scheduler.num_dispatch_instances > 0 - and len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances): + if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances: dispatch_scheduler.available_dispatch_instance_set.add(instance_id) instance_num_requests[instance_id] = 0 dispatch_scheduler.instance_num_requests = instance_num_requests diff --git a/tests/unit_test/global_scheduler/test_global_scheduler.py b/tests/unit_test/global_scheduler/test_global_scheduler.py index 9f2e23c8..adb1f1cc 100644 --- a/tests/unit_test/global_scheduler/test_global_scheduler.py +++ b/tests/unit_test/global_scheduler/test_global_scheduler.py @@ -24,7 +24,8 @@ def init_global_scheduler(): global_scheduler_config = GlobalSchedulerConfig(0, 'remaining_steps', 'load', math.inf, - 'defrag_constrained', 3.0, True, 'avg_load', 10, 60, False) + 'defrag_constrained', 3.0, True, 'avg_load', + 10, 60, False) global_scheduler = GlobalScheduler(global_scheduler_config) return global_scheduler diff --git a/tests/unit_test/global_scheduler/test_migration_scheduler.py b/tests/unit_test/global_scheduler/test_migration_scheduler.py index ed3b706d..8fd32105 100644 --- a/tests/unit_test/global_scheduler/test_migration_scheduler.py +++ b/tests/unit_test/global_scheduler/test_migration_scheduler.py @@ -56,16 +56,16 @@ def test_get_migration_instance_infos(pair_migration_type): if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: constraint_prefill_instance_num = math.inf else: - constraint_prefill_instance_num = random.randint(-1, INSTANCE_NUM) + constraint_prefill_instance_num = random.randint(1, INSTANCE_NUM) migration_scheduler = init_migration_scheduler() - if constraint_prefill_instance_num > 0: + if constraint_prefill_instance_num == math.inf: + instance_info.instance_type = InstanceType.NO_CONSTRAINTS + else: if len([info for info in instance_info_dict.values() if info.instance_type == InstanceType.PREFILL]) < constraint_prefill_instance_num: instance_info.instance_type = InstanceType.PREFILL else: instance_info.instance_type = InstanceType.DECODE - else: - instance_info.instance_type = InstanceType.NO_CONSTRAINTS instance_info_dict[instance_id] = instance_info migration_scheduler.instance_info = instance_info_dict migration_scheduler._sort_instance_infos(descending=False)