Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyi-ECNU committed Oct 14, 2024
1 parent 2f781cf commit dea6812
Show file tree
Hide file tree
Showing 17 changed files with 159 additions and 45 deletions.
9 changes: 6 additions & 3 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,14 @@ class EngineManagerArgs:
last_stage_max_blocks: int = None
max_stages: int = None

enable_pd_disagg: bool = False
enable_pd_disagg: bool = None

def __post_init__(self):
# Check if all fields default to None
for field_info in dataclasses.fields(self):
if field_info.default is not None:
raise ValueError(f"The default value of '{field_info.name}' should be None")

for attr in dataclasses.fields(self):
if getattr(self, attr.name) is None:
setattr(self, attr.name, getattr(_C.MANAGER, attr.name.upper()))
Expand Down Expand Up @@ -143,7 +148,6 @@ def add_cli_args(
help='request dispatch policy')
parser.add_argument('--num-available-dispatch-instances',
type=int,
default=None,
help='number of available instances for dispatching')

parser.add_argument('--enable-migration',
Expand Down Expand Up @@ -224,6 +228,5 @@ def add_cli_args(
help='drop migration if the number of stages > max_stages')
parser.add_argument('--enable-pd-disagg',
type=bool,
default=None,
help='enable prefill decoding disaggregation')
return parser
12 changes: 6 additions & 6 deletions llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,16 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
return seq_group_metadata_list, scheduler_outputs

def _schedule_running(self, running_queue: deque, *args, **kwargs):
filtered_running_queue = []
remove_running = []
for seq_group in list(running_queue):
filtered_running_queue = deque()
remove_running = deque()
for seq_group in running_queue:
if seq_group.output_len >= seq_group.expected_steps:
remove_running.append(seq_group)
remove_running.extend([seq_group])
else:
filtered_running_queue.append(seq_group)
filtered_running_queue.extend([seq_group])
remaining_running, running_scheduled = super()._schedule_running(filtered_running_queue, *args, **kwargs)
for seq_group in remove_running:
remaining_running.append(seq_group)
remaining_running.extend([seq_group])
return remaining_running, running_scheduled

def add_seq_group(self, *args, **kwargs):
Expand Down
4 changes: 3 additions & 1 deletion llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math

from .config import LlumnixConfig as LC

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -78,7 +80,7 @@
# Request dispatch policy
_C.MANAGER.DISPATCH_POLICY = 'load'
# Number of available dispatch instances. -1 indicates that all instances can be used for dispatching
_C.MANAGER.NUM_DISPATCH_INSTANCES = -1
_C.MANAGER.NUM_DISPATCH_INSTANCES = math.inf

# -----------------------------------------------------------------------------
# MIGRATION CONFIGURATION
Expand Down
2 changes: 2 additions & 0 deletions llumnix/global_scheduler/dispatch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def remove_instance(self, instance_id: str) -> None:
self.num_instances = len(self.instance_id_set)
if instance_id in self.instance_num_requests:
del self.instance_num_requests[instance_id]
if instance_id in self.available_dispatch_instance_set:
self.available_dispatch_instance_set.remove(instance_id)

def _sort_instance_infos(self,
descending: bool = True) -> None:
Expand Down
4 changes: 2 additions & 2 deletions llumnix/global_scheduler/global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from llumnix.internal_config import GlobalSchedulerConfig
from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo
from llumnix.global_scheduler.dispatch_scheduler import DispatchScheduler
from llumnix.global_scheduler.migration_scheduler import MigrationScheduler
from llumnix.global_scheduler.migration_scheduler import MigrationScheduler, PairMigrationConstraints
from llumnix.global_scheduler.scaling_scheduler import ScalingScheduler

logger = init_logger(__name__)
Expand Down Expand Up @@ -69,7 +69,7 @@ def dispatch(self) -> str:
request_expected_steps = 1 if self.enable_pd_disagg else math.inf
return instance_id, request_expected_steps

def pair_migration(self, pair_migration_type:str) -> List[Tuple[str, str]]:
def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]:
self.migration_scheduler.update_instance_infos(self.instance_info)
migrate_instance_pairs = self.migration_scheduler.pair_migration(pair_migration_type)
return migrate_instance_pairs
Expand Down
7 changes: 4 additions & 3 deletions llumnix/global_scheduler/migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def __init__(self,
self.instance_info: Dict[str, InstanceInfo] = None
self.sorted_instance_infos: List[InstanceInfo] = None

def pair_migration(self, pair_migration_type:str) -> List[Tuple[str, str]]:
def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]:
self._sort_instance_infos(descending=False)
sorted_src_instance_infos, sorted_dst_instance_infos = self._get_migration_instance_infos(pair_migration_type)
return self.pair_migration_policy.pair_migration(sorted_src_instance_infos, sorted_dst_instance_infos)

def _get_migration_instance_infos(self, pair_migration_type:str) -> Dict[str, InstanceInfo]:
def _get_migration_instance_infos(self, pair_migration_type: PairMigrationConstraints) -> Dict[str, InstanceInfo]:
filter_instance_infos_policy = FilteringInstanceInfosPolicyFactory.get_policy(pair_migration_type,
migrate_out_load_threshold=self.migrate_out_load_threshold)
return filter_instance_infos_policy.filter_instances(self.sorted_instance_infos,pair_migration_type)
Expand Down Expand Up @@ -98,7 +98,8 @@ def __init__(self,
PairMigrationConstraints.PREFILL_2_DECODING: (InstanceType.PREFILL, InstanceType.DECODE),
}

def filter_instances(self, sorted_instance_infos: List[InstanceInfo], pair_migration_type: str = None) -> Dict[str, InstanceInfo]:
def filter_instances(self, sorted_instance_infos: List[InstanceInfo],
pair_migration_type: PairMigrationConstraints = None) -> Dict[str, InstanceInfo]:
src_type, dst_type = self.filter_instances_rules[pair_migration_type]
filtered_src_instance_infos = [info for info in sorted_instance_infos if info.instance_type == src_type]
filtered_dst_instance_infos = [info for info in sorted_instance_infos if info.instance_type == dst_type]
Expand Down
11 changes: 8 additions & 3 deletions llumnix/global_scheduler/scaling_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Dict, List, Tuple, Set
from abc import ABC, abstractmethod
from enum import Enum
import math
import numpy as np

from llumnix.logger import init_logger
Expand Down Expand Up @@ -77,18 +78,22 @@ def add_instance(self, instance_id: str) -> None:
self.instance_id_set.add(instance_id)
self.num_instances = len(self.instance_id_set)
instance_type = None
if self.maximum_prefill_instance_num > 0:
if self.maximum_prefill_instance_num == math.inf:
instance_type = InstanceType.NO_CONSTRAINTS
else:
if len(self.instance_type_id_set[InstanceType.PREFILL]) < self.maximum_prefill_instance_num:
instance_type = InstanceType.PREFILL
else:
instance_type = InstanceType.DECODE
else:
instance_type = InstanceType.NO_CONSTRAINTS
self.instance_type_id_set[instance_type].add(instance_id)
return instance_type

def remove_instance(self, instance_id: str) -> None:
self.instance_id_set.remove(instance_id)
for instance_type in InstanceType:
if instance_id in self.instance_type_id_set[instance_type]:
self.instance_type_id_set[instance_type].remove(instance_id)
break
self.num_instances = len(self.instance_id_set)

def get_empty_instance_info(self) -> InstanceInfo:
Expand Down
5 changes: 3 additions & 2 deletions llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import time
import csv
import os
import math
from typing import Dict, List, Tuple, Union, Iterable
from collections import defaultdict
import traceback
Expand Down Expand Up @@ -217,12 +218,12 @@ async def _clear_request_instance_loop(self, interval: float):
async def _push_migrations(self) -> None:
# Push migrate when the instance_info have updated a certain number of times.
if self.enable_pd_disagg:
asyncio.create_task(self._migrate(PairMigrationConstraints.PREFILL_2_DECODING, -1))
asyncio.create_task(self._migrate(PairMigrationConstraints.PREFILL_2_DECODING, math.inf))
asyncio.create_task(self._migrate(PairMigrationConstraints.DECODING_2_DECODING, 1))
else:
asyncio.create_task(self._migrate(PairMigrationConstraints.NO_CONSTRAINTS, 1))

async def _migrate(self, pair_migration_type:str, migrate_in_num_requests:int) -> None:
async def _migrate(self, pair_migration_type: PairMigrationConstraints, migrate_in_num_requests: int) -> None:
async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> None:
if migrate_instance_pair[0] in self.instance_migrating:
self.instance_migrating[migrate_instance_pair[0]] = False
Expand Down
3 changes: 2 additions & 1 deletion llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def migrate_out(self, dst_instance_name: str, num_requests: int) -> List[str]:
dst_instance_id = dst_instance_name[len("instance_"):]
migrated_request_list = []
continue_migrate = True
while continue_migrate and (len(migrated_request_list) < num_requests or num_requests == -1):
while continue_migrate and len(migrated_request_list) < num_requests:
t0 = time.time()
migrate_out_request = self.migration_scheduler.get_migrate_out_request()
if migrate_out_request is not None:
Expand All @@ -157,6 +157,7 @@ def migrate_out(self, dst_instance_name: str, num_requests: int) -> List[str]:
migrate_out_request.stage_timestamps.append(time.time())
self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request)
else:
migrate_out_request.reset_migration_args()
ray.get(migrate_in_ray_actor.execute_migration_method.remote("free_dst_pre_alloc_cache", migrate_out_request.request_id))
continue_migrate = False
t1 = time.time()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_test/backends/vllm/test_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,4 @@ def test_llm_engine_add_requset():
assert len(llm_engine.scheduler.waiting) == 1
assert llm_engine.scheduler.waiting[-1].request_id == "0"
assert llm_engine.scheduler.waiting[-1].expected_steps == math.inf
assert isinstance(llm_engine.scheduler.waiting[-1], LlumnixRequest)
assert isinstance(llm_engine.scheduler.waiting[-1], LlumnixRequest)
93 changes: 88 additions & 5 deletions tests/unit_test/backends/vllm/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@

from typing import List
import asyncio
import math
import pytest
import ray

from vllm import EngineArgs, SamplingParams
from llumnix.utils import random_uuid
from vllm.utils import random_uuid

from llumnix.backends.vllm.llm_engine import BackendVLLM
from llumnix.llumlet.llumlet import Llumlet
Expand Down Expand Up @@ -92,14 +93,14 @@ async def test_migration_correctness(setup_ray_env, migration_backend):
llumlet_1.execute_engine_method.remote("_run_workers", "rebuild_migration_backend", id_rank_map, "llumnix")])

# empty instance migrate out
res = ray.get(llumlet_0.migrate_out.remote("instance_1"))
res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=math.inf))
assert not res

# running without migration
async def test_correctness(prompt):
sampling_params = SamplingParams(top_k=1, temperature=0, ignore_eos=True, max_tokens=100)
request_id0 = random_uuid()
llumlet_0.generate.remote(request_id0, server_info, prompt, sampling_params)
llumlet_0.generate.remote(request_id0, server_info, math.inf, prompt, sampling_params)
request_output_queue = que
origin_output = None
finished = False
Expand All @@ -110,14 +111,14 @@ async def test_correctness(prompt):
finished = request_output.finished

request_id1 = random_uuid()
ray.get(llumlet_0.generate.remote(request_id1, server_info, prompt, sampling_params))
ray.get(llumlet_0.generate.remote(request_id1, server_info, math.inf, prompt, sampling_params))
# wait prefill done
while True:
running_queue: List[LlumnixRequest] = ray.get(llumlet_0.execute_engine_method.remote("get_running_queue"))
if len(running_queue) > 0 and running_queue[0].inference_type == RequestInferenceType.DECODE:
break
# migrate request
res = ray.get(llumlet_0.migrate_out.remote("instance_1"))
res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=math.inf))
assert len(res) == 1

request_output_queue = que
Expand All @@ -140,6 +141,88 @@ async def test_correctness(prompt):
await test_correctness(prompt)
que.cleanup()

@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo', 'nccl'])
@pytest.mark.asyncio
async def test_pd_diaggregation_correctness(setup_ray_env, migration_backend):
engine_args = EngineArgs(model="facebook/opt-125m",worker_use_ray=True)
id_rank_map = {"0":0,"1":1}
migration_config = MigrationConfig("LCFS", migration_backend, 16, 1, 4, 5, 20)
server_info = init_server_info()
que = init_request_output_queue(server_info)
asyncio.create_task(que.run_server_loop())

llumlet_0:Llumlet = Llumlet.from_args(
False,
True,
ray.get_runtime_context().get_node_id(),
"0",
BackendType.VLLM,
1,
migration_config,
engine_args,)

llumlet_1:Llumlet = Llumlet.from_args(
False,
True,
ray.get_runtime_context().get_node_id(),
"1",
BackendType.VLLM,
1,
migration_config,
engine_args,
)
while True:
res = ray.get([llumlet_0.is_ready.remote(),llumlet_1.is_ready.remote()])
if all(res):
break
ray.get([llumlet_0.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix"),
llumlet_1.execute_engine_method.remote("_run_workers","rebuild_migration_backend", id_rank_map, "llumnix")])
# empty instance migrate out
res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests=math.inf))
assert not res

# running without migration
async def test_correctness(prompt):
sampling_params = SamplingParams(top_k=1, temperature=0, ignore_eos=True, max_tokens=100)
request_id0 = random_uuid()
request_expected_steps_id0 = math.inf
llumlet_0.generate.remote(request_id0, server_info, request_expected_steps_id0, prompt, sampling_params)
request_output_queue = que
origin_output = None
finished = False
while not finished:
request_output = await request_output_queue.get()
origin_output = request_output.outputs[0]
finished = request_output.finished

request_id1 = random_uuid()
request_expected_steps_id1 = 1
ray.get(llumlet_0.generate.remote(request_id1, server_info, request_expected_steps_id1, prompt, sampling_params))
# migrate request for decoding
while True:
res = ray.get(llumlet_0.migrate_out.remote("instance_1", num_requests = math.inf))
if len(res) == 1:
break
request_output_queue = que
output = None
finished = False
while not finished:
request_output = await request_output_queue.get()
origin_output = request_output.outputs[0]
finished = request_output.finished
if request_output.request_id != request_id1:
continue
output = request_output.outputs[0]
finished = request_output.finished

assert output.text == origin_output.text
assert output.cumulative_logprob == origin_output.cumulative_logprob
for prompt in TEST_PROMPTS:
await test_correctness(prompt)
que.cleanup()

def test_clear_migration_states():
llumlet = MockLlumlet()
llumlet.backend_engine.pre_alloc("0", 1)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_test/backends/vllm/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_schedule_running():
policy = PolicyFactory.get_policy(policy_name="fcfs")
budget = create_token_budget()
curr_loras = None

_, seq_group_0 = create_dummy_prompt("0", prompt_length=1, expected_steps=math.inf)
scheduler._allocate_and_set_running(seq_group_0)
append_new_token_seq_group(1, seq_group_0, 1)
Expand Down
Loading

0 comments on commit dea6812

Please sign in to comment.