diff --git a/llumnix/arg_utils.py b/llumnix/arg_utils.py index 5e6d71ea..404b3477 100644 --- a/llumnix/arg_utils.py +++ b/llumnix/arg_utils.py @@ -61,9 +61,14 @@ class EngineManagerArgs: last_stage_max_blocks: int = None max_stages: int = None - enable_pd_disagg: bool = False + enable_pd_disagg: bool = None def __post_init__(self): + # Check if all fields default to None + for field_info in dataclasses.fields(self): + if field_info.default is not None: + raise ValueError(f"The default value of '{field_info.name}' should be None") + for attr in dataclasses.fields(self): if getattr(self, attr.name) is None: setattr(self, attr.name, getattr(_C.MANAGER, attr.name.upper())) @@ -143,7 +148,6 @@ def add_cli_args( help='request dispatch policy') parser.add_argument('--num-available-dispatch-instances', type=int, - default=None, help='number of available instances for dispatching') parser.add_argument('--enable-migration', @@ -224,6 +228,5 @@ def add_cli_args( help='drop migration if the number of stages > max_stages') parser.add_argument('--enable-pd-disagg', type=bool, - default=None, help='enable prefill decoding disaggregation') return parser diff --git a/llumnix/backends/vllm/scheduler.py b/llumnix/backends/vllm/scheduler.py index 8feac4ab..7e9064d7 100644 --- a/llumnix/backends/vllm/scheduler.py +++ b/llumnix/backends/vllm/scheduler.py @@ -209,16 +209,16 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: return seq_group_metadata_list, scheduler_outputs def _schedule_running(self, running_queue: deque, *args, **kwargs): - filtered_running_queue = [] - remove_running = [] - for seq_group in list(running_queue): + filtered_running_queue = deque() + remove_running = deque() + for seq_group in running_queue: if seq_group.output_len >= seq_group.expected_steps: - remove_running.append(seq_group) + remove_running.extend([seq_group]) else: - filtered_running_queue.append(seq_group) + filtered_running_queue.extend([seq_group]) remaining_running, running_scheduled = super()._schedule_running(filtered_running_queue, *args, **kwargs) for seq_group in remove_running: - remaining_running.append(seq_group) + remaining_running.extend([seq_group]) return remaining_running, running_scheduled def add_seq_group(self, *args, **kwargs): diff --git a/llumnix/config/default.py b/llumnix/config/default.py index 7db5a003..3ff053cc 100644 --- a/llumnix/config/default.py +++ b/llumnix/config/default.py @@ -11,6 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math + from .config import LlumnixConfig as LC # ----------------------------------------------------------------------------- @@ -78,7 +80,7 @@ # Request dispatch policy _C.MANAGER.DISPATCH_POLICY = 'load' # Number of available dispatch instances. -1 indicates that all instances can be used for dispatching -_C.MANAGER.NUM_DISPATCH_INSTANCES = -1 +_C.MANAGER.NUM_DISPATCH_INSTANCES = math.inf # ----------------------------------------------------------------------------- # MIGRATION CONFIGURATION diff --git a/llumnix/global_scheduler/dispatch_scheduler.py b/llumnix/global_scheduler/dispatch_scheduler.py index dcf376de..175bdbde 100644 --- a/llumnix/global_scheduler/dispatch_scheduler.py +++ b/llumnix/global_scheduler/dispatch_scheduler.py @@ -69,6 +69,8 @@ def remove_instance(self, instance_id: str) -> None: self.num_instances = len(self.instance_id_set) if instance_id in self.instance_num_requests: del self.instance_num_requests[instance_id] + if instance_id in self.available_dispatch_instance_set: + self.available_dispatch_instance_set.remove(instance_id) def _sort_instance_infos(self, descending: bool = True) -> None: diff --git a/llumnix/global_scheduler/global_scheduler.py b/llumnix/global_scheduler/global_scheduler.py index c69cdece..79d6e88e 100644 --- a/llumnix/global_scheduler/global_scheduler.py +++ b/llumnix/global_scheduler/global_scheduler.py @@ -18,7 +18,7 @@ from llumnix.internal_config import GlobalSchedulerConfig from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo from llumnix.global_scheduler.dispatch_scheduler import DispatchScheduler -from llumnix.global_scheduler.migration_scheduler import MigrationScheduler +from llumnix.global_scheduler.migration_scheduler import MigrationScheduler, PairMigrationConstraints from llumnix.global_scheduler.scaling_scheduler import ScalingScheduler logger = init_logger(__name__) @@ -69,7 +69,7 @@ def dispatch(self) -> str: request_expected_steps = 1 if self.enable_pd_disagg else math.inf return instance_id, request_expected_steps - def pair_migration(self, pair_migration_type:str) -> List[Tuple[str, str]]: + def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]: self.migration_scheduler.update_instance_infos(self.instance_info) migrate_instance_pairs = self.migration_scheduler.pair_migration(pair_migration_type) return migrate_instance_pairs diff --git a/llumnix/global_scheduler/migration_scheduler.py b/llumnix/global_scheduler/migration_scheduler.py index 34071668..3445b210 100644 --- a/llumnix/global_scheduler/migration_scheduler.py +++ b/llumnix/global_scheduler/migration_scheduler.py @@ -56,12 +56,12 @@ def __init__(self, self.instance_info: Dict[str, InstanceInfo] = None self.sorted_instance_infos: List[InstanceInfo] = None - def pair_migration(self, pair_migration_type:str) -> List[Tuple[str, str]]: + def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]: self._sort_instance_infos(descending=False) sorted_src_instance_infos, sorted_dst_instance_infos = self._get_migration_instance_infos(pair_migration_type) return self.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos) - def _get_migration_instance_infos(self, pair_migration_type:str) -> Dict[str, InstanceInfo]: + def _get_migration_instance_infos(self, pair_migration_type: PairMigrationConstraints) -> Dict[str, InstanceInfo]: filter_instance_infos_policy = FilteringInstanceInfosPolicyFactory.get_policy(pair_migration_type, migrate_out_load_threshold=self.migrate_out_load_threshold) return filter_instance_infos_policy.filter_instances(self.sorted_instance_infos,pair_migration_type) @@ -98,7 +98,8 @@ def __init__(self, PairMigrationConstraints.PREFILL_2_DECODING: (InstanceType.PREFILL, InstanceType.DECODE), } - def filter_instances(self, sorted_instance_infos: List[InstanceInfo], pair_migration_type: str = None) -> Dict[str, InstanceInfo]: + def filter_instances(self, sorted_instance_infos: List[InstanceInfo], + pair_migration_type: PairMigrationConstraints = None) -> Dict[str, InstanceInfo]: src_type, dst_type = self.filter_instances_rules[pair_migration_type] filtered_src_instance_infos = [info for info in sorted_instance_infos if info.instance_type == src_type] filtered_dst_instance_infos = [info for info in sorted_instance_infos if info.instance_type == dst_type] diff --git a/llumnix/global_scheduler/scaling_scheduler.py b/llumnix/global_scheduler/scaling_scheduler.py index bd15f99c..edcc9627 100644 --- a/llumnix/global_scheduler/scaling_scheduler.py +++ b/llumnix/global_scheduler/scaling_scheduler.py @@ -14,6 +14,7 @@ from typing import Dict, List, Tuple, Set from abc import ABC, abstractmethod from enum import Enum +import math import numpy as np from llumnix.logger import init_logger @@ -77,18 +78,22 @@ def add_instance(self, instance_id: str) -> None: self.instance_id_set.add(instance_id) self.num_instances = len(self.instance_id_set) instance_type = None - if self.maximum_prefill_instance_num > 0: + if self.maximum_prefill_instance_num == math.inf: + instance_type = InstanceType.NO_CONSTRAINTS + else: if len(self.instance_type_id_set[InstanceType.PREFILL]) < self.maximum_prefill_instance_num: instance_type = InstanceType.PREFILL else: instance_type = InstanceType.DECODE - else: - instance_type = InstanceType.NO_CONSTRAINTS self.instance_type_id_set[instance_type].add(instance_id) return instance_type def remove_instance(self, instance_id: str) -> None: self.instance_id_set.remove(instance_id) + for instance_type in InstanceType: + if instance_id in self.instance_type_id_set[instance_type]: + self.instance_type_id_set[instance_type].remove(instance_id) + break self.num_instances = len(self.instance_id_set) def get_empty_instance_info(self) -> InstanceInfo: diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index fa04bd43..35a6a574 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -15,6 +15,7 @@ import time import csv import os +import math from typing import Dict, List, Tuple, Union, Iterable from collections import defaultdict import traceback @@ -217,12 +218,12 @@ async def _clear_request_instance_loop(self, interval: float): async def _push_migrations(self) -> None: # Push migrate when the instance_info have updated a certain number of times. if self.enable_pd_disagg: - asyncio.create_task(self._migrate(PairMigrationConstraints.PREFILL_2_DECODING, -1)) + asyncio.create_task(self._migrate(PairMigrationConstraints.PREFILL_2_DECODING, math.inf)) asyncio.create_task(self._migrate(PairMigrationConstraints.DECODING_2_DECODING, 1)) else: asyncio.create_task(self._migrate(PairMigrationConstraints.NO_CONSTRAINTS, 1)) - async def _migrate(self, pair_migration_type:str, migrate_in_num_requests:int) -> None: + async def _migrate(self, pair_migration_type: PairMigrationConstraints, migrate_in_num_requests: int) -> None: async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> None: if migrate_instance_pair[0] in self.instance_migrating: self.instance_migrating[migrate_instance_pair[0]] = False diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 657c77ef..5d220676 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -141,7 +141,7 @@ def migrate_out(self, dst_instance_name: str, num_requests: int) -> List[str]: dst_instance_id = dst_instance_name[len("instance_"):] migrated_request_list = [] continue_migrate = True - while continue_migrate and (len(migrated_request_list) < num_requests or num_requests == -1): + while continue_migrate and len(migrated_request_list) < num_requests: t0 = time.time() migrate_out_request = self.migration_scheduler.get_migrate_out_request() if migrate_out_request is not None: @@ -157,6 +157,7 @@ def migrate_out(self, dst_instance_name: str, num_requests: int) -> List[str]: migrate_out_request.stage_timestamps.append(time.time()) self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) else: + migrate_out_request.reset_migration_args() ray.get(migrate_in_ray_actor.execute_migration_method.remote("free_dst_pre_alloc_cache", migrate_out_request.request_id)) continue_migrate = False t1 = time.time() diff --git a/tests/unit_test/backends/vllm/test_llm_engine.py b/tests/unit_test/backends/vllm/test_llm_engine.py index cfcd2f1f..a6c6e3a1 100644 --- a/tests/unit_test/backends/vllm/test_llm_engine.py +++ b/tests/unit_test/backends/vllm/test_llm_engine.py @@ -114,4 +114,4 @@ def test_llm_engine_add_requset(): assert len(llm_engine.scheduler.waiting) == 1 assert llm_engine.scheduler.waiting[-1].request_id == "0" assert llm_engine.scheduler.waiting[-1].expected_steps == math.inf - assert isinstance(llm_engine.scheduler.waiting[-1], LlumnixRequest) \ No newline at end of file + assert isinstance(llm_engine.scheduler.waiting[-1], LlumnixRequest) diff --git a/tests/unit_test/backends/vllm/test_migration.py b/tests/unit_test/backends/vllm/test_migration.py index 8e2bc371..eb42ee2d 100644 --- a/tests/unit_test/backends/vllm/test_migration.py +++ b/tests/unit_test/backends/vllm/test_migration.py @@ -13,11 +13,12 @@ from typing import List import asyncio +import math import pytest import ray from vllm import EngineArgs, SamplingParams -from llumnix.utils import random_uuid +from vllm.utils import random_uuid from llumnix.backends.vllm.llm_engine import BackendVLLM from llumnix.llumlet.llumlet import Llumlet @@ -92,14 +93,14 @@ async def test_migration_correctness(setup_ray_env, migration_backend): llumlet_1.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix")]) # empty instance migrate out - res = ray.get(llumlet_0.migrate_out.remote("instance_1")) + res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=math.inf)) assert not res # running without migration async def test_correctness(prompt): sampling_params = SamplingParams(top_k=1, temperature=0, ignore_eos=True, max_tokens=100) request_id0 = random_uuid() - llumlet_0.generate.remote(request_id0, server_info, prompt, sampling_params) + llumlet_0.generate.remote(request_id0, server_info, math.inf, prompt, sampling_params) request_output_queue = que origin_output = None finished = False @@ -110,14 +111,14 @@ async def test_correctness(prompt): finished = request_output.finished request_id1 = random_uuid() - ray.get(llumlet_0.generate.remote(request_id1, server_info, prompt, sampling_params)) + ray.get(llumlet_0.generate.remote(request_id1, server_info, math.inf, prompt, sampling_params)) # wait prefill done while True: running_queue: List[LlumnixRequest] = ray.get(llumlet_0.execute_engine_method.remote("get_running_queue")) if len(running_queue) > 0 and running_queue[0].inference_type == RequestInferenceType.DECODE: break # migrate request - res = ray.get(llumlet_0.migrate_out.remote("instance_1")) + res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=math.inf)) assert len(res) == 1 request_output_queue = que @@ -140,6 +141,88 @@ async def test_correctness(prompt): await test_correctness(prompt) que.cleanup() +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl']) +@pytest.mark.asyncio +async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend): + engine_args = EngineArgs(model="facebook/opt-125m",worker_use_ray=True) + id_rank_map = {"0":0,"1":1} + migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20) + server_info = init_server_info() + que = init_request_output_queue(server_info) + asyncio.create_task(que.run_server_loop()) + + llumlet_0:Llumlet = Llumlet.from_args( + False, + True, + ray.get_runtime_context().get_node_id(), + "0", + BackendType.VLLM, + 1, + migration_config, + engine_args,) + + llumlet_1:Llumlet = Llumlet.from_args( + False, + True, + ray.get_runtime_context().get_node_id(), + "1", + BackendType.VLLM, + 1, + migration_config, + engine_args, + ) + while True: + res = ray.get([llumlet_0.is_ready.remote(),llumlet_1.is_ready.remote()]) + if all(res): + break + ray.get([llumlet_0.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix"), + llumlet_1.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix")]) + # empty instance migrate out + res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=math.inf)) + assert not res + + # running without migration + async def test_correctness(prompt): + sampling_params = SamplingParams(top_k=1, temperature=0, ignore_eos=True, max_tokens=100) + request_id0 = random_uuid() + request_expected_steps_id0 = math.inf + llumlet_0.generate.remote(request_id0, server_info, request_expected_steps_id0, prompt, sampling_params) + request_output_queue = que + origin_output = None + finished = False + while not finished: + request_output = await request_output_queue.get() + origin_output = request_output.outputs[0] + finished = request_output.finished + + request_id1 = random_uuid() + request_expected_steps_id1 = 1 + ray.get(llumlet_0.generate.remote(request_id1, server_info, request_expected_steps_id1, prompt, sampling_params)) + # migrate request for decoding + while True: + res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests = math.inf)) + if len(res) == 1: + break + request_output_queue = que + output = None + finished = False + while not finished: + request_output = await request_output_queue.get() + origin_output = request_output.outputs[0] + finished = request_output.finished + if request_output.request_id != request_id1: + continue + output = request_output.outputs[0] + finished = request_output.finished + + assert output.text == origin_output.text + assert output.cumulative_logprob == origin_output.cumulative_logprob + for prompt in TEST_PROMPTS: + await test_correctness(prompt) + que.cleanup() + def test_clear_migration_states(): llumlet = MockLlumlet() llumlet.backend_engine.pre_alloc("0", 1) diff --git a/tests/unit_test/backends/vllm/test_scheduler.py b/tests/unit_test/backends/vllm/test_scheduler.py index 102b3019..1c1af7ac 100644 --- a/tests/unit_test/backends/vllm/test_scheduler.py +++ b/tests/unit_test/backends/vllm/test_scheduler.py @@ -156,7 +156,7 @@ def test_schedule_running(): policy = PolicyFactory.get_policy(policy_name="fcfs") budget = create_token_budget() curr_loras = None - + _, seq_group_0 = create_dummy_prompt("0", prompt_length=1, expected_steps=math.inf) scheduler._allocate_and_set_running(seq_group_0) append_new_token_seq_group(1, seq_group_0, 1) diff --git a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py index ae787b6a..84676474 100644 --- a/tests/unit_test/global_scheduler/test_dispatch_scheduler.py +++ b/tests/unit_test/global_scheduler/test_dispatch_scheduler.py @@ -33,15 +33,27 @@ def test_add_instance_and_remove_instance(dispatch_scheduler): dispatch_scheduler.add_instance('instance_1') assert dispatch_scheduler.num_instances == 1 assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 + dispatch_scheduler.remove_instance('instance_1') + assert dispatch_scheduler.num_instances == 0 + assert len(dispatch_scheduler.available_dispatch_instance_set) == 0 + dispatch_scheduler.add_instance('instance_2') + assert dispatch_scheduler.num_instances == 1 + assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 + dispatch_scheduler.add_instance('instance_3') assert dispatch_scheduler.num_instances == 2 if dispatch_scheduler.num_dispatch_instances <= 0: assert len(dispatch_scheduler.available_dispatch_instance_set) == 2 else: assert len(dispatch_scheduler.available_dispatch_instance_set) == min(2, dispatch_scheduler.num_dispatch_instances) - dispatch_scheduler.remove_instance('instance_1') - assert dispatch_scheduler.num_instances == 1 + dispatch_scheduler.remove_instance('instance_2') + assert dispatch_scheduler.num_instances == 1 + if dispatch_scheduler.num_dispatch_instances <= 0: + assert len(dispatch_scheduler.available_dispatch_instance_set) == 1 + else: + assert len(dispatch_scheduler.available_dispatch_instance_set) == min(1, dispatch_scheduler.num_dispatch_instances-1) + dispatch_scheduler.remove_instance('instance_3') assert dispatch_scheduler.num_instances == 0 def test_dispatch_balanced(): @@ -76,7 +88,7 @@ def test_dispatch_load(): instance_num_requests[instance_id] = 0 dispatch_scheduler.instance_num_requests = instance_num_requests dispatch_scheduler.instance_info = instance_info_dict - available_instance_dict = {key: instance_info_dict[key] for key in instance_info_dict + available_instance_dict = {key: value for key, value in instance_info_dict.items() if key in dispatch_scheduler.available_dispatch_instance_set} min_instance_id = next(key for key, value in sorted(available_instance_dict.items(), key=lambda item: item[1].instance_load_dispatch_scale)) @@ -100,7 +112,7 @@ def test_dispatch_queue(): instance_num_requests[instance_id] = 0 dispatch_scheduler.instance_num_requests = instance_num_requests dispatch_scheduler.instance_info = instance_info_dict - available_instance_dict = {key: instance_info_dict[key] for key in instance_info_dict + available_instance_dict = {key: value for key, value in instance_info_dict.items() if key in dispatch_scheduler.available_dispatch_instance_set} min_instance_id = next(key for key, value in sorted(available_instance_dict.items(), key=lambda item: item[1].num_waiting_requests)) diff --git a/tests/unit_test/global_scheduler/test_global_scheduler.py b/tests/unit_test/global_scheduler/test_global_scheduler.py index dca37e4f..9f2e23c8 100644 --- a/tests/unit_test/global_scheduler/test_global_scheduler.py +++ b/tests/unit_test/global_scheduler/test_global_scheduler.py @@ -23,7 +23,8 @@ def init_global_scheduler(): - global_scheduler_config = GlobalSchedulerConfig(0, 'remaining_steps', 'load', -1, 'defrag_constrained', 3.0, True, 'avg_load', 10, 60, False) + global_scheduler_config = GlobalSchedulerConfig(0, 'remaining_steps', 'load', math.inf, + 'defrag_constrained', 3.0, True, 'avg_load', 10, 60, False) global_scheduler = GlobalScheduler(global_scheduler_config) return global_scheduler diff --git a/tests/unit_test/global_scheduler/test_llm_engine_manager.py b/tests/unit_test/global_scheduler/test_llm_engine_manager.py index a2d0d7b7..b744ced6 100644 --- a/tests/unit_test/global_scheduler/test_llm_engine_manager.py +++ b/tests/unit_test/global_scheduler/test_llm_engine_manager.py @@ -12,7 +12,7 @@ # limitations under the License. import time - +import math import ray import pytest import numpy as np @@ -171,7 +171,7 @@ def test_generate_and_abort(setup_ray_env, engine_manager, llumlet): num_requests = ray.get(llumlet.get_num_requests.remote()) assert num_requests == 0 server_info = ServerInfo(None, None, None, None, None) - ray.get(engine_manager.generate.remote(request_id, server_info, -1, None, None)) + ray.get(engine_manager.generate.remote(request_id, server_info, math.inf, None, None)) num_requests = ray.get(llumlet.get_num_requests.remote()) assert num_requests == 1 ray.get(engine_manager.abort.remote(request_id)) @@ -189,8 +189,8 @@ def test_get_request_instance(setup_ray_env): llumlet, llumlet_1 = llumlets[0], llumlets[1] request_id = random_uuid() request_id_1 = random_uuid() - ray.get(llumlet.generate.remote(request_id, None, -1, None, None)) - ray.get(llumlet_1.generate.remote(request_id_1, None, -1, None, None)) + ray.get(llumlet.generate.remote(request_id, None, math.inf, None, None)) + ray.get(llumlet_1.generate.remote(request_id_1, None, math.inf, None, None)) num_requests = ray.get(llumlet.get_num_requests.remote()) num_requests_1 = ray.get(llumlet_1.get_num_requests.remote()) assert num_requests == 1 @@ -227,8 +227,8 @@ def test_update_instance_info_loop_and_migrate(setup_ray_env, engine_manager): llumlet, llumlet_1 = llumlets[0], llumlets[1] request_id = random_uuid() request_id_1 = random_uuid() - ray.get(llumlet.generate.remote(request_id, None, -1, None, None)) - ray.get(llumlet_1.generate.remote(request_id_1, None, -1, None, None)) + ray.get(llumlet.generate.remote(request_id, None, math.inf, None, None)) + ray.get(llumlet_1.generate.remote(request_id_1, None, math.inf, None, None)) instance_info_migrate_out = get_instance_info_migrate_out(instance_id) instance_info_migrate_in = get_instance_info_migrate_in(instance_id_1) ray.get(llumlet.set_instance_info.remote(instance_info_migrate_out)) diff --git a/tests/unit_test/global_scheduler/test_migration_scheduler.py b/tests/unit_test/global_scheduler/test_migration_scheduler.py index 9dcd699e..ed3b706d 100644 --- a/tests/unit_test/global_scheduler/test_migration_scheduler.py +++ b/tests/unit_test/global_scheduler/test_migration_scheduler.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import random import pytest import numpy as np @@ -53,12 +54,13 @@ def test_get_migration_instance_infos(pair_migration_type): instance_info.instance_load_migrate = MIGRATE_OUT_LOAD_THRESHOLD + random.uniform(-1, 1) instance_info.num_killed_requests = random.randint(0, 1) if pair_migration_type == PairMigrationConstraints.NO_CONSTRAINTS: - constraint_prefill_instance_num = -1 + constraint_prefill_instance_num = math.inf else: constraint_prefill_instance_num = random.randint(-1, INSTANCE_NUM) migration_scheduler = init_migration_scheduler() if constraint_prefill_instance_num > 0: - if len([info for info in instance_info_dict.values() if info.instance_type == InstanceType.PREFILL]) < constraint_prefill_instance_num: + if len([info for info in instance_info_dict.values() + if info.instance_type == InstanceType.PREFILL]) < constraint_prefill_instance_num: instance_info.instance_type = InstanceType.PREFILL else: instance_info.instance_type = InstanceType.DECODE @@ -105,11 +107,11 @@ def test_pair_migration(policy): instance_info_dict[instance_id] = instance_info migration_scheduler.instance_info = instance_info_dict migration_scheduler._sort_instance_infos(descending=False) - sorted_src_instance_infos = [i for i in reversed(migration_scheduler.sorted_instance_infos) - if i.instance_type == InstanceType.NO_CONSTRAINTS + sorted_src_instance_infos = [i for i in reversed(migration_scheduler.sorted_instance_infos) + if i.instance_type == InstanceType.NO_CONSTRAINTS and (i.num_killed_requests > 0 or i.instance_load_migrate > migration_scheduler.migrate_out_load_threshold)] sorted_dst_instance_infos = [i for i in migration_scheduler.sorted_instance_infos - if i.instance_type == InstanceType.NO_CONSTRAINTS + if i.instance_type == InstanceType.NO_CONSTRAINTS and (i.num_killed_requests == 0 and i.instance_load_migrate < migration_scheduler.migrate_out_load_threshold)] migrate_instance_pairs = migration_scheduler.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos) for migrate_out_instance, migrate_in_instance in migrate_instance_pairs: diff --git a/tests/unit_test/llumlet/test_migration_coordinator.py b/tests/unit_test/llumlet/test_migration_coordinator.py index 8635c468..fd925cfe 100644 --- a/tests/unit_test/llumlet/test_migration_coordinator.py +++ b/tests/unit_test/llumlet/test_migration_coordinator.py @@ -13,6 +13,7 @@ from unittest.mock import MagicMock, patch +import math import ray from llumnix.llumlet.migration_coordinator import MigrationCoordinator @@ -96,7 +97,7 @@ def test_migrate_out_multistage(_, setup_ray_env): # Create mock objects backend_engine = MagicMock(spec=BackendInterface) migrate_in_ray_actor = MagicMock() - migrate_out_request = MockRequest("1", 1, -1) + migrate_out_request = MockRequest("1", 1, math.inf) # Create an instance of MigrationCoordinator max_stages = 3