Skip to content

Commit

Permalink
[Manager] Optimize watch instance deployment states implementation & …
Browse files Browse the repository at this point in the history
…Add enable_port_offset_store arg (#92)
  • Loading branch information
s5u13b authored Jan 15, 2025
1 parent 3e319f0 commit 1f49b36
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 30 deletions.
4 changes: 4 additions & 0 deletions docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--enable-pd-disagg]
[--num-dispatch-instances NUM_DISPATCH_INSTANCES]
[--enable-port-increment]
[--enable-port-offset-store]
```

`--host`
Expand Down Expand Up @@ -237,6 +238,9 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
`--enable-port-increment`
- Enable port increment when desploying multiple servers.

`--enable-port-offset-store`
- Enable store port offset when desploying multiple servers.

# Unsupported vLLM feature options

`--device`
Expand Down
8 changes: 8 additions & 0 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ class ManagerArgs:
num_dispatch_instances: int = None

enable_port_increment: bool = None
enable_port_offset_store: bool = None


def __post_init__(self):
# Check if all fields default to None
Expand Down Expand Up @@ -222,6 +224,9 @@ def check_args(cls, args: 'ManagerArgs', parser: argparse.ArgumentParser):
assert not args.simulator_mode or args.profiling_result_file_path is not None, \
"Set profiling_result_file_path args when enable simulator mode"

assert not args.enable_port_offset_store or args.enable_port_increment, \
"Set enable_port_increment when enable_port_offset_store"

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--initial-instances',
Expand Down Expand Up @@ -357,6 +362,9 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--enable-port-increment',
action='store_true',
help='enable port increment when desploying multiple servers')
parser.add_argument('--enable-port-offset-store',
action='store_true',
help='enable store port offset when desploying multiple servers')

return parser

Expand Down
2 changes: 2 additions & 0 deletions llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
_C.MANAGER.PROFILING_RESULT_FILE_PATH = None
# Enable port increment when deploying multiple servers
_C.MANAGER.ENABLE_PORT_INCREMENT = False
# Enable store port offset when deploying multiple servers
_C.MANAGER.ENABLE_PORT_OFFSET_STORE = False

# -----------------------------------------------------------------------------
# DISPATCH CONFIGURATION
Expand Down
1 change: 0 additions & 1 deletion llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def get_all_request_ids(self) -> List[str]:
return self.backend_engine.get_all_request_ids()

def generate(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs) -> None:
# This should not be used for logging, as it is monotonic time.
if hasattr(server_info, 'request_timestamps'):
server_info.request_timestamps.llumlet_generate_timestamp = time.time()
self.backend_engine.add_request(request_id, server_info, expected_steps, *args, **kwargs)
Expand Down
64 changes: 45 additions & 19 deletions llumnix/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
AUTO_SCALE_UP_INTERVAL = 1.0
WAIT_PLACEMENT_GROUP_TIMEOUT = 5.0
CHECK_DEPLOYMENT_STATES_INTERVAL = 30.0
WATCH_DEPLOYMENT_INTERVAL = 40.0
WATCH_DEPLOYMENT_INTERVAL = 10.0
WATCH_DEPLOYMENT_INTERVAL_PENDING_INSTANCE = 120.0

# TODO(s5u13b): Handle exception of ray operations.
# TODO(s5u13b): Add exeception handling wrapper.
Expand Down Expand Up @@ -132,8 +133,12 @@ def __init__(self,
asyncio.create_task(self._update_instance_info_loop(self.polling_interval))
asyncio.create_task(self._clear_request_instance_loop(CLEAR_REQUEST_INSTANCE_INTERVAL))

value = get_actor_data_from_ray_internal_kv("manager", "port_offset")
self.port_offset = 0 if value is None else int(value)
if self.manager_args.enable_port_increment:
self.port_offset = 0
if self.manager_args.enable_port_offset_store:
value = get_actor_data_from_ray_internal_kv("manager", "port_offset")
if value is not None:
self.port_offset = int(value)
if hasattr(self, "launch_mode") and self.launch_mode == LaunchMode.GLOBAL:
assert self.entrypoints_args is not None and self.engine_args is not None
self.last_timeout_instance_id = None
Expand Down Expand Up @@ -312,7 +317,7 @@ async def _auto_scale_up_loop(self, interval: float) -> None:
try:
await asyncio.wait_for(new_pg.ready(), WAIT_PLACEMENT_GROUP_TIMEOUT)
except asyncio.TimeoutError:
logger.info("[_auto_scale_up_loop] waiting for new placement group ready timeout")
logger.debug("[_auto_scale_up_loop] waiting for new placement group ready timeout")
# After timeout, the new placement group might be pending,
# created(without server and instance), rescheduling.
self.last_timeout_instance_id = new_instance_id
Expand Down Expand Up @@ -430,7 +435,7 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migration_b
no_pending_instance = self.pending_rebuild_migration_instances == 0

for ins_id in instance_ids:
self._clear_instance_ray_resources(ins_id)
self._clear_instance_ray_states(ins_id)
if ins_id in self.instances:
indeed_update = True
if ins_id in self.instances:
Expand Down Expand Up @@ -460,7 +465,7 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migration_b

return self.num_instances

def _clear_instance_ray_resources(self, instance_id: str):
def _clear_instance_ray_states(self, instance_id: str):
if not remove_placement_group(instance_id):
logger.debug("[clear_instance_ray_resources] failed to remove placement group {}".format(instance_id))
if not kill_server(instance_id):
Expand Down Expand Up @@ -544,7 +549,8 @@ def _init_server(self,
entrypoints_args.port += self.port_offset
entrypoints_args.request_output_queue_port += self.port_offset
self.port_offset += 1
put_actor_data_to_ray_internal_kv("manager", "port_offset", self.port_offset)
if self.manager_args.enable_port_offset_store:
put_actor_data_to_ray_internal_kv("manager", "port_offset", self.port_offset)
fastapi_server = FastAPIServerActor.from_args(server_name, placement_group, entrypoints_args)
return fastapi_server

Expand Down Expand Up @@ -605,23 +611,33 @@ async def done_scale_up():
asyncio.create_task(done_scale_up())

async def _check_deployment_states_loop(self, interval: float) -> None:
async def watch_deployment(instance_id: str):
async def watch_instance_deployment_states(instance_id: str):
# There might be some delays of calling _init_server_and_instance, so sleep first.
await asyncio.sleep(WATCH_DEPLOYMENT_INTERVAL)
curr_pgs, curr_servers, curr_instances = self.get_curr_deployment()
if instance_id in curr_pgs and (instance_id not in curr_servers or instance_id not in curr_instances):
logger.warning("[_check_deployment_states_loop] instance {} deployment states incorrect, "
"states: (pg {}, server {}, instance {})"
.format(instance_id, instance_id in curr_pgs, instance_id in curr_servers, instance_id in curr_instances))
wait_pending_instance_time = 0.0
while True:
instance_state = list_actors(filters=[("name", "=", get_instance_name(instance_id))])
instance_pending_creation = len(instance_state) == 1 and instance_state[0]["state"] == "PENDING_CREATION"
if not instance_pending_creation:
break
await asyncio.sleep(WATCH_DEPLOYMENT_INTERVAL)
wait_pending_instance_time += WATCH_DEPLOYMENT_INTERVAL
if wait_pending_instance_time >= WATCH_DEPLOYMENT_INTERVAL_PENDING_INSTANCE:
break
pg_created, server_alive, instance_alive = self._get_instance_deployment_states(instance_id)
if pg_created and (not server_alive or not instance_alive):
logger.warning("instance {} deployment states incorrect, states: (pg {}, server {}, instance {})"
.format(instance_id, pg_created, server_alive, instance_alive))
self.scale_down(instance_id)

while True:
try:
curr_pgs, curr_servers, curr_instances = self.get_curr_deployment()
curr_pgs, curr_servers, curr_instances = self._get_cluster_deployment()
assert len(curr_pgs) >= max(len(curr_servers), len(curr_instances))
tasks = []
for instance_id in curr_pgs:
if instance_id not in curr_servers or instance_id not in curr_instances:
tasks.append(asyncio.create_task(watch_deployment(instance_id)))
tasks.append(asyncio.create_task(watch_instance_deployment_states(instance_id)))
await asyncio.gather(*tasks, return_exceptions=True)
await asyncio.sleep(interval)
# pylint: disable=broad-except
Expand Down Expand Up @@ -655,7 +671,7 @@ def check_instance_error_done_callback(idx: int, instance_id: str, fut):

return results

def get_curr_deployment(self) -> Tuple[Dict[str, PlacementGroup], Dict[str, FastAPIServerActor], Dict[str, Llumlet]]:
def _get_cluster_deployment(self) -> Tuple[Dict[str, PlacementGroup], Dict[str, FastAPIServerActor], Dict[str, Llumlet]]:
curr_pgs: Dict[str, PlacementGroup] = {}
curr_servers: Dict[str, PlacementGroup] = {}
curr_instances: Dict[str, Llumlet] = {}
Expand All @@ -676,6 +692,16 @@ def get_curr_deployment(self) -> Tuple[Dict[str, PlacementGroup], Dict[str, Fast

return curr_pgs, curr_servers, curr_instances

def _get_instance_deployment_states(self, instance_id: str):
pg_state = list_placement_groups(filters=[("name", "=", get_placement_group_name(instance_id))])
pg_created = len(pg_state) == 1 and pg_state[0]["state"] == "CREATED"
server_state = list_actors(filters=[("name", "=", get_server_name(instance_id))])
server_alive = len(server_state) == 1 and server_state[0]["state"] == "ALIVE"
instance_state = list_actors(filters=[("name", "=", get_instance_name(instance_id))])
instance_alive = len(instance_state) == 1 and instance_state[0]["state"] == "ALIVE"

return pg_created, server_alive, instance_alive

async def _get_request_instance(self) -> None:
def get_request_instance_done_callback(instance_id: str, fut):
ret = fut.result()[0]
Expand All @@ -690,12 +716,12 @@ def get_request_instance_done_callback(instance_id: str, fut):
instance_ids = []
tasks = []
for instance_id, instance_actor_handle in self.instances.items():
task = asyncio.gather(instance_actor_handle.get_instance_info.remote(), return_exceptions=True)
task = asyncio.gather(instance_actor_handle.get_all_request_ids.remote(), return_exceptions=True)
task.add_done_callback(partial(get_request_instance_done_callback, instance_id))
tasks.append(task)
await asyncio.gather(*tasks, return_exceptions=True)
logger.debug("[_get_request_instance] instance_ids: {}".format(instance_ids))
logger.debug("[_get_request_instance] instance_requests: {}".format(instance_requests))
logger.info("instance_ids: {}".format(instance_ids))
logger.info("instance_requests: {}".format(instance_requests))
for (instance_id, requests) in zip(instance_ids, instance_requests):
for request_id in requests:
self.request_instance[request_id] = instance_id
Expand Down
26 changes: 16 additions & 10 deletions tests/unit_test/global_scheduler/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_update_instance_info_loop_and_migrate(ray_env, manager):
else:
assert num_migrate_in == 0 and num_migrate_out == 0

def test_init_server_and_instance_and_clear_instance_ray_resources(ray_env):
def test_init_server_and_get_instance_deployment_states_and_instance_and_clear_instance_ray_resources(ray_env):
manager, _, _, engine_args, _ = init_manager_with_launch_mode(LaunchMode.LOCAL)
instance_id = random_uuid()
pg = ray.get(manager._init_placement_group.remote(get_placement_group_name(instance_id),
Expand All @@ -333,8 +333,11 @@ def test_init_server_and_instance_and_clear_instance_ray_resources(ray_env):
num_instances = ray.get(manager.scale_up.remote(instance_id, instance))
assert num_instances == 1

pg_created, server_alive, instance_alive = ray.get(manager._get_instance_deployment_states.remote(instance_id))
assert pg_created and server_alive and instance_alive

# test clear_instance_ray_resources
ray.get(manager._clear_instance_ray_resources.remote(instance_id))
ray.get(manager._clear_instance_ray_states.remote(instance_id))
# wait for remove and kill
time.sleep(1.0)
pg_exists = is_placement_group_exists(get_placement_group_name(instance_id))
Expand All @@ -344,25 +347,28 @@ def test_init_server_and_instance_and_clear_instance_ray_resources(ray_env):
instance_exists = is_actor_exists(get_instance_name(instance_id))
assert not instance_exists

pg_created, server_alive, instance_alive = ray.get(manager._get_instance_deployment_states.remote(instance_id))
assert not pg_created and not server_alive and not instance_alive

@pytest.mark.parametrize("request_output_queue_type", ['rayqueue', 'zmq'])
def test_auto_scale_up_loop_and_get_curr_deployment(ray_env, request_output_queue_type):
def test_auto_scale_up_loop_and_get_cluster_deployment(ray_env, request_output_queue_type):
manager, _, _, _, _ = init_manager_with_launch_mode(LaunchMode.GLOBAL, request_output_queue_type)
time.sleep(30.0)
num_instances = ray.get(manager.scale_up.remote([], []))
assert num_instances == 4
curr_pgs, curr_servers, curr_instances = ray.get(manager.get_curr_deployment.remote())
curr_pgs, curr_servers, curr_instances = ray.get(manager._get_cluster_deployment.remote())
assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4

actor_names_dict = ray.util.list_named_actors(all_namespaces=True)
instance_ids = [actor_name_dict['name'].split("_")[-1] for actor_name_dict in actor_names_dict
if actor_name_dict['name'].startswith(INSTANCE_NAME_PREFIX)]
assert len(instance_ids) == 4
ray.get(manager._clear_instance_ray_resources.remote(instance_ids[0]))
ray.get(manager._clear_instance_ray_resources.remote(instance_ids[1]))
ray.get(manager._clear_instance_ray_states.remote(instance_ids[0]))
ray.get(manager._clear_instance_ray_states.remote(instance_ids[1]))
time.sleep(30.0)
num_instances = ray.get(manager.scale_up.remote([], []))
assert num_instances == 4
curr_pgs, curr_servers, curr_instances = ray.get(manager.get_curr_deployment.remote())
curr_pgs, curr_servers, curr_instances = ray.get(manager._get_cluster_deployment.remote())
assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4

@pytest.mark.parametrize("request_output_queue_type", ['rayqueue', 'zmq'])
Expand All @@ -371,7 +377,7 @@ def test_check_deployment_states_loop_and_auto_scale_up_loop(ray_env, request_ou
time.sleep(30.0)
num_instances = ray.get(manager.scale_up.remote([], []))
assert num_instances == 4
curr_pgs, curr_servers, curr_instances = ray.get(manager.get_curr_deployment.remote())
curr_pgs, curr_servers, curr_instances = ray.get(manager._get_cluster_deployment.remote())
assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4

actor_names_dict = ray.util.list_named_actors(all_namespaces=True)
Expand All @@ -382,8 +388,8 @@ def test_check_deployment_states_loop_and_auto_scale_up_loop(ray_env, request_ou
kill_server(instance_ids[1])
kill_instance(instance_ids[2])
# Wait for check deployment states, scale down instance and auto scale up.
time.sleep(120.0)
time.sleep(90.0)
num_instances = ray.get(manager.scale_up.remote([], []))
assert num_instances == 4
curr_pgs, curr_servers, curr_instances = ray.get(manager.get_curr_deployment.remote())
curr_pgs, curr_servers, curr_instances = ray.get(manager._get_cluster_deployment.remote())
assert len(curr_pgs) == 4 and len(curr_servers) == 4 and len(curr_instances) == 4

0 comments on commit 1f49b36

Please sign in to comment.