Skip to content

Commit

Permalink
rebase main
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyi-ECNU committed Oct 14, 2024
1 parent dea6812 commit d2def02
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 31 deletions.
2 changes: 1 addition & 1 deletion llumnix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@
"QueueType",
]

__all__.extend(getattr(vllm, "__all__", []))
__all__.extend(getattr(vllm, "__all__", []))
27 changes: 15 additions & 12 deletions tests/unit_test/backends/vllm/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,19 @@ async def test_correctness(prompt):
await test_correctness(prompt)
que.cleanup()

@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl'])
@pytest.mark.asyncio
async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend):
engine_args = EngineArgs(model="facebook/opt-125m",worker_use_ray=True)
id_rank_map = {"0":0,"1":1}
migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20)
server_info = init_server_info()
que = init_request_output_queue(server_info)

output_queue_type = QueueType.RAYQUEUE
que, server_info = request_output_queue_server(output_queue_type)
asyncio.create_task(que.run_server_loop())

llumlet_0:Llumlet = Llumlet.from_args(
output_queue_type,
False,
True,
ray.get_runtime_context().get_node_id(),
Expand All @@ -164,6 +164,7 @@ async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend):
engine_args,)

llumlet_1:Llumlet = Llumlet.from_args(
output_queue_type,
False,
True,
ray.get_runtime_context().get_node_id(),
Expand Down Expand Up @@ -193,9 +194,10 @@ async def test_correctness(prompt):
origin_output = None
finished = False
while not finished:
request_output = await request_output_queue.get()
origin_output = request_output.outputs[0]
finished = request_output.finished
request_outputs = await request_output_queue.get()
for request_output in request_outputs:
origin_output = request_output.outputs[0]
finished = request_output.finished

request_id1 = random_uuid()
request_expected_steps_id1 = 1
Expand All @@ -209,11 +211,12 @@ async def test_correctness(prompt):
output = None
finished = False
while not finished:
request_output = await request_output_queue.get()
origin_output = request_output.outputs[0]
finished = request_output.finished
if request_output.request_id != request_id1:
continue
request_outputs = await request_output_queue.get()
for request_output in request_outputs:
origin_output = request_output.outputs[0]
finished = request_output.finished
if request_output.request_id != request_id1:
continue
output = request_output.outputs[0]
finished = request_output.finished

Expand Down
18 changes: 5 additions & 13 deletions tests/unit_test/global_scheduler/test_dispatch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

def init_dispatch_scheduler(policy='load'):
instance_load_calculator = InstanceLoadCalculator('remaining_steps', True)
dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator, random.randint(-1,4))
dispatch_scheduler = DispatchScheduler(policy, instance_load_calculator, random.randint(1,4))
return dispatch_scheduler

@pytest.fixture
Expand All @@ -42,17 +42,11 @@ def test_add_instance_and_remove_instance(dispatch_scheduler):
assert len(dispatch_scheduler.available_dispatch_instance_set) == 1
dispatch_scheduler.add_instance('instance_3')
assert dispatch_scheduler.num_instances == 2
if dispatch_scheduler.num_dispatch_instances <= 0:
assert len(dispatch_scheduler.available_dispatch_instance_set) == 2
else:
assert len(dispatch_scheduler.available_dispatch_instance_set) == min(2, dispatch_scheduler.num_dispatch_instances)
assert len(dispatch_scheduler.available_dispatch_instance_set) == min(2, dispatch_scheduler.num_dispatch_instances)

dispatch_scheduler.remove_instance('instance_2')
assert dispatch_scheduler.num_instances == 1
if dispatch_scheduler.num_dispatch_instances <= 0:
assert len(dispatch_scheduler.available_dispatch_instance_set) == 1
else:
assert len(dispatch_scheduler.available_dispatch_instance_set) == min(1, dispatch_scheduler.num_dispatch_instances-1)
assert len(dispatch_scheduler.available_dispatch_instance_set) == 1
dispatch_scheduler.remove_instance('instance_3')
assert dispatch_scheduler.num_instances == 0

Expand All @@ -62,8 +56,7 @@ def test_dispatch_balanced():
dispatch_scheduler = init_dispatch_scheduler('balanced')
instance_num_requests = {}
for instance_id in [f'instance_{i}' for i in range(1, INSTANCE_NUM + 1)]:
if dispatch_scheduler.num_dispatch_instances <= 0 or (dispatch_scheduler.num_dispatch_instances > 0
and len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances):
if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances:
dispatch_scheduler.available_dispatch_instance_set.add(instance_id)
instance_num_requests[instance_id] = random.randint(1, 10)
dispatch_scheduler.instance_num_requests = instance_num_requests
Expand Down Expand Up @@ -106,8 +99,7 @@ def test_dispatch_queue():
instance_info.instance_id = instance_id
instance_info.num_waiting_requests = random.randint(1, 10)
instance_info_dict[instance_id] = instance_info
if dispatch_scheduler.num_dispatch_instances <= 0 or (dispatch_scheduler.num_dispatch_instances > 0
and len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances):
if len(dispatch_scheduler.available_dispatch_instance_set) < dispatch_scheduler.num_dispatch_instances:
dispatch_scheduler.available_dispatch_instance_set.add(instance_id)
instance_num_requests[instance_id] = 0
dispatch_scheduler.instance_num_requests = instance_num_requests
Expand Down
3 changes: 2 additions & 1 deletion tests/unit_test/global_scheduler/test_global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

def init_global_scheduler():
global_scheduler_config = GlobalSchedulerConfig(0, 'remaining_steps', 'load', math.inf,
'defrag_constrained', 3.0, True, 'avg_load', 10, 60, False)
'defrag_constrained', 3.0, True, 'avg_load',
10, 60, False)
global_scheduler = GlobalScheduler(global_scheduler_config)
return global_scheduler

Expand Down
8 changes: 4 additions & 4 deletions tests/unit_test/global_scheduler/test_migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ def test_get_migration_instance_infos(pair_migration_type):
if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS:
constraint_prefill_instance_num = math.inf
else:
constraint_prefill_instance_num = random.randint(-1, INSTANCE_NUM)
constraint_prefill_instance_num = random.randint(1, INSTANCE_NUM)
migration_scheduler = init_migration_scheduler()
if constraint_prefill_instance_num > 0:
if constraint_prefill_instance_num == math.inf:
instance_info.instance_type = InstanceType.NO_CONSTRAINTS
else:
if len([info for info in instance_info_dict.values()
if info.instance_type == InstanceType.PREFILL]) < constraint_prefill_instance_num:
instance_info.instance_type = InstanceType.PREFILL
else:
instance_info.instance_type = InstanceType.DECODE
else:
instance_info.instance_type = InstanceType.NO_CONSTRAINTS
instance_info_dict[instance_id] = instance_info
migration_scheduler.instance_info = instance_info_dict
migration_scheduler._sort_instance_infos(descending=False)
Expand Down

0 comments on commit d2def02

Please sign in to comment.